gotensor/optimizer.go

233 lines
5.8 KiB
Go

package gotensor
import "math"
// Optimizer 优化器接口
type Optimizer interface {
Step() // 根据梯度更新参数
ZeroGrad() // 清空所有梯度
}
// SGD 随机梯度下降优化器
type SGD struct {
Parameters []*Tensor
LR float64 // 学习率
}
// NewSGD 创建一个新的SGD优化器
func NewSGD(parameters []*Tensor, lr float64) *SGD {
return &SGD{
Parameters: parameters,
LR: lr,
}
}
// Step 更新参数
func (s *SGD) Step() {
for _, param := range s.Parameters {
// 获取参数的梯度
grad := param.Grad
// 获取参数的形状
shape := param.Data.Shape()
// 更新参数: param = param - lr * grad
if len(shape) == 1 {
for i := 0; i < shape[0]; i++ {
paramVal, _ := param.Data.Get(i)
gradVal, _ := grad.Get(i)
newVal := paramVal - s.LR * gradVal
param.Data.Set(newVal, i)
}
} else if len(shape) == 2 {
rows, cols := shape[0], shape[1]
for i := 0; i < rows; i++ {
for j := 0; j < cols; j++ {
paramVal, _ := param.Data.Get(i, j)
gradVal, _ := grad.Get(i, j)
newVal := paramVal - s.LR * gradVal
param.Data.Set(newVal, i, j)
}
}
}
}
}
// ZeroGrad 清空所有梯度
func (s *SGD) ZeroGrad() {
for _, param := range s.Parameters {
param.ZeroGrad()
}
}
// Adam 优化器
type Adam struct {
Parameters []*Tensor
LR float64 // 学习率
Beta1 float64 // 一阶矩估计的指数衰减率
Beta2 float64 // 二阶矩估计的指数衰减率
Epsilon float64 // 防止除零的小常数
T int // 当前步数
// 一阶矩估计
M []map[string]*Tensor
// 二阶矩估计
V []map[string]*Tensor
}
// NewAdam 创建一个新的Adam优化器
func NewAdam(parameters []*Tensor, lr, beta1, beta2, epsilon float64) *Adam {
adam := &Adam{
Parameters: parameters,
LR: lr,
Beta1: beta1,
Beta2: beta2,
Epsilon: epsilon,
T: 0,
M: make([]map[string]*Tensor, len(parameters)),
V: make([]map[string]*Tensor, len(parameters)),
}
// 初始化M和V
for i := range parameters {
adam.M[i] = make(map[string]*Tensor)
adam.V[i] = make(map[string]*Tensor)
// 创建与参数形状相同的零张量
shape := parameters[i].Shape()
m, _ := NewZeros(shape)
v, _ := NewZeros(shape)
adam.M[i]["tensor"] = m
adam.V[i]["tensor"] = v
}
return adam
}
// Step 更新参数
func (a *Adam) Step() {
a.T++
for i, param := range a.Parameters {
grad := param.Grad
shape := param.Data.Shape()
// 更新一阶矩估计: m = beta1 * m + (1 - beta1) * grad
m := a.M[i]["tensor"]
newMData := make([]float64, param.Size())
if len(shape) == 1 {
for idx := 0; idx < shape[0]; idx++ {
mVal, _ := m.Data.Get(idx)
gradVal, _ := grad.Get(idx)
newMData[idx] = a.Beta1 * mVal + (1 - a.Beta1) * gradVal
}
} else if len(shape) == 2 {
rows, cols := shape[0], shape[1]
for r := 0; r < rows; r++ {
for c := 0; c < cols; c++ {
mVal, _ := m.Data.Get(r, c)
gradVal, _ := grad.Get(r, c)
newMData[r*cols+c] = a.Beta1 * mVal + (1 - a.Beta1) * gradVal
}
}
}
newM, _ := NewTensor(newMData, shape)
a.M[i]["tensor"] = newM
// 更新二阶矩估计: v = beta2 * v + (1 - beta2) * grad^2
v := a.V[i]["tensor"]
newVData := make([]float64, param.Size())
if len(shape) == 1 {
for idx := 0; idx < shape[0]; idx++ {
vVal, _ := v.Data.Get(idx)
gradVal, _ := grad.Get(idx)
newVData[idx] = a.Beta2 * vVal + (1 - a.Beta2) * gradVal * gradVal
}
} else if len(shape) == 2 {
rows, cols := shape[0], shape[1]
for r := 0; r < rows; r++ {
for c := 0; c < cols; c++ {
vVal, _ := v.Data.Get(r, c)
gradVal, _ := grad.Get(r, c)
newVData[r*cols+c] = a.Beta2 * vVal + (1 - a.Beta2) * gradVal * gradVal
}
}
}
newV, _ := NewTensor(newVData, shape)
a.V[i]["tensor"] = newV
// 计算偏差修正的一阶矩估计
mHatData := make([]float64, param.Size())
mHatShape := shape
if len(shape) == 1 {
for idx := 0; idx < shape[0]; idx++ {
mVal, _ := newM.Data.Get(idx)
mHatData[idx] = mVal / (1 - math.Pow(a.Beta1, float64(a.T)))
}
} else if len(shape) == 2 {
rows, cols := shape[0], shape[1]
for r := 0; r < rows; r++ {
for c := 0; c < cols; c++ {
mVal, _ := newM.Data.Get(r, c)
mHatData[r*cols+c] = mVal / (1 - math.Pow(a.Beta1, float64(a.T)))
}
}
}
mHat, _ := NewTensor(mHatData, mHatShape)
// 计算偏差修正的二阶矩估计
vHatData := make([]float64, param.Size())
vHatShape := shape
if len(shape) == 1 {
for idx := 0; idx < shape[0]; idx++ {
vVal, _ := newV.Data.Get(idx)
vHatData[idx] = vVal / (1 - math.Pow(a.Beta2, float64(a.T)))
}
} else if len(shape) == 2 {
rows, cols := shape[0], shape[1]
for r := 0; r < rows; r++ {
for c := 0; c < cols; c++ {
vVal, _ := newV.Data.Get(r, c)
vHatData[r*cols+c] = vVal / (1 - math.Pow(a.Beta2, float64(a.T)))
}
}
}
vHat, _ := NewTensor(vHatData, vHatShape)
// 更新参数: param = param - lr * m_hat / (sqrt(v_hat) + epsilon)
if len(shape) == 1 {
for idx := 0; idx < shape[0]; idx++ {
paramVal, _ := param.Data.Get(idx)
mHatVal, _ := mHat.Data.Get(idx)
vHatVal, _ := vHat.Data.Get(idx)
updateVal := a.LR * mHatVal / (math.Sqrt(vHatVal) + a.Epsilon)
newVal := paramVal - updateVal
param.Data.Set(newVal, idx)
}
} else if len(shape) == 2 {
rows, cols := shape[0], shape[1]
for r := 0; r < rows; r++ {
for c := 0; c < cols; c++ {
paramVal, _ := param.Data.Get(r, c)
mHatVal, _ := mHat.Data.Get(r, c)
vHatVal, _ := vHat.Data.Get(r, c)
updateVal := a.LR * mHatVal / (math.Sqrt(vHatVal) + a.Epsilon)
newVal := paramVal - updateVal
param.Data.Set(newVal, r, c)
}
}
}
}
}
// ZeroGrad 清空所有梯度
func (a *Adam) ZeroGrad() {
for _, param := range a.Parameters {
param.ZeroGrad()
}
}