92 lines
1.9 KiB
Go
92 lines
1.9 KiB
Go
package gotensor
|
|
|
|
import (
|
|
"fmt"
|
|
)
|
|
|
|
// Trainer 训练器结构,管理整个训练过程
|
|
type Trainer struct {
|
|
Model Model
|
|
Optimizer Optimizer
|
|
}
|
|
|
|
// NewTrainer 创建新的训练器
|
|
func NewTrainer(model Model, optimizer Optimizer) *Trainer {
|
|
return &Trainer{
|
|
Model: model,
|
|
Optimizer: optimizer,
|
|
}
|
|
}
|
|
|
|
// TrainEpoch 训练一个epoch
|
|
func (t *Trainer) TrainEpoch(inputs []*Tensor, targets []*Tensor, lossFn func(*Tensor, *Tensor) *Tensor) (float64, error) {
|
|
var totalLoss float64
|
|
|
|
for i := 0; i < len(inputs); i++ {
|
|
// 前向传播
|
|
output, err := t.Model.Forward(inputs[i])
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// 计算损失
|
|
loss := lossFn(output, targets[i])
|
|
lossVal, _ := loss.Data.Get(0)
|
|
totalLoss += lossVal
|
|
|
|
// 反向传播
|
|
loss.Backward()
|
|
|
|
// 更新参数
|
|
t.Optimizer.Step()
|
|
|
|
// 清空梯度
|
|
t.Optimizer.ZeroGrad()
|
|
}
|
|
|
|
avgLoss := totalLoss / float64(len(inputs))
|
|
return avgLoss, nil
|
|
}
|
|
|
|
// Train 完整的训练过程
|
|
func (t *Trainer) Train(
|
|
trainInputs []*Tensor,
|
|
trainTargets []*Tensor,
|
|
epochs int,
|
|
lossFn func(*Tensor, *Tensor) *Tensor,
|
|
verbose bool,
|
|
) error {
|
|
for epoch := 0; epoch < epochs; epoch++ {
|
|
avgLoss, err := t.TrainEpoch(trainInputs, trainTargets, lossFn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if verbose {
|
|
fmt.Printf("Epoch [%d/%d], Loss: %.6f\n", epoch+1, epochs, avgLoss)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Evaluate 评估模型性能
|
|
func (t *Trainer) Evaluate(testInputs []*Tensor, testTargets []*Tensor, lossFn func(*Tensor, *Tensor) *Tensor) (float64, error) {
|
|
var totalLoss float64
|
|
|
|
for i := 0; i < len(testInputs); i++ {
|
|
// 前向传播
|
|
output, err := t.Model.Forward(testInputs[i])
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// 计算损失
|
|
loss := lossFn(output, testTargets[i])
|
|
lossVal, _ := loss.Data.Get(0)
|
|
totalLoss += lossVal
|
|
}
|
|
|
|
avgLoss := totalLoss / float64(len(testInputs))
|
|
return avgLoss, nil
|
|
} |