package gotensor import ( "fmt" ) // Trainer 训练器结构,管理整个训练过程 type Trainer struct { Model Model Optimizer Optimizer } // NewTrainer 创建新的训练器 func NewTrainer(model Model, optimizer Optimizer) *Trainer { return &Trainer{ Model: model, Optimizer: optimizer, } } // TrainEpoch 训练一个epoch func (t *Trainer) TrainEpoch(inputs []*Tensor, targets []*Tensor, lossFn func(*Tensor, *Tensor) *Tensor) (float64, error) { var totalLoss float64 for i := 0; i < len(inputs); i++ { // 前向传播 output, err := t.Model.Forward(inputs[i]) if err != nil { return 0, err } // 计算损失 loss := lossFn(output, targets[i]) lossVal, _ := loss.Data.Get(0) totalLoss += lossVal // 反向传播 loss.Backward() // 更新参数 t.Optimizer.Step() // 清空梯度 t.Optimizer.ZeroGrad() } avgLoss := totalLoss / float64(len(inputs)) return avgLoss, nil } // Train 完整的训练过程 func (t *Trainer) Train( trainInputs []*Tensor, trainTargets []*Tensor, epochs int, lossFn func(*Tensor, *Tensor) *Tensor, verbose bool, ) error { for epoch := 0; epoch < epochs; epoch++ { avgLoss, err := t.TrainEpoch(trainInputs, trainTargets, lossFn) if err != nil { return err } if verbose { fmt.Printf("Epoch [%d/%d], Loss: %.6f\n", epoch+1, epochs, avgLoss) } } return nil } // Evaluate 评估模型性能 func (t *Trainer) Evaluate(testInputs []*Tensor, testTargets []*Tensor, lossFn func(*Tensor, *Tensor) *Tensor) (float64, error) { var totalLoss float64 for i := 0; i < len(testInputs); i++ { // 前向传播 output, err := t.Model.Forward(testInputs[i]) if err != nil { return 0, err } // 计算损失 loss := lossFn(output, testTargets[i]) lossVal, _ := loss.Data.Get(0) totalLoss += lossVal } avgLoss := totalLoss / float64(len(testInputs)) return avgLoss, nil }