feat(examples): 更新CNN示例中的全连接层实现

- 修复权重矩阵大小计算,根据展平后的实际大小动态创建权重
- 添加张量重塑逻辑,将展平后的张量转换为正确的2D格式进行矩阵乘法
- 使用动态生成的权重值替代固定的权重数组
- 确保矩阵乘法维度匹配:(1, N) * (N, 2) = (1, 2)
```
This commit is contained in:
kingecg 2025-12-30 23:23:51 +08:00
parent 3536fdf8cf
commit fd232d65fa
1 changed files with 19 additions and 4 deletions

View File

@ -65,15 +65,30 @@ func main() {
fmt.Printf("展平后大小: %d\n", flattened.Size())
// 创建一些随机权重进行全连接层操作
weightsData := []float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6}
weightsShape := []int{flattened.Size(), 2} // 输出2类猫/狗)
// 由于池化后是2x2展平后应该是4个元素所以我们需要4x2的权重矩阵
flattenedSize := flattened.Size()
weightsData := make([]float64, flattenedSize * 2) // flattenedSize*2个权重值
for i := range weightsData {
weightsData[i] = 0.1 * float64(i+1) // 填充一些递增的值
}
weightsShape := []int{flattenedSize, 2} // 输出2类猫/狗)
weights, err := gotensor.NewTensor(weightsData, weightsShape)
if err != nil {
panic(err)
}
// 全连接层计算 (矩阵乘法)
fcResult, err := flattened.MatMul(weights)
// 重塑flattened张量为2D格式以进行矩阵乘法
reshapedFlattenedData := make([]float64, flattenedSize)
for i := 0; i < flattenedSize; i++ {
reshapedFlattenedData[i], _ = flattened.Data.Get(i)
}
reshapedFlattened, err := gotensor.NewTensor(reshapedFlattenedData, []int{1, flattenedSize}) // 作为1xN的矩阵
if err != nil {
panic(err)
}
// 全连接层计算 (矩阵乘法) - 现在是 (1, N) * (N, 2) = (1, 2)
fcResult, err := reshapedFlattened.MatMul(weights)
if err != nil {
panic(err)
}