package main import ( "fmt" "git.kingecg.top/kingecg/gotensor" ) func main() { fmt.Println("=== 卷积神经网络示例 ===") // 创建一个简单的4x4灰度图像(批大小为1) imageData := []float64{ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, } // 形状为 [batch, channel, height, width] imageShape := []int{1, 1, 4, 4} imageTensor, err := gotensor.NewTensor(imageData, imageShape) if err != nil { panic(err) } fmt.Printf("输入图像:\n%s\n", imageTensor.String()) // 创建一个简单的卷积核 (3x3) kernelData := []float64{ 1.0, 0.0, -1.0, 1.0, 0.0, -1.0, 1.0, 0.0, -1.0, } // 形状为 [input_channels, output_channels, kernel_height, kernel_width] kernelShape := []int{1, 1, 3, 3} kernelTensor, err := gotensor.NewTensor(kernelData, kernelShape) if err != nil { panic(err) } fmt.Printf("卷积核:\n%s\n", kernelTensor.String()) // 执行卷积操作 (stride=1, padding=0) convResult, err := imageTensor.Conv2D(kernelTensor, 1, 0) if err != nil { panic(err) } fmt.Printf("卷积结果:\n%s\n", convResult.String()) // 应用ReLU激活函数 reluResult := convResult.ReLU() fmt.Printf("ReLU结果:\n%s\n", reluResult.String()) // 执行最大池化 (kernel_size=2, stride=2) poolResult, err := reluResult.MaxPool2D(2, 2) if err != nil { panic(err) } fmt.Printf("池化结果:\n%s\n", poolResult.String()) // 展平操作 flattened := poolResult.Flatten() fmt.Printf("展平后大小: %d\n", flattened.Size()) // 创建一些随机权重进行全连接层操作 // 由于池化后是2x2,展平后应该是4个元素,所以我们需要4x2的权重矩阵 flattenedSize := flattened.Size() weightsData := make([]float64, flattenedSize * 2) // flattenedSize*2个权重值 for i := range weightsData { weightsData[i] = 0.1 * float64(i+1) // 填充一些递增的值 } weightsShape := []int{flattenedSize, 2} // 输出2类(猫/狗) weights, err := gotensor.NewTensor(weightsData, weightsShape) if err != nil { panic(err) } // 重塑flattened张量为2D格式以进行矩阵乘法 reshapedFlattenedData := make([]float64, flattenedSize) for i := 0; i < flattenedSize; i++ { reshapedFlattenedData[i], _ = flattened.Data.Get(i) } reshapedFlattened, err := gotensor.NewTensor(reshapedFlattenedData, []int{1, flattenedSize}) // 作为1xN的矩阵 if err != nil { panic(err) } // 全连接层计算 (矩阵乘法) - 现在是 (1, N) * (N, 2) = (1, 2) fcResult, err := reshapedFlattened.MatMul(weights) if err != nil { panic(err) } fmt.Printf("全连接输出:\n%s\n", fcResult.String()) // 应用Softmax softmaxResult := fcResult.Softmax() fmt.Printf("Softmax输出 (概率分布):\n%s\n", softmaxResult.String()) // 创建目标标签 (猫=1, 狗=0) targetData := []float64{1.0, 0.0} target, err := gotensor.NewTensor(targetData, []int{1, 2}) if err != nil { panic(err) } // 计算交叉熵损失 loss := softmaxResult.CrossEntropy(target) lossVal, _ := loss.Data.Get(0) fmt.Printf("交叉熵损失: %f\n", lossVal) fmt.Println("=== 模型前向传播完成 ===") }