package gotensor import ( "encoding/json" "os" ) // Model 模型接口定义 type Model interface { Forward(inputs *Tensor) (*Tensor, error) Parameters() []*Tensor // 获取模型所有参数 ZeroGrad() // 将所有参数的梯度清零 } // Sequential 序列模型,按顺序执行层 type Sequential struct { Layers []Layer } // Layer 接口定义 type Layer interface { Forward(inputs *Tensor) (*Tensor, error) Parameters() []*Tensor ZeroGrad() } // Forward 实现前向传播 func (s *Sequential) Forward(inputs *Tensor) (*Tensor, error) { output := inputs var err error for _, layer := range s.Layers { output, err = layer.Forward(output) if err != nil { return nil, err } } return output, nil } // Parameters 获取模型所有参数 func (s *Sequential) Parameters() []*Tensor { var params []*Tensor for _, layer := range s.Layers { params = append(params, layer.Parameters()...) } return params } // ZeroGrad 将所有参数梯度清零 func (s *Sequential) ZeroGrad() { for _, layer := range s.Layers { layer.ZeroGrad() } } // SaveModel 保存模型参数到文件 func SaveModel(model Model, filepath string) error { params := model.Parameters() paramsData := make([][]float64, len(params)) for i, param := range params { shape := param.Shape() size := param.Size() data := make([]float64, size) for idx := 0; idx < size; idx++ { if len(shape) == 1 { data[idx], _ = param.Data.Get(idx) } else if len(shape) == 2 { cols := shape[1] data[idx], _ = param.Data.Get(idx/cols, idx%cols) } } paramsData[i] = data } file, err := os.Create(filepath) if err != nil { return err } defer file.Close() return json.NewEncoder(file).Encode(paramsData) } // LoadModel 从文件加载模型参数 func LoadModel(model Model, filepath string) error { file, err := os.Open(filepath) if err != nil { return err } defer file.Close() var paramsData [][]float64 err = json.NewDecoder(file).Decode(¶msData) if err != nil { return err } params := model.Parameters() if len(params) != len(paramsData) { return nil // 参数数量不匹配,返回错误 } for i, param := range params { data := paramsData[i] shape := param.Shape() if len(shape) == 1 { for idx, val := range data { param.Data.Set(val, idx) } } else if len(shape) == 2 { cols := shape[1] for idx, val := range data { param.Data.Set(val, idx/cols, idx%cols) } } } return nil }