100 lines
2.5 KiB
Go
100 lines
2.5 KiB
Go
package main
|
||
|
||
import (
|
||
"fmt"
|
||
"git.kingecg.top/kingecg/gotensor"
|
||
)
|
||
|
||
func main() {
|
||
fmt.Println("=== 卷积神经网络示例 ===")
|
||
|
||
// 创建一个简单的4x4灰度图像(批大小为1)
|
||
imageData := []float64{
|
||
1.0, 2.0, 3.0, 4.0,
|
||
5.0, 6.0, 7.0, 8.0,
|
||
9.0, 10.0, 11.0, 12.0,
|
||
13.0, 14.0, 15.0, 16.0,
|
||
}
|
||
|
||
// 形状为 [batch, channel, height, width]
|
||
imageShape := []int{1, 1, 4, 4}
|
||
imageTensor, err := gotensor.NewTensor(imageData, imageShape)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
fmt.Printf("输入图像:\n%s\n", imageTensor.String())
|
||
|
||
// 创建一个简单的卷积核 (3x3)
|
||
kernelData := []float64{
|
||
1.0, 0.0, -1.0,
|
||
1.0, 0.0, -1.0,
|
||
1.0, 0.0, -1.0,
|
||
}
|
||
// 形状为 [input_channels, output_channels, kernel_height, kernel_width]
|
||
kernelShape := []int{1, 1, 3, 3}
|
||
kernelTensor, err := gotensor.NewTensor(kernelData, kernelShape)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
fmt.Printf("卷积核:\n%s\n", kernelTensor.String())
|
||
|
||
// 执行卷积操作 (stride=1, padding=0)
|
||
convResult, err := imageTensor.Conv2D(kernelTensor, 1, 0)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
fmt.Printf("卷积结果:\n%s\n", convResult.String())
|
||
|
||
// 应用ReLU激活函数
|
||
reluResult := convResult.ReLU()
|
||
fmt.Printf("ReLU结果:\n%s\n", reluResult.String())
|
||
|
||
// 执行最大池化 (kernel_size=2, stride=2)
|
||
poolResult, err := reluResult.MaxPool2D(2, 2)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
fmt.Printf("池化结果:\n%s\n", poolResult.String())
|
||
|
||
// 展平操作
|
||
flattened := poolResult.Flatten()
|
||
fmt.Printf("展平后大小: %d\n", flattened.Size())
|
||
|
||
// 创建一些随机权重进行全连接层操作
|
||
weightsData := []float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6}
|
||
weightsShape := []int{flattened.Size(), 2} // 输出2类(猫/狗)
|
||
weights, err := gotensor.NewTensor(weightsData, weightsShape)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
// 全连接层计算 (矩阵乘法)
|
||
fcResult, err := flattened.MatMul(weights)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
fmt.Printf("全连接输出:\n%s\n", fcResult.String())
|
||
|
||
// 应用Softmax
|
||
softmaxResult := fcResult.Softmax()
|
||
fmt.Printf("Softmax输出 (概率分布):\n%s\n", softmaxResult.String())
|
||
|
||
// 创建目标标签 (猫=1, 狗=0)
|
||
targetData := []float64{1.0, 0.0}
|
||
target, err := gotensor.NewTensor(targetData, []int{1, 2})
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
// 计算交叉熵损失
|
||
loss := softmaxResult.CrossEntropy(target)
|
||
lossVal, _ := loss.Data.Get(0)
|
||
fmt.Printf("交叉熵损失: %f\n", lossVal)
|
||
|
||
fmt.Println("=== 模型前向传播完成 ===")
|
||
} |