125 lines
2.4 KiB
Go
125 lines
2.4 KiB
Go
package gotensor
|
|
|
|
import (
|
|
"encoding/json"
|
|
"os"
|
|
)
|
|
|
|
// Model 模型接口定义
|
|
type Model interface {
|
|
Forward(inputs *Tensor) (*Tensor, error)
|
|
Parameters() []*Tensor // 获取模型所有参数
|
|
ZeroGrad() // 将所有参数的梯度清零
|
|
}
|
|
|
|
// Sequential 序列模型,按顺序执行层
|
|
type Sequential struct {
|
|
Layers []Layer
|
|
}
|
|
|
|
// Layer 接口定义
|
|
type Layer interface {
|
|
Forward(inputs *Tensor) (*Tensor, error)
|
|
Parameters() []*Tensor
|
|
ZeroGrad()
|
|
}
|
|
|
|
// Forward 实现前向传播
|
|
func (s *Sequential) Forward(inputs *Tensor) (*Tensor, error) {
|
|
output := inputs
|
|
var err error
|
|
|
|
for _, layer := range s.Layers {
|
|
output, err = layer.Forward(output)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return output, nil
|
|
}
|
|
|
|
// Parameters 获取模型所有参数
|
|
func (s *Sequential) Parameters() []*Tensor {
|
|
var params []*Tensor
|
|
for _, layer := range s.Layers {
|
|
params = append(params, layer.Parameters()...)
|
|
}
|
|
return params
|
|
}
|
|
|
|
// ZeroGrad 将所有参数梯度清零
|
|
func (s *Sequential) ZeroGrad() {
|
|
for _, layer := range s.Layers {
|
|
layer.ZeroGrad()
|
|
}
|
|
}
|
|
|
|
// SaveModel 保存模型参数到文件
|
|
func SaveModel(model Model, filepath string) error {
|
|
params := model.Parameters()
|
|
paramsData := make([][]float64, len(params))
|
|
|
|
for i, param := range params {
|
|
shape := param.Shape()
|
|
size := param.Size()
|
|
data := make([]float64, size)
|
|
|
|
for idx := 0; idx < size; idx++ {
|
|
if len(shape) == 1 {
|
|
data[idx], _ = param.Data.Get(idx)
|
|
} else if len(shape) == 2 {
|
|
cols := shape[1]
|
|
data[idx], _ = param.Data.Get(idx/cols, idx%cols)
|
|
}
|
|
}
|
|
paramsData[i] = data
|
|
}
|
|
|
|
file, err := os.Create(filepath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer file.Close()
|
|
|
|
return json.NewEncoder(file).Encode(paramsData)
|
|
}
|
|
|
|
// LoadModel 从文件加载模型参数
|
|
func LoadModel(model Model, filepath string) error {
|
|
file, err := os.Open(filepath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer file.Close()
|
|
|
|
var paramsData [][]float64
|
|
err = json.NewDecoder(file).Decode(¶msData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
params := model.Parameters()
|
|
if len(params) != len(paramsData) {
|
|
return nil // 参数数量不匹配,返回错误
|
|
}
|
|
|
|
for i, param := range params {
|
|
data := paramsData[i]
|
|
shape := param.Shape()
|
|
|
|
if len(shape) == 1 {
|
|
for idx, val := range data {
|
|
param.Data.Set(val, idx)
|
|
}
|
|
} else if len(shape) == 2 {
|
|
cols := shape[1]
|
|
for idx, val := range data {
|
|
param.Data.Set(val, idx/cols, idx%cols)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|