```
feat(examples): 更新CNN示例中的全连接层实现 - 修复权重矩阵大小计算,根据展平后的实际大小动态创建权重 - 添加张量重塑逻辑,将展平后的张量转换为正确的2D格式进行矩阵乘法 - 使用动态生成的权重值替代固定的权重数组 - 确保矩阵乘法维度匹配:(1, N) * (N, 2) = (1, 2) ```
This commit is contained in:
parent
3536fdf8cf
commit
fd232d65fa
|
|
@ -65,15 +65,30 @@ func main() {
|
|||
fmt.Printf("展平后大小: %d\n", flattened.Size())
|
||||
|
||||
// 创建一些随机权重进行全连接层操作
|
||||
weightsData := []float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6}
|
||||
weightsShape := []int{flattened.Size(), 2} // 输出2类(猫/狗)
|
||||
// 由于池化后是2x2,展平后应该是4个元素,所以我们需要4x2的权重矩阵
|
||||
flattenedSize := flattened.Size()
|
||||
weightsData := make([]float64, flattenedSize * 2) // flattenedSize*2个权重值
|
||||
for i := range weightsData {
|
||||
weightsData[i] = 0.1 * float64(i+1) // 填充一些递增的值
|
||||
}
|
||||
weightsShape := []int{flattenedSize, 2} // 输出2类(猫/狗)
|
||||
weights, err := gotensor.NewTensor(weightsData, weightsShape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// 全连接层计算 (矩阵乘法)
|
||||
fcResult, err := flattened.MatMul(weights)
|
||||
// 重塑flattened张量为2D格式以进行矩阵乘法
|
||||
reshapedFlattenedData := make([]float64, flattenedSize)
|
||||
for i := 0; i < flattenedSize; i++ {
|
||||
reshapedFlattenedData[i], _ = flattened.Data.Get(i)
|
||||
}
|
||||
reshapedFlattened, err := gotensor.NewTensor(reshapedFlattenedData, []int{1, flattenedSize}) // 作为1xN的矩阵
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// 全连接层计算 (矩阵乘法) - 现在是 (1, N) * (N, 2) = (1, 2)
|
||||
fcResult, err := reshapedFlattened.MatMul(weights)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue