```
feat(examples): 更新simple_model_example.go以使用gomatrix包 - 导入gomatrix包替代部分gotensor功能 - 修改权重张量梯度初始化使用gomatrix.NewZeros - 更新must函数为泛型实现 - 重构损失函数实现使用gomatrix操作 - 优化输出格式化避免重复数据访问 refactor(model_test): 更新测试用例使用gomatrix构造函数 - 修改TestSequential测试使用gomatrix.NewMatrix和gomatrix.NewVector - 更新TestSaveLoadModel测试使用gomatrix构造函数 - 修改TestLinearLayer测试使用NewTensor构造权重矩阵 refactor(trainer_test): 将Must函数改为泛型实现 - 更新Must函数为泛型版本支持任意类型 ```
This commit is contained in:
parent
9aa53cbd5c
commit
16c2277474
|
|
@ -3,6 +3,7 @@ package main
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"git.kingecg.top/kingecg/gomatrix"
|
||||
"git.kingecg.top/kingecg/gotensor"
|
||||
)
|
||||
|
||||
|
|
@ -38,7 +39,7 @@ func (l *LinearLayer) Forward(inputs *gotensor.Tensor) (*gotensor.Tensor, error)
|
|||
// 创建转置后的权重张量
|
||||
weightTransposedTensor := &gotensor.Tensor{
|
||||
Data: weightTransposed,
|
||||
Grad: must(gotensor.NewZeros(l.Weight.Shape())),
|
||||
Grad: must(gomatrix.NewZeros(l.Weight.Shape())),
|
||||
}
|
||||
|
||||
// 矩阵乘法
|
||||
|
|
@ -83,7 +84,7 @@ func (m *SimpleModel) ZeroGrad() {
|
|||
}
|
||||
|
||||
// must 是一个辅助函数,用于处理可能的错误
|
||||
func must(t *gotensor.Tensor, err error) *gotensor.Tensor {
|
||||
func must[T any](t *T, err error) *T {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
@ -122,12 +123,12 @@ func main() {
|
|||
// 定义损失函数 (MSE)
|
||||
lossFn := func(output, target *gotensor.Tensor) *gotensor.Tensor {
|
||||
// 计算均方误差
|
||||
diff, _ := output.Sub(target)
|
||||
squared, _ := diff.Mul(diff)
|
||||
sum, _ := squared.Sum()
|
||||
diff, _ := output.Data.Subtract(target.Data)
|
||||
squared, _ := diff.Multiply(diff)
|
||||
sum := squared.Sum()
|
||||
size := float64(output.Size())
|
||||
result, _ := sum.DivScalar(size)
|
||||
return result
|
||||
result := sum / size
|
||||
return must(gotensor.NewTensor([]float64{result}, []int{1}))
|
||||
}
|
||||
|
||||
fmt.Println("开始训练模型...")
|
||||
|
|
@ -155,10 +156,11 @@ func main() {
|
|||
inputVal1, _ := input.Data.Get(1)
|
||||
outputVal0, _ := output.Data.Get(0)
|
||||
outputVal1, _ := output.Data.Get(1)
|
||||
|
||||
targetVal0, _ := trainTargets[i].Data.Get(0)
|
||||
targetVal1, _ := trainTargets[i].Data.Get(1)
|
||||
fmt.Printf("输入: [%.0f, %.0f] -> 输出: [%.3f, %.3f], 目标: [%.0f, %.0f]\n",
|
||||
inputVal0, inputVal1, outputVal0, outputVal1,
|
||||
trainTargets[i].Data.Get(0), trainTargets[i].Data.Get(1))
|
||||
targetVal0, targetVal1)
|
||||
}
|
||||
|
||||
// 保存模型
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -1,6 +1,2 @@
|
|||
git.kingecg.top/kingecg/gomatrix v0.0.0-20251230141944-2ff4dcfb0fcd h1:vn3LW38hQPGig0iqofIaIMYXVp3Uqb5QX6eH5B5lVxU=
|
||||
git.kingecg.top/kingecg/gomatrix v0.0.0-20251230141944-2ff4dcfb0fcd/go.mod h1:CHH1HkVvXrpsb+uDrsoyjx0lTwQ3oSSMbIRJmwvO6z8=
|
||||
git.kingecg.top/kingecg/gomatrix v0.0.0-20251231092627-f40960a855c7 h1:tutkcVKwpzNYxZRXkunhnkrDGRfMYgvwGAbCBCtO62c=
|
||||
git.kingecg.top/kingecg/gomatrix v0.0.0-20251231092627-f40960a855c7/go.mod h1:CHH1HkVvXrpsb+uDrsoyjx0lTwQ3oSSMbIRJmwvO6z8=
|
||||
git.kingecg.top/kingecg/gomatrix v0.0.0-20251231094846-bfcfba4e3f99 h1:sV3rEZIhYwU1TLmqFybT6Lwf6lA4oiITX/HC7i+JsiA=
|
||||
git.kingecg.top/kingecg/gomatrix v0.0.0-20251231094846-bfcfba4e3f99/go.mod h1:CHH1HkVvXrpsb+uDrsoyjx0lTwQ3oSSMbIRJmwvO6z8=
|
||||
|
|
|
|||
|
|
@ -65,11 +65,11 @@ func TestSequential(t *testing.T) {
|
|||
linearLayer := &Linear{
|
||||
Weight: &Tensor{
|
||||
Data: weight,
|
||||
Grad: Must(NewMatrix([][]float64{{0, 0}, {0, 0}})),
|
||||
Grad: Must(gomatrix.NewMatrix([]float64{0, 0, 0, 0}, []int{2, 2})),
|
||||
},
|
||||
Bias: &Tensor{
|
||||
Data: bias,
|
||||
Grad: Must(NewVector([]float64{0, 0})),
|
||||
Grad: Must(gomatrix.NewVector([]float64{0, 0})),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -117,11 +117,11 @@ func TestSequential(t *testing.T) {
|
|||
|
||||
// TestSaveLoadModel 测试模型保存和加载功能
|
||||
func TestSaveLoadModel(t *testing.T) {
|
||||
weight, err := NewMatrix([][]float64{{1, 2}, {3, 4}})
|
||||
weight, err := gomatrix.NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create weight matrix: %v", err)
|
||||
}
|
||||
bias, err := NewVector([]float64{0.5, 0.5})
|
||||
bias, err := gomatrix.NewVector([]float64{0.5, 0.5})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create bias vector: %v", err)
|
||||
}
|
||||
|
|
@ -129,11 +129,11 @@ func TestSaveLoadModel(t *testing.T) {
|
|||
linearLayer := &Linear{
|
||||
Weight: &Tensor{
|
||||
Data: weight,
|
||||
Grad: Must(NewMatrix([][]float64{{0, 0}, {0, 0}})),
|
||||
Grad: Must(gomatrix.NewMatrix([]float64{0, 0, 0, 0}, []int{2, 2})),
|
||||
},
|
||||
Bias: &Tensor{
|
||||
Data: bias,
|
||||
Grad: Must(NewVector([]float64{0, 0})),
|
||||
Grad: Must(gomatrix.NewVector([]float64{0, 0})),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -169,7 +169,7 @@ func TestSaveLoadModel(t *testing.T) {
|
|||
|
||||
// TestLinearLayer 测试线性层功能
|
||||
func TestLinearLayer(t *testing.T) {
|
||||
weight := Must(NewMatrix([][]float64{{2, 0}, {0, 3}}))
|
||||
weight := Must(NewTensor([]float64{2, 0, 0, 3}, []int{2, 2}))
|
||||
bias := Must(NewVector([]float64{0.5, 0.5}))
|
||||
|
||||
layer := NewLinear(weight, bias)
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ func NewVector(data []float64) (*Tensor, error) {
|
|||
|
||||
}
|
||||
|
||||
func Must(t *Tensor, err error) *Tensor {
|
||||
func Must[T any](t *T, err error) *T {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue