gotensor/optimizer_test.go

157 lines
3.6 KiB
Go

package gotensor
import (
"testing"
"math"
)
// TestSGD 测试SGD优化器
func TestSGD(t *testing.T) {
// 创建一些参数用于测试
weightData, _ := NewMatrix([][]float64{{1.0, 2.0}, {3.0, 4.0}})
weightGrad, _ := NewMatrix([][]float64{{0.1, 0.2}, {0.3, 0.4}})
params := []*Tensor{
{
Data: Must(NewVector([]float64{1.0, 2.0, 3.0})),
Grad: Must(NewVector([]float64{0.1, 0.2, 0.3})),
},
{
Data: weightData,
Grad: weightGrad,
},
}
// 创建SGD优化器
lr := 0.1
sgd := NewSGD(params, lr)
// 保存原始参数值
origVec0, _ := params[0].Data.Get(0)
origMat00, _ := params[1].Data.Get(0, 0)
// 执行一步优化
sgd.Step()
// 检查参数是否已更新
newVec0, _ := params[0].Data.Get(0)
newMat00, _ := params[1].Data.Get(0, 0)
expectedVec0 := origVec0 - lr*0.1
expectedMat00 := origMat00 - lr*0.1
if math.Abs(newVec0-expectedVec0) > 1e-9 {
t.Errorf("Expected updated param[0][0] to be %v, got %v", expectedVec0, newVec0)
}
if math.Abs(newMat00-expectedMat00) > 1e-9 {
t.Errorf("Expected updated param[1][0,0] to be %v, got %v", expectedMat00, newMat00)
}
// 测试ZeroGrad
sgd.ZeroGrad()
for _, param := range params {
shape := param.Shape()
for i := 0; i < param.Size(); i++ {
var gradVal float64
if len(shape) == 1 {
gradVal, _ = param.Grad.Get(i)
} else if len(shape) == 2 {
cols := shape[1]
gradVal, _ = param.Grad.Get(i/cols, i%cols)
}
if math.Abs(gradVal) > 1e-9 {
t.Errorf("Expected gradient to be zero after ZeroGrad, got %v", gradVal)
}
}
}
}
// TestAdam 测试Adam优化器
func TestAdam(t *testing.T) {
// 创建一些参数用于测试
params := []*Tensor{
{
Data: Must(NewVector([]float64{1.0, 2.0})),
Grad: Must(NewVector([]float64{0.1, 0.2})),
},
}
// 创建Adam优化器
lr := 0.001
beta1 := 0.9
beta2 := 0.999
epsilon := 1e-8
adam := NewAdam(params, lr, beta1, beta2, epsilon)
// 保存原始参数值
origVec0, _ := params[0].Data.Get(0)
// 执行几步优化
for i := 0; i < 3; i++ {
adam.Step()
}
// 检查参数是否已更新
newVec0, _ := params[0].Data.Get(0)
if math.Abs(newVec0-origVec0) < 1e-9 {
t.Errorf("Expected parameter to be updated, but it wasn't. Original: %v, New: %v", origVec0, newVec0)
}
// 验证内部状态是否已创建
if len(adam.M) != len(params) || len(adam.V) != len(params) {
t.Error("Adam internal states M and V not properly initialized")
}
// 测试ZeroGrad
adam.ZeroGrad()
for _, param := range params {
shape := param.Shape()
for i := 0; i < param.Size(); i++ {
var gradVal float64
if len(shape) == 1 {
gradVal, _ = param.Grad.Get(i)
} else if len(shape) == 2 {
cols := shape[1]
gradVal, _ = param.Grad.Get(i/cols, i%cols)
}
if math.Abs(gradVal) > 1e-9 {
t.Errorf("Expected gradient to be zero after ZeroGrad, got %v", gradVal)
}
}
}
}
// TestAdamWithMatrix 测试Adam优化器处理矩阵参数
func TestAdamWithMatrix(t *testing.T) {
matrixData, _ := NewMatrix([][]float64{{1.0, 2.0}, {3.0, 4.0}})
matrixGrad, _ := NewMatrix([][]float64{{0.1, 0.2}, {0.3, 0.4}})
// 创建矩阵参数用于测试
params := []*Tensor{
{
Data: matrixData,
Grad: matrixGrad,
},
}
// 创建Adam优化器
lr := 0.001
adam := NewAdam(params, lr, 0.9, 0.999, 1e-8)
// 保存原始参数值
origMat00, _ := params[0].Data.Get(0, 0)
// 执行几步优化
for i := 0; i < 5; i++ {
adam.Step()
}
// 检查参数是否已更新
newMat00, _ := params[0].Data.Get(0, 0)
if math.Abs(newMat00-origMat00) < 1e-9 {
t.Errorf("Expected parameter to be updated, but it wasn't. Original: %v, New: %v", origMat00, newMat00)
}
}