```
docs: 添加项目文档和示例代码 添加了完整的项目文档,包括README.md文件,详细介绍gotensor库的功能特性、 安装方法和使用示例。同时添加了多个示例程序展示基本运算、自动微分和线性 回归功能,并完善了测试用例。 - 添加LICENSE文件(MIT许可证) - 添加README.md项目文档 - 添加基本运算示例 - 添加自动微分示例 - 添加线性回归示例 - 添加单元测试文件 ```
This commit is contained in:
parent
2da5bc6ece
commit
9d6d4bdf56
|
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2025 kingecg
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
@ -0,0 +1,110 @@
|
||||||
|
# gotensor
|
||||||
|
|
||||||
|
gotensor 是一个用 Go 语言编写的张量计算库,提供了基本的张量运算、自动微分和反向传播功能。该项目旨在为 Go 语言开发者提供一个高效、易用的张量计算工具。
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
- 基本张量运算:加法、减法、乘法、矩阵乘法等
|
||||||
|
- 张量操作:数乘、转置等
|
||||||
|
- 自动微分和反向传播
|
||||||
|
- 支持多种初始化方式:零张量、单位矩阵、随机张量等
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get git.kingecg.top/kingecg/gotensor
|
||||||
|
```
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"git.kingecg.top/kingecg/gotensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// 创建两个2x2的张量
|
||||||
|
t1_data := []float64{1, 2, 3, 4}
|
||||||
|
t1_shape := []int{2, 2}
|
||||||
|
t1, err := gotensor.NewTensor(t1_data, t1_shape)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t2_data := []float64{5, 6, 7, 8}
|
||||||
|
t2, err := gotensor.NewTensor(t2_data, t1_shape)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行加法运算
|
||||||
|
result, err := t1.Add(t2)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("结果:\n%s\n", result.String())
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 示例
|
||||||
|
|
||||||
|
项目包含多个示例,展示如何使用 gotensor:
|
||||||
|
|
||||||
|
- [基本运算示例](examples/basic_operations.go):展示基本的张量运算
|
||||||
|
- [自动微分示例](examples/autograd_example.go):演示自动微分和反向传播
|
||||||
|
- [线性回归示例](examples/linear_regression.go):使用 gotensor 实现简单的线性回归
|
||||||
|
|
||||||
|
运行示例:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 基本运算示例
|
||||||
|
go run examples/basic_operations.go
|
||||||
|
|
||||||
|
# 自动微分示例
|
||||||
|
go run examples/autograd_example.go
|
||||||
|
|
||||||
|
# 线性回归示例
|
||||||
|
go run examples/linear_regression.go
|
||||||
|
```
|
||||||
|
|
||||||
|
## API 文档
|
||||||
|
|
||||||
|
### 创建张量
|
||||||
|
|
||||||
|
- `NewTensor(data []float64, shape []int)` - 创建新的张量
|
||||||
|
- `NewZeros(shape []int)` - 创建零张量
|
||||||
|
- `NewOnes(shape []int)` - 创建全一张量
|
||||||
|
- `NewIdentity(size int)` - 创建单位矩阵
|
||||||
|
|
||||||
|
### 张量运算
|
||||||
|
|
||||||
|
- `Add(other *Tensor)` - 张量加法
|
||||||
|
- `Subtract(other *Tensor)` - 张量减法
|
||||||
|
- `Multiply(other *Tensor)` - 张量逐元素乘法
|
||||||
|
- `MatMul(other *Tensor)` - 矩阵乘法
|
||||||
|
- `Scale(factor float64)` - 数乘
|
||||||
|
|
||||||
|
### 其他方法
|
||||||
|
|
||||||
|
- `ZeroGrad()` - 将梯度置零
|
||||||
|
- `Shape()` - 返回张量形状
|
||||||
|
- `Size()` - 返回张量大小
|
||||||
|
- `Get(indices ...int)` - 获取指定位置的值
|
||||||
|
- `Set(value float64, indices ...int)` - 设置指定位置的值
|
||||||
|
- `Backward()` - 执行反向传播
|
||||||
|
|
||||||
|
## 测试
|
||||||
|
|
||||||
|
运行测试:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go test
|
||||||
|
```
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
|
||||||
|
本项目使用 MIT 许可证 - 详见 [LICENSE](LICENSE) 文件。
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"git.kingecg.top/kingecg/gotensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Println("=== 自动微分和反向传播示例 ===")
|
||||||
|
|
||||||
|
// 创建一些张量,用于模拟简单的计算图
|
||||||
|
// 例如: z = (w * x) + b,其中w, x, b是张量
|
||||||
|
w_data := []float64{2.0}
|
||||||
|
w, err := gotensor.NewTensor(w_data, []int{1})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
x_data := []float64{3.0}
|
||||||
|
x, err := gotensor.NewTensor(x_data, []int{1})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b_data := []float64{1.0}
|
||||||
|
b, err := gotensor.NewTensor(b_data, []int{1})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("w = %s\n", w.String())
|
||||||
|
fmt.Printf("x = %s\n", x.String())
|
||||||
|
fmt.Printf("b = %s\n", b.String())
|
||||||
|
|
||||||
|
// 计算 y = w * x
|
||||||
|
y, err := w.Multiply(x)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("y = w * x = %s\n", y.String())
|
||||||
|
|
||||||
|
// 计算 z = y + b
|
||||||
|
z, err := y.Add(b)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("z = y + b = %s\n", z.String())
|
||||||
|
|
||||||
|
// 设置z的梯度为1 (通常这是损失函数,所以梯度是1)
|
||||||
|
// 在实际应用中,这通常是损失函数的梯度
|
||||||
|
z.ZeroGrad() // 确保梯度初始化为零
|
||||||
|
// 在这里,我们直接操作梯度,将输出节点的梯度设为1
|
||||||
|
// 为了演示,我们手动设置输出梯度
|
||||||
|
// 由于z是最终输出,我们设置其梯度为1
|
||||||
|
|
||||||
|
// 在实际自动微分中,我们会从最终输出开始反向传播
|
||||||
|
// 手动设置输出梯度为1,模拟损失函数的梯度
|
||||||
|
one, _ := gotensor.NewOnes(z.Shape())
|
||||||
|
z.Grad = one.Data
|
||||||
|
|
||||||
|
// 执行反向传播
|
||||||
|
z.Backward()
|
||||||
|
|
||||||
|
fmt.Printf("反向传播后:\n")
|
||||||
|
fmt.Printf("w的梯度: %s\n", w.String()) // 这将显示w的梯度
|
||||||
|
fmt.Printf("x的梯度: %s\n", x.String()) // 这将显示x的梯度
|
||||||
|
fmt.Printf("b的梯度: %s\n", b.String()) // 这将显示b的梯度
|
||||||
|
|
||||||
|
// 验证梯度计算
|
||||||
|
// 对于 z = w*x + b:
|
||||||
|
// dz/dw = x = 3
|
||||||
|
// dz/dx = w = 2
|
||||||
|
// dz/db = 1
|
||||||
|
fmt.Println("\n=== 梯度验证 ===")
|
||||||
|
w_grad, _ := w.Grad.Get(0)
|
||||||
|
x_grad, _ := x.Grad.Get(0)
|
||||||
|
b_grad, _ := b.Grad.Get(0)
|
||||||
|
|
||||||
|
fmt.Printf("w的梯度 (预期x=3.0): %.2f\n", w_grad)
|
||||||
|
fmt.Printf("x的梯度 (预期w=2.0): %.2f\n", x_grad)
|
||||||
|
fmt.Printf("b的梯度 (预期1.0): %.2f\n", b_grad)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,72 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"git.kingecg.top/kingecg/gotensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Println("=== 基本运算示例 ===")
|
||||||
|
|
||||||
|
// 创建两个2x2的张量
|
||||||
|
t1_data := []float64{1, 2, 3, 4}
|
||||||
|
t1_shape := []int{2, 2}
|
||||||
|
t1, err := gotensor.NewTensor(t1_data, t1_shape)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t2_data := []float64{5, 6, 7, 8}
|
||||||
|
t2, err := gotensor.NewTensor(t2_data, t1_shape)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("张量1:\n%s\n", t1.String())
|
||||||
|
fmt.Printf("张量2:\n%s\n", t2.String())
|
||||||
|
|
||||||
|
// 加法运算
|
||||||
|
add_result, err := t1.Add(t2)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("加法结果 (t1 + t2):\n%s\n", add_result.String())
|
||||||
|
|
||||||
|
// 减法运算
|
||||||
|
sub_result, err := t1.Subtract(t2)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("减法结果 (t1 - t2):\n%s\n", sub_result.String())
|
||||||
|
|
||||||
|
// 逐元素乘法
|
||||||
|
mul_result, err := t1.Multiply(t2)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("逐元素乘法结果 (t1 * t2):\n%s\n", mul_result.String())
|
||||||
|
|
||||||
|
// 数乘
|
||||||
|
scale_result := t1.Scale(2.0)
|
||||||
|
fmt.Printf("数乘结果 (t1 * 2):\n%s\n", scale_result.String())
|
||||||
|
|
||||||
|
// 矩阵乘法
|
||||||
|
matmul_result, err := t1.MatMul(t2)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("矩阵乘法结果 (t1 @ t2):\n%s\n", matmul_result.String())
|
||||||
|
|
||||||
|
// 创建零张量和单位矩阵
|
||||||
|
zeros, err := gotensor.NewZeros([]int{2, 3})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("2x3零张量:\n%s\n", zeros.String())
|
||||||
|
|
||||||
|
identity, err := gotensor.NewIdentity(3)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("3x3单位矩阵:\n%s\n", identity.String())
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,163 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"git.kingecg.top/kingecg/gotensor"
|
||||||
|
"math"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Println("=== 线性回归示例 ===")
|
||||||
|
|
||||||
|
// 创建简单的线性回归数据: y = 2*x + 3
|
||||||
|
// X = [[1, x1], [1, x2], ...] (添加偏置项)
|
||||||
|
// w = [b, w1] (偏置和权重)
|
||||||
|
|
||||||
|
// 准备训练数据
|
||||||
|
x_data := []float64{1.0, 2.0, 3.0, 4.0, 5.0}
|
||||||
|
X, err := gotensor.NewTensor(x_data, []int{5, 1})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建设计矩阵,添加偏置列
|
||||||
|
X_with_bias, err := addBiasColumn(X)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 真实的y值: y = 2*x + 3
|
||||||
|
y_data := []float64{5.0, 7.0, 9.0, 11.0, 13.0} // 2*x + 3
|
||||||
|
Y, err := gotensor.NewTensor(y_data, []int{5, 1})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化权重 [b, w] = [0, 0]
|
||||||
|
w_data := []float64{0.0, 0.0}
|
||||||
|
W, err := gotensor.NewTensor(w_data, []int{2, 1})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("训练数据 X (带偏置):\n%s\n", X_with_bias.String())
|
||||||
|
fmt.Printf("目标值 Y:\n%s\n", Y.String())
|
||||||
|
fmt.Printf("初始权重 W:\n%s\n", W.String())
|
||||||
|
|
||||||
|
// 训练参数
|
||||||
|
learning_rate := 0.01
|
||||||
|
epochs := 100
|
||||||
|
|
||||||
|
for epoch := 0; epoch < epochs; epoch++ {
|
||||||
|
// 前向传播: Y_pred = X * W
|
||||||
|
Y_pred, err := X_with_bias.MatMul(W)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算损失: MSE = (1/n) * sum((Y_pred - Y)^2)
|
||||||
|
diff, err := Y_pred.Subtract(Y)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算平方差
|
||||||
|
squared_diff, err := diff.Multiply(diff)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算平均值(简单地用一个常数缩放)
|
||||||
|
// 在完整实现中,我们需要一个求平均的函数,这里简化处理
|
||||||
|
loss_val := sumTensor(squared_diff) / float64(squared_diff.Size())
|
||||||
|
|
||||||
|
if epoch % 20 == 0 {
|
||||||
|
fmt.Printf("Epoch %d, Loss: %.6f\n", epoch, loss_val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算梯度
|
||||||
|
// MSE梯度: dL/dW = (2/n) * X^T * (Y_pred - Y)
|
||||||
|
X_transpose, err := X_with_bias.Data.Transpose()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
X_T := &gotensor.Tensor{Data: X_transpose}
|
||||||
|
|
||||||
|
grad_W_intermediate, err := X_T.MatMul(diff)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
grad_W := grad_W_intermediate.Scale(2.0 / float64(X_with_bias.Size()))
|
||||||
|
|
||||||
|
// 更新权重: W = W - lr * grad_W
|
||||||
|
grad_update := grad_W.Scale(learning_rate)
|
||||||
|
W, err = W.Subtract(grad_update)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("\n训练后的权重 W:\n%s\n", W.String())
|
||||||
|
|
||||||
|
// 预测
|
||||||
|
Y_pred, err := X_with_bias.MatMul(W)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("预测值:\n%s\n", Y_pred.String())
|
||||||
|
fmt.Printf("真实值:\n%s\n", Y.String())
|
||||||
|
|
||||||
|
// 计算最终误差
|
||||||
|
diff, err := Y_pred.Subtract(Y)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
squared_diff, err := diff.Multiply(diff)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
final_loss := sumTensor(squared_diff) / float64(squared_diff.Size())
|
||||||
|
fmt.Printf("最终损失: %.6f\n", final_loss)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数:为输入矩阵添加偏置列
|
||||||
|
func addBiasColumn(X *gotensor.Tensor) (*gotensor.Tensor, error) {
|
||||||
|
shape := X.Shape()
|
||||||
|
rows := shape[0]
|
||||||
|
|
||||||
|
// 这里需要实现列拼接,暂时使用简单的拼接方法
|
||||||
|
// 实际上我们需要一个拼接函数,这里用一个近似的方法
|
||||||
|
// 由于当前API没有提供拼接功能,我们手动构造
|
||||||
|
result_data := make([]float64, rows*2)
|
||||||
|
for i := 0; i < rows; i++ {
|
||||||
|
result_data[i*2] = 1.0 // 偏置项
|
||||||
|
X_val, _ := X.Data.Get(i, 0)
|
||||||
|
result_data[i*2+1] = X_val
|
||||||
|
}
|
||||||
|
|
||||||
|
return gotensor.NewTensor(result_data, []int{rows, 2})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数:计算张量所有元素的和
|
||||||
|
func sumTensor(t *gotensor.Tensor) float64 {
|
||||||
|
sum := 0.0
|
||||||
|
shape := t.Shape()
|
||||||
|
if len(shape) == 2 {
|
||||||
|
rows, cols := shape[0], shape[1]
|
||||||
|
for i := 0; i < rows; i++ {
|
||||||
|
for j := 0; j < cols; j++ {
|
||||||
|
val, _ := t.Data.Get(i, j)
|
||||||
|
sum += val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
size := t.Size()
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
val, _ := t.Data.Get(i)
|
||||||
|
sum += val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return math.Abs(sum) // 使用绝对值避免负数和
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,192 @@
|
||||||
|
package gotensor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewTensor(t *testing.T) {
|
||||||
|
data := []float64{1, 2, 3, 4}
|
||||||
|
shape := []int{2, 2}
|
||||||
|
|
||||||
|
tensor, err := NewTensor(data, shape)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("创建Tensor时发生错误: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tensor == nil {
|
||||||
|
t.Error("创建的Tensor不应为nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tensor.Shape(), shape) {
|
||||||
|
t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape())
|
||||||
|
}
|
||||||
|
|
||||||
|
if tensor.Size() != 4 {
|
||||||
|
t.Errorf("大小不匹配,期望: 4, 实际: %d", tensor.Size())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTensorAdd(t *testing.T) {
|
||||||
|
data1 := []float64{1, 2, 3, 4}
|
||||||
|
data2 := []float64{5, 6, 7, 8}
|
||||||
|
shape := []int{2, 2}
|
||||||
|
|
||||||
|
t1, err := NewTensor(data1, shape)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t2, err := NewTensor(data2, shape)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := t1.Add(t2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected_data := []float64{6, 8, 10, 12}
|
||||||
|
|
||||||
|
for i := 0; i < result.Size(); i++ {
|
||||||
|
expected_val := expected_data[i]
|
||||||
|
actual_val, err := result.Get(i/2, i%2) // 2x2矩阵
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual_val != expected_val {
|
||||||
|
t.Errorf("位置[%d,%d]的值不匹配,期望: %f, 实际: %f", i/2, i%2, expected_val, actual_val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTensorMultiply(t *testing.T) {
|
||||||
|
data1 := []float64{1, 2, 3, 4}
|
||||||
|
data2 := []float64{2, 3, 4, 5}
|
||||||
|
shape := []int{2, 2}
|
||||||
|
|
||||||
|
t1, err := NewTensor(data1, shape)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t2, err := NewTensor(data2, shape)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := t1.Multiply(t2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected_data := []float64{2, 6, 12, 20} // 逐元素相乘
|
||||||
|
|
||||||
|
for i := 0; i < result.Size(); i++ {
|
||||||
|
expected_val := expected_data[i]
|
||||||
|
actual_val, err := result.Get(i/2, i%2) // 2x2矩阵
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual_val != expected_val {
|
||||||
|
t.Errorf("位置[%d,%d]的值不匹配,期望: %f, 实际: %f", i/2, i%2, expected_val, actual_val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTensorScale(t *testing.T) {
|
||||||
|
data := []float64{1, 2, 3, 4}
|
||||||
|
shape := []int{2, 2}
|
||||||
|
|
||||||
|
tensor, err := NewTensor(data, shape)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
factor := 3.0
|
||||||
|
result := tensor.Scale(factor)
|
||||||
|
|
||||||
|
expected_data := []float64{3, 6, 9, 12}
|
||||||
|
|
||||||
|
for i := 0; i < result.Size(); i++ {
|
||||||
|
expected_val := expected_data[i]
|
||||||
|
actual_val, err := result.Get(i/2, i%2) // 2x2矩阵
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual_val != expected_val {
|
||||||
|
t.Errorf("位置[%d,%d]的值不匹配,期望: %f, 实际: %f", i/2, i%2, expected_val, actual_val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTensorShapeAndSize(t *testing.T) {
|
||||||
|
data := []float64{1, 2, 3, 4, 5, 6}
|
||||||
|
shape := []int{2, 3}
|
||||||
|
|
||||||
|
tensor, err := NewTensor(data, shape)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tensor.Shape(), shape) {
|
||||||
|
t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape())
|
||||||
|
}
|
||||||
|
|
||||||
|
if tensor.Size() != 6 {
|
||||||
|
t.Errorf("大小不匹配,期望: 6, 实际: %d", tensor.Size())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewZeros(t *testing.T) {
|
||||||
|
shape := []int{2, 3}
|
||||||
|
|
||||||
|
tensor, err := NewZeros(shape)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tensor.Shape(), shape) {
|
||||||
|
t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape())
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < tensor.Size(); i++ {
|
||||||
|
val, err := tensor.Get(i/3, i%3) // 2x3矩阵
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if val != 0.0 {
|
||||||
|
t.Errorf("位置[%d,%d]的值不为0,实际: %f", i/3, i%3, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewOnes(t *testing.T) {
|
||||||
|
shape := []int{2, 2}
|
||||||
|
|
||||||
|
tensor, err := NewOnes(shape)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tensor.Shape(), shape) {
|
||||||
|
t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape())
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < tensor.Size(); i++ {
|
||||||
|
val, err := tensor.Get(i/2, i%2) // 2x2矩阵
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if val != 1.0 {
|
||||||
|
t.Errorf("位置[%d,%d]的值不为1,实际: %f", i/2, i%2, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue