210 lines
5.3 KiB
Go
210 lines
5.3 KiB
Go
package gotensor
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
|
|
"git.kingecg.top/kingecg/gomatrix"
|
|
)
|
|
|
|
func NewMatrix(data [][]float64) (*gomatrix.Matrix, error) {
|
|
c := make([]float64, len(data)*len(data[0]))
|
|
for i := 0; i < len(c); i++ {
|
|
c[i] = data[i/len(data[0])][i%len(data[0])]
|
|
}
|
|
return gomatrix.NewMatrix(c, []int{len(data), len(data[0])})
|
|
}
|
|
|
|
// TestSGD 测试SGD优化器
|
|
func TestSGD(t *testing.T) {
|
|
// 创建一些参数用于测试
|
|
weightData, _ := NewMatrix([][]float64{{1.0, 2.0}, {3.0, 4.0}})
|
|
weightGrad, _ := NewMatrix([][]float64{{0.1, 0.2}, {0.3, 0.4}})
|
|
|
|
params := []*Tensor{
|
|
{
|
|
Data: Must(gomatrix.NewMatrix([]float64{1.0, 2.0, 3.0}, []int{3, 1})),
|
|
Grad: Must(gomatrix.NewMatrix([]float64{0.1, 0.2, 0.3}, []int{3, 1})),
|
|
},
|
|
{
|
|
Data: weightData,
|
|
Grad: weightGrad,
|
|
},
|
|
}
|
|
|
|
// 创建SGD优化器
|
|
lr := 0.1
|
|
sgd := NewSGD(params, lr)
|
|
|
|
// 保存原始参数值
|
|
origVec0, _ := params[0].Data.Get(0, 0)
|
|
origMat00, _ := params[1].Data.Get(0, 0)
|
|
|
|
// 执行一步优化
|
|
sgd.Step()
|
|
|
|
// 检查参数是否已更新
|
|
newVec0, _ := params[0].Data.Get(0, 0)
|
|
newMat00, _ := params[1].Data.Get(0, 0)
|
|
|
|
expectedVec0 := origVec0 - lr*0.1
|
|
expectedMat00 := origMat00 - lr*0.1
|
|
|
|
if math.Abs(newVec0-expectedVec0) > 1e-9 {
|
|
t.Errorf("Expected updated param[0][0] to be %v, got %v", expectedVec0, newVec0)
|
|
}
|
|
|
|
if math.Abs(newMat00-expectedMat00) > 1e-9 {
|
|
t.Errorf("Expected updated param[1][0,0] to be %v, got %v", expectedMat00, newMat00)
|
|
}
|
|
|
|
// 测试ZeroGrad
|
|
sgd.ZeroGrad()
|
|
for _, param := range params {
|
|
shape := param.Shape()
|
|
for i := 0; i < param.Size(); i++ {
|
|
var gradVal float64
|
|
if len(shape) == 1 {
|
|
gradVal, _ = param.Grad.Get(i)
|
|
} else if len(shape) == 2 {
|
|
cols := shape[1]
|
|
gradVal, _ = param.Grad.Get(i/cols, i%cols)
|
|
}
|
|
if math.Abs(gradVal) > 1e-9 {
|
|
t.Errorf("Expected gradient to be zero after ZeroGrad, got %v", gradVal)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestAdam 测试Adam优化器
|
|
func TestAdam(t *testing.T) {
|
|
// 创建一些参数用于测试
|
|
params := []*Tensor{
|
|
{
|
|
Data: Must(gomatrix.NewMatrix([]float64{1.0, 2.0}, []int{2, 1})),
|
|
Grad: Must(gomatrix.NewMatrix([]float64{0.1, 0.2}, []int{2, 1})),
|
|
},
|
|
}
|
|
|
|
// 创建Adam优化器
|
|
lr := 0.001
|
|
beta1 := 0.9
|
|
beta2 := 0.999
|
|
epsilon := 1e-8
|
|
adam := NewAdam(params, lr, beta1, beta2, epsilon)
|
|
|
|
// 保存原始参数值
|
|
origVec0, _ := params[0].Data.Get(0, 0)
|
|
|
|
// 执行几步优化
|
|
for i := 0; i < 3; i++ {
|
|
adam.Step()
|
|
}
|
|
|
|
// 检查参数是否已更新
|
|
newVec0, _ := params[0].Data.Get(0, 0)
|
|
|
|
if math.Abs(newVec0-origVec0) < 1e-9 {
|
|
t.Errorf("Expected parameter to be updated, but it wasn't. Original: %v, New: %v", origVec0, newVec0)
|
|
}
|
|
|
|
// 验证内部状态是否已创建
|
|
if len(adam.M) != len(params) || len(adam.V) != len(params) {
|
|
t.Error("Adam internal states M and V not properly initialized")
|
|
}
|
|
|
|
// 测试ZeroGrad
|
|
adam.ZeroGrad()
|
|
for _, param := range params {
|
|
shape := param.Shape()
|
|
for i := 0; i < param.Size(); i++ {
|
|
var gradVal float64
|
|
if len(shape) == 1 {
|
|
gradVal, _ = param.Grad.Get(i)
|
|
} else if len(shape) == 2 {
|
|
cols := shape[1]
|
|
gradVal, _ = param.Grad.Get(i/cols, i%cols)
|
|
}
|
|
if math.Abs(gradVal) > 1e-9 {
|
|
t.Errorf("Expected gradient to be zero after ZeroGrad, got %v", gradVal)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestAdamWithMatrix 测试Adam优化器处理矩阵参数
|
|
func TestAdamWithMatrix(t *testing.T) {
|
|
matrixData, _ := NewMatrix([][]float64{{1.0, 2.0}, {3.0, 4.0}})
|
|
matrixGrad, _ := NewMatrix([][]float64{{0.1, 0.2}, {0.3, 0.4}})
|
|
|
|
// 创建矩阵参数用于测试
|
|
params := []*Tensor{
|
|
{
|
|
Data: matrixData,
|
|
Grad: matrixGrad,
|
|
},
|
|
}
|
|
|
|
// 创建Adam优化器
|
|
lr := 0.001
|
|
adam := NewAdam(params, lr, 0.9, 0.999, 1e-8)
|
|
|
|
// 验证内部状态是否已正确创建
|
|
if len(adam.M) != len(params) || len(adam.V) != len(params) {
|
|
t.Fatalf("Adam internal states M and V not properly initialized. Expected %d states, got M:%d, V:%d",
|
|
len(params), len(adam.M), len(adam.V))
|
|
}
|
|
|
|
// 验证内部状态矩阵的形状与参数一致
|
|
mShape := adam.M[0]["tensor"].Shape()
|
|
vShape := adam.V[0]["tensor"].Shape()
|
|
paramShape := params[0].Shape()
|
|
if mShape[0] != paramShape[0] || mShape[1] != paramShape[1] ||
|
|
vShape[0] != paramShape[0] || vShape[1] != paramShape[1] {
|
|
t.Errorf("Adam internal state shapes don't match parameter shape. "+
|
|
"Param: %v, M: %v, V: %v", paramShape, mShape, vShape)
|
|
}
|
|
|
|
// 保存原始参数值的副本
|
|
originalData := make([][]float64, paramShape[0])
|
|
for i := 0; i < paramShape[0]; i++ {
|
|
originalData[i] = make([]float64, paramShape[1])
|
|
for j := 0; j < paramShape[1]; j++ {
|
|
originalData[i][j], _ = params[0].Data.Get(i, j)
|
|
}
|
|
}
|
|
|
|
// 执行几步优化
|
|
for i := 0; i < 5; i++ {
|
|
adam.Step()
|
|
}
|
|
|
|
// 检查所有参数是否已更新
|
|
updated := false
|
|
for i := 0; i < paramShape[0]; i++ {
|
|
for j := 0; j < paramShape[1]; j++ {
|
|
newVal, _ := params[0].Data.Get(i, j)
|
|
if math.Abs(newVal-originalData[i][j]) > 1e-9 {
|
|
updated = true
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
if !updated {
|
|
t.Errorf("Expected parameters to be updated, but none were changed")
|
|
}
|
|
|
|
// 额外验证更新值是否合理(应该向梯度相反方向移动)
|
|
firstOrig := originalData[0][0]
|
|
firstNew, _ := params[0].Data.Get(0, 0)
|
|
firstGrad, _ := params[0].Grad.Get(0, 0)
|
|
|
|
// 参数应该沿着梯度的反方向更新
|
|
if (firstNew-firstOrig)*firstGrad > 0 {
|
|
t.Errorf("Parameter updated in wrong direction. delta=%v, grad=%v",
|
|
firstNew-firstOrig, firstGrad)
|
|
}
|
|
}
|