gotensor/README.md

138 lines
3.5 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# gotensor
gotensor 是一个用 Go 语言实现的张量计算库,专注于为 Go 开发者提供高效的数值计算能力,支持自动微分和反向传播,适用于构建轻量级机器学习模型。
## 功能特性
- 基本张量运算:加法、减法、乘法、矩阵乘法
- 数乘、转置等张量操作
- 自动微分与反向传播机制
- 激活函数Sigmoid、ReLU、Softmax
- 卷积与池化操作Conv2D、MaxPool2D、AvgPool2D
- 神经网络层与损失函数Flatten、CrossEntropy、MSE
- 多种初始化方式:零张量、单位矩阵、随机张量
- 模型定义和训练支持
- 模型保存和加载
- 多种优化器SGD、Adam
## 安装
```bash
go get git.kingecg.top/kingecg/gotensor
```
## 快速开始
```go
package main
import (
"fmt"
"git.kingecg.top/kingecg/gotensor"
)
func main() {
// 创建张量
tensor1, _ := gotensor.NewTensor([]float64{1, 2, 3}, []int{1, 3})
tensor2, _ := gotensor.NewTensor([]float64{4, 5, 6}, []int{3, 1})
// 执行矩阵乘法
result, _ := tensor1.MatMul(tensor2)
fmt.Println(result)
}
```
## 模型训练示例
```go
package main
import (
"fmt"
"git.kingecg.top/kingecg/gotensor"
)
func main() {
// 创建模型(例如:简单的线性回归模型)
// 这里可以使用Sequential模型或者自定义模型
// 定义一些示例数据
input, _ := gotensor.NewTensor([]float64{1, 2, 3, 4}, []int{2, 2})
target, _ := gotensor.NewTensor([]float64{5, 6, 7, 8}, []int{2, 2})
// 创建模型参数
weights, _ := gotensor.NewTensor([]float64{0.5, 0.3, 0.2, 0.4}, []int{2, 2})
// 定义模型(这里简化为单个张量,实际中会是更复杂的结构)
// ...
// 定义优化器
optimizer := gotensor.NewSGD([]*gotensor.Tensor{weights}, 0.01)
// 创建训练器
trainer := gotensor.NewTrainer(nil, optimizer) // 需要传入实际模型
// 开始训练
// trainer.Train(trainInputs, trainTargets, epochs, lossFn, true)
fmt.Println("Training example")
}
```
## API 文档
### 张量操作
- `NewTensor(data []float64, shape []int)`: 创建新张量
- `Add(other *Tensor)`: 张量加法
- `Subtract(other *Tensor)`: 张量减法
- `Multiply(other *Tensor)`: 张量逐元素乘法
- `MatMul(other *Tensor)`: 矩阵乘法
- `Scale(factor float64)`: 张量数乘
- `Sigmoid()`: Sigmoid激活函数
- `ReLU()`: ReLU激活函数
- `Softmax()`: Softmax函数
- `Backward()`: 反向传播
### 神经网络层
- `Flatten()`: 展平张量
- `CrossEntropy(target *Tensor)`: 交叉熵损失
- `MeanSquaredError(target *Tensor)`: 均方误差损失
### 模型定义
- `Sequential`: 序列模型
- `Model` 接口: 模型的基本接口
- `SaveModel(model Model, filepath string)`: 保存模型
- `LoadModel(model Model, filepath string)`: 加载模型
### 优化器
- `SGD`: 随机梯度下降
- `Adam`: Adam优化算法
### 训练器
- `Trainer`: 训练管理器
- `NewTrainer(model Model, optimizer Optimizer)`: 创建训练器
- `Train(...)`: 执行训练
- `Evaluate(...)`: 评估模型
## 示例
项目包含多个示例:
- `examples/basic_operation`: 基本张量运算示例
- `examples/autograd`: 自动微分示例
- `examples/linear_regression`: 线性回归示例
- `examples/cnn_example.go`: 卷积神经网络示例
## 贡献
欢迎提交 Issue 和 Pull Request 来帮助改进 gotensor
## 许可证
本项目使用 MIT 许可证 - 详见 [LICENSE](LICENSE) 文件。