```
feat(examples): 更新simple_model_example.go以使用gomatrix包 - 导入gomatrix包替代部分gotensor功能 - 修改权重张量梯度初始化使用gomatrix.NewZeros - 更新must函数为泛型实现 - 重构损失函数实现使用gomatrix操作 - 优化输出格式化避免重复数据访问 refactor(model_test): 更新测试用例使用gomatrix构造函数 - 修改TestSequential测试使用gomatrix.NewMatrix和gomatrix.NewVector - 更新TestSaveLoadModel测试使用gomatrix构造函数 - 修改TestLinearLayer测试使用NewTensor构造权重矩阵 refactor(trainer_test): 将Must函数改为泛型实现 - 更新Must函数为泛型版本支持任意类型 ```
This commit is contained in:
parent
9aa53cbd5c
commit
16c2277474
|
|
@ -3,6 +3,7 @@ package main
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"git.kingecg.top/kingecg/gomatrix"
|
||||||
"git.kingecg.top/kingecg/gotensor"
|
"git.kingecg.top/kingecg/gotensor"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -38,7 +39,7 @@ func (l *LinearLayer) Forward(inputs *gotensor.Tensor) (*gotensor.Tensor, error)
|
||||||
// 创建转置后的权重张量
|
// 创建转置后的权重张量
|
||||||
weightTransposedTensor := &gotensor.Tensor{
|
weightTransposedTensor := &gotensor.Tensor{
|
||||||
Data: weightTransposed,
|
Data: weightTransposed,
|
||||||
Grad: must(gotensor.NewZeros(l.Weight.Shape())),
|
Grad: must(gomatrix.NewZeros(l.Weight.Shape())),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 矩阵乘法
|
// 矩阵乘法
|
||||||
|
|
@ -83,7 +84,7 @@ func (m *SimpleModel) ZeroGrad() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// must 是一个辅助函数,用于处理可能的错误
|
// must 是一个辅助函数,用于处理可能的错误
|
||||||
func must(t *gotensor.Tensor, err error) *gotensor.Tensor {
|
func must[T any](t *T, err error) *T {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
@ -122,12 +123,12 @@ func main() {
|
||||||
// 定义损失函数 (MSE)
|
// 定义损失函数 (MSE)
|
||||||
lossFn := func(output, target *gotensor.Tensor) *gotensor.Tensor {
|
lossFn := func(output, target *gotensor.Tensor) *gotensor.Tensor {
|
||||||
// 计算均方误差
|
// 计算均方误差
|
||||||
diff, _ := output.Sub(target)
|
diff, _ := output.Data.Subtract(target.Data)
|
||||||
squared, _ := diff.Mul(diff)
|
squared, _ := diff.Multiply(diff)
|
||||||
sum, _ := squared.Sum()
|
sum := squared.Sum()
|
||||||
size := float64(output.Size())
|
size := float64(output.Size())
|
||||||
result, _ := sum.DivScalar(size)
|
result := sum / size
|
||||||
return result
|
return must(gotensor.NewTensor([]float64{result}, []int{1}))
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("开始训练模型...")
|
fmt.Println("开始训练模型...")
|
||||||
|
|
@ -155,10 +156,11 @@ func main() {
|
||||||
inputVal1, _ := input.Data.Get(1)
|
inputVal1, _ := input.Data.Get(1)
|
||||||
outputVal0, _ := output.Data.Get(0)
|
outputVal0, _ := output.Data.Get(0)
|
||||||
outputVal1, _ := output.Data.Get(1)
|
outputVal1, _ := output.Data.Get(1)
|
||||||
|
targetVal0, _ := trainTargets[i].Data.Get(0)
|
||||||
|
targetVal1, _ := trainTargets[i].Data.Get(1)
|
||||||
fmt.Printf("输入: [%.0f, %.0f] -> 输出: [%.3f, %.3f], 目标: [%.0f, %.0f]\n",
|
fmt.Printf("输入: [%.0f, %.0f] -> 输出: [%.3f, %.3f], 目标: [%.0f, %.0f]\n",
|
||||||
inputVal0, inputVal1, outputVal0, outputVal1,
|
inputVal0, inputVal1, outputVal0, outputVal1,
|
||||||
trainTargets[i].Data.Get(0), trainTargets[i].Data.Get(1))
|
targetVal0, targetVal1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存模型
|
// 保存模型
|
||||||
|
|
|
||||||
4
go.sum
4
go.sum
|
|
@ -1,6 +1,2 @@
|
||||||
git.kingecg.top/kingecg/gomatrix v0.0.0-20251230141944-2ff4dcfb0fcd h1:vn3LW38hQPGig0iqofIaIMYXVp3Uqb5QX6eH5B5lVxU=
|
|
||||||
git.kingecg.top/kingecg/gomatrix v0.0.0-20251230141944-2ff4dcfb0fcd/go.mod h1:CHH1HkVvXrpsb+uDrsoyjx0lTwQ3oSSMbIRJmwvO6z8=
|
|
||||||
git.kingecg.top/kingecg/gomatrix v0.0.0-20251231092627-f40960a855c7 h1:tutkcVKwpzNYxZRXkunhnkrDGRfMYgvwGAbCBCtO62c=
|
|
||||||
git.kingecg.top/kingecg/gomatrix v0.0.0-20251231092627-f40960a855c7/go.mod h1:CHH1HkVvXrpsb+uDrsoyjx0lTwQ3oSSMbIRJmwvO6z8=
|
|
||||||
git.kingecg.top/kingecg/gomatrix v0.0.0-20251231094846-bfcfba4e3f99 h1:sV3rEZIhYwU1TLmqFybT6Lwf6lA4oiITX/HC7i+JsiA=
|
git.kingecg.top/kingecg/gomatrix v0.0.0-20251231094846-bfcfba4e3f99 h1:sV3rEZIhYwU1TLmqFybT6Lwf6lA4oiITX/HC7i+JsiA=
|
||||||
git.kingecg.top/kingecg/gomatrix v0.0.0-20251231094846-bfcfba4e3f99/go.mod h1:CHH1HkVvXrpsb+uDrsoyjx0lTwQ3oSSMbIRJmwvO6z8=
|
git.kingecg.top/kingecg/gomatrix v0.0.0-20251231094846-bfcfba4e3f99/go.mod h1:CHH1HkVvXrpsb+uDrsoyjx0lTwQ3oSSMbIRJmwvO6z8=
|
||||||
|
|
|
||||||
|
|
@ -65,11 +65,11 @@ func TestSequential(t *testing.T) {
|
||||||
linearLayer := &Linear{
|
linearLayer := &Linear{
|
||||||
Weight: &Tensor{
|
Weight: &Tensor{
|
||||||
Data: weight,
|
Data: weight,
|
||||||
Grad: Must(NewMatrix([][]float64{{0, 0}, {0, 0}})),
|
Grad: Must(gomatrix.NewMatrix([]float64{0, 0, 0, 0}, []int{2, 2})),
|
||||||
},
|
},
|
||||||
Bias: &Tensor{
|
Bias: &Tensor{
|
||||||
Data: bias,
|
Data: bias,
|
||||||
Grad: Must(NewVector([]float64{0, 0})),
|
Grad: Must(gomatrix.NewVector([]float64{0, 0})),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -117,11 +117,11 @@ func TestSequential(t *testing.T) {
|
||||||
|
|
||||||
// TestSaveLoadModel 测试模型保存和加载功能
|
// TestSaveLoadModel 测试模型保存和加载功能
|
||||||
func TestSaveLoadModel(t *testing.T) {
|
func TestSaveLoadModel(t *testing.T) {
|
||||||
weight, err := NewMatrix([][]float64{{1, 2}, {3, 4}})
|
weight, err := gomatrix.NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create weight matrix: %v", err)
|
t.Fatalf("Failed to create weight matrix: %v", err)
|
||||||
}
|
}
|
||||||
bias, err := NewVector([]float64{0.5, 0.5})
|
bias, err := gomatrix.NewVector([]float64{0.5, 0.5})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create bias vector: %v", err)
|
t.Fatalf("Failed to create bias vector: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -129,11 +129,11 @@ func TestSaveLoadModel(t *testing.T) {
|
||||||
linearLayer := &Linear{
|
linearLayer := &Linear{
|
||||||
Weight: &Tensor{
|
Weight: &Tensor{
|
||||||
Data: weight,
|
Data: weight,
|
||||||
Grad: Must(NewMatrix([][]float64{{0, 0}, {0, 0}})),
|
Grad: Must(gomatrix.NewMatrix([]float64{0, 0, 0, 0}, []int{2, 2})),
|
||||||
},
|
},
|
||||||
Bias: &Tensor{
|
Bias: &Tensor{
|
||||||
Data: bias,
|
Data: bias,
|
||||||
Grad: Must(NewVector([]float64{0, 0})),
|
Grad: Must(gomatrix.NewVector([]float64{0, 0})),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -169,7 +169,7 @@ func TestSaveLoadModel(t *testing.T) {
|
||||||
|
|
||||||
// TestLinearLayer 测试线性层功能
|
// TestLinearLayer 测试线性层功能
|
||||||
func TestLinearLayer(t *testing.T) {
|
func TestLinearLayer(t *testing.T) {
|
||||||
weight := Must(NewMatrix([][]float64{{2, 0}, {0, 3}}))
|
weight := Must(NewTensor([]float64{2, 0, 0, 3}, []int{2, 2}))
|
||||||
bias := Must(NewVector([]float64{0.5, 0.5}))
|
bias := Must(NewVector([]float64{0.5, 0.5}))
|
||||||
|
|
||||||
layer := NewLinear(weight, bias)
|
layer := NewLinear(weight, bias)
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ func NewVector(data []float64) (*Tensor, error) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Must(t *Tensor, err error) *Tensor {
|
func Must[T any](t *T, err error) *T {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue