package main import ( "fmt" "git.kingecg.top/kingecg/gotensor" ) func main() { fmt.Println("=== 卷积神经网络示例 ===") // 创建一个简单的4x4灰度图像(批大小为1) imageData := []float64{ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, } // 形状为 [batch, channel, height, width] imageShape := []int{1, 1, 4, 4} imageTensor, err := gotensor.NewTensor(imageData, imageShape) if err != nil { panic(err) } fmt.Printf("输入图像:\n%s\n", imageTensor.String()) // 创建一个简单的卷积核 (3x3) kernelData := []float64{ 1.0, 0.0, -1.0, 1.0, 0.0, -1.0, 1.0, 0.0, -1.0, } // 形状为 [input_channels, output_channels, kernel_height, kernel_width] kernelShape := []int{1, 1, 3, 3} kernelTensor, err := gotensor.NewTensor(kernelData, kernelShape) if err != nil { panic(err) } fmt.Printf("卷积核:\n%s\n", kernelTensor.String()) // 执行卷积操作 (stride=1, padding=0) convResult, err := imageTensor.Conv2D(kernelTensor, 1, 0) if err != nil { panic(err) } fmt.Printf("卷积结果:\n%s\n", convResult.String()) // 应用ReLU激活函数 reluResult := convResult.ReLU() fmt.Printf("ReLU结果:\n%s\n", reluResult.String()) // 执行最大池化 (kernel_size=2, stride=2) poolResult, err := reluResult.MaxPool2D(2, 2) if err != nil { panic(err) } fmt.Printf("池化结果:\n%s\n", poolResult.String()) // 展平操作 flattened := poolResult.Flatten() fmt.Printf("展平后大小: %d\n", flattened.Size()) // 创建一些随机权重进行全连接层操作 weightsData := []float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6} weightsShape := []int{flattened.Size(), 2} // 输出2类(猫/狗) weights, err := gotensor.NewTensor(weightsData, weightsShape) if err != nil { panic(err) } // 全连接层计算 (矩阵乘法) fcResult, err := flattened.MatMul(weights) if err != nil { panic(err) } fmt.Printf("全连接输出:\n%s\n", fcResult.String()) // 应用Softmax softmaxResult := fcResult.Softmax() fmt.Printf("Softmax输出 (概率分布):\n%s\n", softmaxResult.String()) // 创建目标标签 (猫=1, 狗=0) targetData := []float64{1.0, 0.0} target, err := gotensor.NewTensor(targetData, []int{1, 2}) if err != nil { panic(err) } // 计算交叉熵损失 loss := softmaxResult.CrossEntropy(target) lossVal, _ := loss.Data.Get(0) fmt.Printf("交叉熵损失: %f\n", lossVal) fmt.Println("=== 模型前向传播完成 ===") }