feat(examples): 更新simple_model_example.go以使用gomatrix包

- 导入gomatrix包替代部分gotensor功能
- 修改权重张量梯度初始化使用gomatrix.NewZeros
- 更新must函数为泛型实现
- 重构损失函数实现使用gomatrix操作
- 优化输出格式化避免重复数据访问

refactor(model_test): 更新测试用例使用gomatrix构造函数

- 修改TestSequential测试使用gomatrix.NewMatrix和gomatrix.NewVector
- 更新TestSaveLoadModel测试使用gomatrix构造函数
- 修改TestLinearLayer测试使用NewTensor构造权重矩阵

refactor(trainer_test): 将Must函数改为泛型实现

- 更新Must函数为泛型版本支持任意类型
```
This commit is contained in:
kingecg 2026-01-01 15:14:10 +08:00
parent 9aa53cbd5c
commit 16c2277474
4 changed files with 19 additions and 21 deletions

View File

@ -3,6 +3,7 @@ package main
import (
"fmt"
"git.kingecg.top/kingecg/gomatrix"
"git.kingecg.top/kingecg/gotensor"
)
@ -38,7 +39,7 @@ func (l *LinearLayer) Forward(inputs *gotensor.Tensor) (*gotensor.Tensor, error)
// 创建转置后的权重张量
weightTransposedTensor := &gotensor.Tensor{
Data: weightTransposed,
Grad: must(gotensor.NewZeros(l.Weight.Shape())),
Grad: must(gomatrix.NewZeros(l.Weight.Shape())),
}
// 矩阵乘法
@ -83,7 +84,7 @@ func (m *SimpleModel) ZeroGrad() {
}
// must 是一个辅助函数,用于处理可能的错误
func must(t *gotensor.Tensor, err error) *gotensor.Tensor {
func must[T any](t *T, err error) *T {
if err != nil {
panic(err)
}
@ -122,12 +123,12 @@ func main() {
// 定义损失函数 (MSE)
lossFn := func(output, target *gotensor.Tensor) *gotensor.Tensor {
// 计算均方误差
diff, _ := output.Sub(target)
squared, _ := diff.Mul(diff)
sum, _ := squared.Sum()
diff, _ := output.Data.Subtract(target.Data)
squared, _ := diff.Multiply(diff)
sum := squared.Sum()
size := float64(output.Size())
result, _ := sum.DivScalar(size)
return result
result := sum / size
return must(gotensor.NewTensor([]float64{result}, []int{1}))
}
fmt.Println("开始训练模型...")
@ -155,10 +156,11 @@ func main() {
inputVal1, _ := input.Data.Get(1)
outputVal0, _ := output.Data.Get(0)
outputVal1, _ := output.Data.Get(1)
targetVal0, _ := trainTargets[i].Data.Get(0)
targetVal1, _ := trainTargets[i].Data.Get(1)
fmt.Printf("输入: [%.0f, %.0f] -> 输出: [%.3f, %.3f], 目标: [%.0f, %.0f]\n",
inputVal0, inputVal1, outputVal0, outputVal1,
trainTargets[i].Data.Get(0), trainTargets[i].Data.Get(1))
targetVal0, targetVal1)
}
// 保存模型

4
go.sum
View File

@ -1,6 +1,2 @@
git.kingecg.top/kingecg/gomatrix v0.0.0-20251230141944-2ff4dcfb0fcd h1:vn3LW38hQPGig0iqofIaIMYXVp3Uqb5QX6eH5B5lVxU=
git.kingecg.top/kingecg/gomatrix v0.0.0-20251230141944-2ff4dcfb0fcd/go.mod h1:CHH1HkVvXrpsb+uDrsoyjx0lTwQ3oSSMbIRJmwvO6z8=
git.kingecg.top/kingecg/gomatrix v0.0.0-20251231092627-f40960a855c7 h1:tutkcVKwpzNYxZRXkunhnkrDGRfMYgvwGAbCBCtO62c=
git.kingecg.top/kingecg/gomatrix v0.0.0-20251231092627-f40960a855c7/go.mod h1:CHH1HkVvXrpsb+uDrsoyjx0lTwQ3oSSMbIRJmwvO6z8=
git.kingecg.top/kingecg/gomatrix v0.0.0-20251231094846-bfcfba4e3f99 h1:sV3rEZIhYwU1TLmqFybT6Lwf6lA4oiITX/HC7i+JsiA=
git.kingecg.top/kingecg/gomatrix v0.0.0-20251231094846-bfcfba4e3f99/go.mod h1:CHH1HkVvXrpsb+uDrsoyjx0lTwQ3oSSMbIRJmwvO6z8=

View File

@ -65,11 +65,11 @@ func TestSequential(t *testing.T) {
linearLayer := &Linear{
Weight: &Tensor{
Data: weight,
Grad: Must(NewMatrix([][]float64{{0, 0}, {0, 0}})),
Grad: Must(gomatrix.NewMatrix([]float64{0, 0, 0, 0}, []int{2, 2})),
},
Bias: &Tensor{
Data: bias,
Grad: Must(NewVector([]float64{0, 0})),
Grad: Must(gomatrix.NewVector([]float64{0, 0})),
},
}
@ -117,11 +117,11 @@ func TestSequential(t *testing.T) {
// TestSaveLoadModel 测试模型保存和加载功能
func TestSaveLoadModel(t *testing.T) {
weight, err := NewMatrix([][]float64{{1, 2}, {3, 4}})
weight, err := gomatrix.NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
if err != nil {
t.Fatalf("Failed to create weight matrix: %v", err)
}
bias, err := NewVector([]float64{0.5, 0.5})
bias, err := gomatrix.NewVector([]float64{0.5, 0.5})
if err != nil {
t.Fatalf("Failed to create bias vector: %v", err)
}
@ -129,11 +129,11 @@ func TestSaveLoadModel(t *testing.T) {
linearLayer := &Linear{
Weight: &Tensor{
Data: weight,
Grad: Must(NewMatrix([][]float64{{0, 0}, {0, 0}})),
Grad: Must(gomatrix.NewMatrix([]float64{0, 0, 0, 0}, []int{2, 2})),
},
Bias: &Tensor{
Data: bias,
Grad: Must(NewVector([]float64{0, 0})),
Grad: Must(gomatrix.NewVector([]float64{0, 0})),
},
}
@ -169,7 +169,7 @@ func TestSaveLoadModel(t *testing.T) {
// TestLinearLayer 测试线性层功能
func TestLinearLayer(t *testing.T) {
weight := Must(NewMatrix([][]float64{{2, 0}, {0, 3}}))
weight := Must(NewTensor([]float64{2, 0, 0, 3}, []int{2, 2}))
bias := Must(NewVector([]float64{0.5, 0.5}))
layer := NewLinear(weight, bias)

View File

@ -63,7 +63,7 @@ func NewVector(data []float64) (*Tensor, error) {
}
func Must(t *Tensor, err error) *Tensor {
func Must[T any](t *T, err error) *T {
if err != nil {
panic(err)
}