From fd232d65fa42219e2bee0491396bd52004975b4b Mon Sep 17 00:00:00 2001 From: kingecg Date: Tue, 30 Dec 2025 23:23:51 +0800 Subject: [PATCH] =?UTF-8?q?```=20feat(examples):=20=E6=9B=B4=E6=96=B0CNN?= =?UTF-8?q?=E7=A4=BA=E4=BE=8B=E4=B8=AD=E7=9A=84=E5=85=A8=E8=BF=9E=E6=8E=A5?= =?UTF-8?q?=E5=B1=82=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复权重矩阵大小计算,根据展平后的实际大小动态创建权重 - 添加张量重塑逻辑,将展平后的张量转换为正确的2D格式进行矩阵乘法 - 使用动态生成的权重值替代固定的权重数组 - 确保矩阵乘法维度匹配:(1, N) * (N, 2) = (1, 2) ``` --- examples/cnn_example.go | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/examples/cnn_example.go b/examples/cnn_example.go index f2f08e6..30a06f7 100644 --- a/examples/cnn_example.go +++ b/examples/cnn_example.go @@ -65,15 +65,30 @@ func main() { fmt.Printf("展平后大小: %d\n", flattened.Size()) // 创建一些随机权重进行全连接层操作 - weightsData := []float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6} - weightsShape := []int{flattened.Size(), 2} // 输出2类(猫/狗) + // 由于池化后是2x2,展平后应该是4个元素,所以我们需要4x2的权重矩阵 + flattenedSize := flattened.Size() + weightsData := make([]float64, flattenedSize * 2) // flattenedSize*2个权重值 + for i := range weightsData { + weightsData[i] = 0.1 * float64(i+1) // 填充一些递增的值 + } + weightsShape := []int{flattenedSize, 2} // 输出2类(猫/狗) weights, err := gotensor.NewTensor(weightsData, weightsShape) if err != nil { panic(err) } - // 全连接层计算 (矩阵乘法) - fcResult, err := flattened.MatMul(weights) + // 重塑flattened张量为2D格式以进行矩阵乘法 + reshapedFlattenedData := make([]float64, flattenedSize) + for i := 0; i < flattenedSize; i++ { + reshapedFlattenedData[i], _ = flattened.Data.Get(i) + } + reshapedFlattened, err := gotensor.NewTensor(reshapedFlattenedData, []int{1, flattenedSize}) // 作为1xN的矩阵 + if err != nil { + panic(err) + } + + // 全连接层计算 (矩阵乘法) - 现在是 (1, N) * (N, 2) = (1, 2) + fcResult, err := reshapedFlattened.MatMul(weights) if err != nil { panic(err) }