gotensor/examples/cnn_example.go

100 lines
2.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"fmt"
"git.kingecg.top/kingecg/gotensor"
)
func main() {
fmt.Println("=== 卷积神经网络示例 ===")
// 创建一个简单的4x4灰度图像批大小为1
imageData := []float64{
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0,
}
// 形状为 [batch, channel, height, width]
imageShape := []int{1, 1, 4, 4}
imageTensor, err := gotensor.NewTensor(imageData, imageShape)
if err != nil {
panic(err)
}
fmt.Printf("输入图像:\n%s\n", imageTensor.String())
// 创建一个简单的卷积核 (3x3)
kernelData := []float64{
1.0, 0.0, -1.0,
1.0, 0.0, -1.0,
1.0, 0.0, -1.0,
}
// 形状为 [input_channels, output_channels, kernel_height, kernel_width]
kernelShape := []int{1, 1, 3, 3}
kernelTensor, err := gotensor.NewTensor(kernelData, kernelShape)
if err != nil {
panic(err)
}
fmt.Printf("卷积核:\n%s\n", kernelTensor.String())
// 执行卷积操作 (stride=1, padding=0)
convResult, err := imageTensor.Conv2D(kernelTensor, 1, 0)
if err != nil {
panic(err)
}
fmt.Printf("卷积结果:\n%s\n", convResult.String())
// 应用ReLU激活函数
reluResult := convResult.ReLU()
fmt.Printf("ReLU结果:\n%s\n", reluResult.String())
// 执行最大池化 (kernel_size=2, stride=2)
poolResult, err := reluResult.MaxPool2D(2, 2)
if err != nil {
panic(err)
}
fmt.Printf("池化结果:\n%s\n", poolResult.String())
// 展平操作
flattened := poolResult.Flatten()
fmt.Printf("展平后大小: %d\n", flattened.Size())
// 创建一些随机权重进行全连接层操作
weightsData := []float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6}
weightsShape := []int{flattened.Size(), 2} // 输出2类猫/狗)
weights, err := gotensor.NewTensor(weightsData, weightsShape)
if err != nil {
panic(err)
}
// 全连接层计算 (矩阵乘法)
fcResult, err := flattened.MatMul(weights)
if err != nil {
panic(err)
}
fmt.Printf("全连接输出:\n%s\n", fcResult.String())
// 应用Softmax
softmaxResult := fcResult.Softmax()
fmt.Printf("Softmax输出 (概率分布):\n%s\n", softmaxResult.String())
// 创建目标标签 (猫=1, 狗=0)
targetData := []float64{1.0, 0.0}
target, err := gotensor.NewTensor(targetData, []int{1, 2})
if err != nil {
panic(err)
}
// 计算交叉熵损失
loss := softmaxResult.CrossEntropy(target)
lossVal, _ := loss.Data.Get(0)
fmt.Printf("交叉熵损失: %f\n", lossVal)
fmt.Println("=== 模型前向传播完成 ===")
}