gotensor/trainer.go

92 lines
1.9 KiB
Go

package gotensor
import (
"fmt"
)
// Trainer 训练器结构,管理整个训练过程
type Trainer struct {
Model Model
Optimizer Optimizer
}
// NewTrainer 创建新的训练器
func NewTrainer(model Model, optimizer Optimizer) *Trainer {
return &Trainer{
Model: model,
Optimizer: optimizer,
}
}
// TrainEpoch 训练一个epoch
func (t *Trainer) TrainEpoch(inputs []*Tensor, targets []*Tensor, lossFn func(*Tensor, *Tensor) *Tensor) (float64, error) {
var totalLoss float64
for i := 0; i < len(inputs); i++ {
// 前向传播
output, err := t.Model.Forward(inputs[i])
if err != nil {
return 0, err
}
// 计算损失
loss := lossFn(output, targets[i])
lossVal, _ := loss.Data.Get(0)
totalLoss += lossVal
// 反向传播
loss.Backward()
// 更新参数
t.Optimizer.Step()
// 清空梯度
t.Optimizer.ZeroGrad()
}
avgLoss := totalLoss / float64(len(inputs))
return avgLoss, nil
}
// Train 完整的训练过程
func (t *Trainer) Train(
trainInputs []*Tensor,
trainTargets []*Tensor,
epochs int,
lossFn func(*Tensor, *Tensor) *Tensor,
verbose bool,
) error {
for epoch := 0; epoch < epochs; epoch++ {
avgLoss, err := t.TrainEpoch(trainInputs, trainTargets, lossFn)
if err != nil {
return err
}
if verbose {
fmt.Printf("Epoch [%d/%d], Loss: %.6f\n", epoch+1, epochs, avgLoss)
}
}
return nil
}
// Evaluate 评估模型性能
func (t *Trainer) Evaluate(testInputs []*Tensor, testTargets []*Tensor, lossFn func(*Tensor, *Tensor) *Tensor) (float64, error) {
var totalLoss float64
for i := 0; i < len(testInputs); i++ {
// 前向传播
output, err := t.Model.Forward(testInputs[i])
if err != nil {
return 0, err
}
// 计算损失
loss := lossFn(output, testTargets[i])
lossVal, _ := loss.Data.Get(0)
totalLoss += lossVal
}
avgLoss := totalLoss / float64(len(testInputs))
return avgLoss, nil
}