gotensor/examples/simple_model_example.go

185 lines
4.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"fmt"
"git.kingecg.top/kingecg/gomatrix"
"git.kingecg.top/kingecg/gotensor"
)
// LinearLayer 是一个简单的线性层实现
type LinearLayer struct {
Weight *gotensor.Tensor
Bias *gotensor.Tensor
}
// NewLinearLayer 创建一个新的线性层
func NewLinearLayer(inputSize, outputSize int) *LinearLayer {
weight, _ := gotensor.NewTensor([]float64{
0.5, 0.1,
0.2, 0.4,
}, []int{outputSize, inputSize})
bias, _ := gotensor.NewTensor([]float64{0, 0}, []int{outputSize})
return &LinearLayer{
Weight: weight,
Bias: bias,
}
}
func (l *LinearLayer) Forward(inputs *gotensor.Tensor) (*gotensor.Tensor, error) {
// 执行线性变换: output = inputs * weight^T + bias
// 首先转置权重
weightTransposed, err := l.Weight.Data.Transpose()
if err != nil {
return nil, err
}
// 创建转置后的权重张量
weightTransposedTensor := &gotensor.Tensor{
Data: weightTransposed,
Grad: must(gomatrix.NewZeros(l.Weight.Shape())),
}
// 矩阵乘法
mulResult, err := inputs.MatMul(weightTransposedTensor)
if err != nil {
return nil, err
}
// 加上偏置
output, err := mulResult.Add(l.Bias)
if err != nil {
return nil, err
}
return output, nil
}
func (l *LinearLayer) Parameters() []*gotensor.Tensor {
return []*gotensor.Tensor{l.Weight, l.Bias}
}
func (l *LinearLayer) ZeroGrad() {
l.Weight.ZeroGrad()
l.Bias.ZeroGrad()
}
// SimpleModel 是一个简单的模型实现
type SimpleModel struct {
Layer *LinearLayer
}
func (m *SimpleModel) Forward(inputs *gotensor.Tensor) (*gotensor.Tensor, error) {
return m.Layer.Forward(inputs)
}
func (m *SimpleModel) Parameters() []*gotensor.Tensor {
return m.Layer.Parameters()
}
func (m *SimpleModel) ZeroGrad() {
m.Layer.ZeroGrad()
}
// must 是一个辅助函数,用于处理可能的错误
func must[T any](t *T, err error) *T {
if err != nil {
panic(err)
}
return t
}
func main() {
fmt.Println("Gotensor Simple Model Example")
// 创建模型
model := &SimpleModel{
Layer: NewLinearLayer(2, 2), // 2输入2输出
}
// 创建优化器 (SGD)
optimizer := gotensor.NewSGD(model.Parameters(), 0.01)
// 创建训练器
trainer := gotensor.NewTrainer(model, optimizer)
// 准备训练数据 (简单的XOR问题)
trainInputs := []*gotensor.Tensor{
must(gotensor.NewTensor([]float64{0, 0}, []int{2})),
must(gotensor.NewTensor([]float64{0, 1}, []int{2})),
must(gotensor.NewTensor([]float64{1, 0}, []int{2})),
must(gotensor.NewTensor([]float64{1, 1}, []int{2})),
}
trainTargets := []*gotensor.Tensor{
must(gotensor.NewTensor([]float64{0, 1}, []int{2})), // 0 XOR 0 = 0 (表示为 [0,1] -> [0])
must(gotensor.NewTensor([]float64{1, 0}, []int{2})), // 0 XOR 1 = 1 (表示为 [1,0] -> [1])
must(gotensor.NewTensor([]float64{1, 0}, []int{2})), // 1 XOR 0 = 1 (表示为 [1,0] -> [1])
must(gotensor.NewTensor([]float64{0, 1}, []int{2})), // 1 XOR 1 = 0 (表示为 [0,1] -> [0])
}
// 定义损失函数 (MSE)
lossFn := func(output, target *gotensor.Tensor) *gotensor.Tensor {
// 计算均方误差
diff, _ := output.Data.Subtract(target.Data)
squared, _ := diff.Multiply(diff)
sum := squared.Sum()
size := float64(output.Size())
result := sum / size
return must(gotensor.NewTensor([]float64{result}, []int{1}))
}
fmt.Println("开始训练模型...")
// 训练模型
epochs := 100
err := trainer.Train(trainInputs, trainTargets, epochs, lossFn, true)
if err != nil {
fmt.Printf("训练过程中出现错误: %v\n", err)
return
}
fmt.Println("训练完成!")
// 评估模型
fmt.Println("\n评估训练结果:")
for i, input := range trainInputs {
output, err := model.Forward(input)
if err != nil {
fmt.Printf("前向传播错误: %v\n", err)
continue
}
inputVal0, _ := input.Data.Get(0)
inputVal1, _ := input.Data.Get(1)
outputVal0, _ := output.Data.Get(0)
outputVal1, _ := output.Data.Get(1)
targetVal0, _ := trainTargets[i].Data.Get(0)
targetVal1, _ := trainTargets[i].Data.Get(1)
fmt.Printf("输入: [%.0f, %.0f] -> 输出: [%.3f, %.3f], 目标: [%.0f, %.0f]\n",
inputVal0, inputVal1, outputVal0, outputVal1,
targetVal0, targetVal1)
}
// 保存模型
err = gotensor.SaveModel(model, "/tmp/simple_model.json")
if err != nil {
fmt.Printf("保存模型失败: %v\n", err)
} else {
fmt.Println("模型已保存到 /tmp/simple_model.json")
}
// 加载模型
newModel := &SimpleModel{
Layer: NewLinearLayer(2, 2),
}
err = gotensor.LoadModel(newModel, "/tmp/simple_model.json")
if err != nil {
fmt.Printf("加载模型失败: %v\n", err)
} else {
fmt.Println("模型已从 /tmp/simple_model.json 加载")
}
}