```
docs: 添加项目文档和示例代码 添加了完整的项目文档,包括README.md文件,详细介绍gotensor库的功能特性、 安装方法和使用示例。同时添加了多个示例程序展示基本运算、自动微分和线性 回归功能,并完善了测试用例。 - 添加LICENSE文件(MIT许可证) - 添加README.md项目文档 - 添加基本运算示例 - 添加自动微分示例 - 添加线性回归示例 - 添加单元测试文件 ```
This commit is contained in:
parent
2da5bc6ece
commit
9d6d4bdf56
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2025 kingecg
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
# gotensor
|
||||
|
||||
gotensor 是一个用 Go 语言编写的张量计算库,提供了基本的张量运算、自动微分和反向传播功能。该项目旨在为 Go 语言开发者提供一个高效、易用的张量计算工具。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 基本张量运算:加法、减法、乘法、矩阵乘法等
|
||||
- 张量操作:数乘、转置等
|
||||
- 自动微分和反向传播
|
||||
- 支持多种初始化方式:零张量、单位矩阵、随机张量等
|
||||
|
||||
## 安装
|
||||
|
||||
```bash
|
||||
go get git.kingecg.top/kingecg/gotensor
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.kingecg.top/kingecg/gotensor"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建两个2x2的张量
|
||||
t1_data := []float64{1, 2, 3, 4}
|
||||
t1_shape := []int{2, 2}
|
||||
t1, err := gotensor.NewTensor(t1_data, t1_shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
t2_data := []float64{5, 6, 7, 8}
|
||||
t2, err := gotensor.NewTensor(t2_data, t1_shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// 执行加法运算
|
||||
result, err := t1.Add(t2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("结果:\n%s\n", result.String())
|
||||
}
|
||||
```
|
||||
|
||||
## 示例
|
||||
|
||||
项目包含多个示例,展示如何使用 gotensor:
|
||||
|
||||
- [基本运算示例](examples/basic_operations.go):展示基本的张量运算
|
||||
- [自动微分示例](examples/autograd_example.go):演示自动微分和反向传播
|
||||
- [线性回归示例](examples/linear_regression.go):使用 gotensor 实现简单的线性回归
|
||||
|
||||
运行示例:
|
||||
|
||||
```bash
|
||||
# 基本运算示例
|
||||
go run examples/basic_operations.go
|
||||
|
||||
# 自动微分示例
|
||||
go run examples/autograd_example.go
|
||||
|
||||
# 线性回归示例
|
||||
go run examples/linear_regression.go
|
||||
```
|
||||
|
||||
## API 文档
|
||||
|
||||
### 创建张量
|
||||
|
||||
- `NewTensor(data []float64, shape []int)` - 创建新的张量
|
||||
- `NewZeros(shape []int)` - 创建零张量
|
||||
- `NewOnes(shape []int)` - 创建全一张量
|
||||
- `NewIdentity(size int)` - 创建单位矩阵
|
||||
|
||||
### 张量运算
|
||||
|
||||
- `Add(other *Tensor)` - 张量加法
|
||||
- `Subtract(other *Tensor)` - 张量减法
|
||||
- `Multiply(other *Tensor)` - 张量逐元素乘法
|
||||
- `MatMul(other *Tensor)` - 矩阵乘法
|
||||
- `Scale(factor float64)` - 数乘
|
||||
|
||||
### 其他方法
|
||||
|
||||
- `ZeroGrad()` - 将梯度置零
|
||||
- `Shape()` - 返回张量形状
|
||||
- `Size()` - 返回张量大小
|
||||
- `Get(indices ...int)` - 获取指定位置的值
|
||||
- `Set(value float64, indices ...int)` - 设置指定位置的值
|
||||
- `Backward()` - 执行反向传播
|
||||
|
||||
## 测试
|
||||
|
||||
运行测试:
|
||||
|
||||
```bash
|
||||
go test
|
||||
```
|
||||
|
||||
## 许可证
|
||||
|
||||
本项目使用 MIT 许可证 - 详见 [LICENSE](LICENSE) 文件。
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.kingecg.top/kingecg/gotensor"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fmt.Println("=== 自动微分和反向传播示例 ===")
|
||||
|
||||
// 创建一些张量,用于模拟简单的计算图
|
||||
// 例如: z = (w * x) + b,其中w, x, b是张量
|
||||
w_data := []float64{2.0}
|
||||
w, err := gotensor.NewTensor(w_data, []int{1})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
x_data := []float64{3.0}
|
||||
x, err := gotensor.NewTensor(x_data, []int{1})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
b_data := []float64{1.0}
|
||||
b, err := gotensor.NewTensor(b_data, []int{1})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("w = %s\n", w.String())
|
||||
fmt.Printf("x = %s\n", x.String())
|
||||
fmt.Printf("b = %s\n", b.String())
|
||||
|
||||
// 计算 y = w * x
|
||||
y, err := w.Multiply(x)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("y = w * x = %s\n", y.String())
|
||||
|
||||
// 计算 z = y + b
|
||||
z, err := y.Add(b)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("z = y + b = %s\n", z.String())
|
||||
|
||||
// 设置z的梯度为1 (通常这是损失函数,所以梯度是1)
|
||||
// 在实际应用中,这通常是损失函数的梯度
|
||||
z.ZeroGrad() // 确保梯度初始化为零
|
||||
// 在这里,我们直接操作梯度,将输出节点的梯度设为1
|
||||
// 为了演示,我们手动设置输出梯度
|
||||
// 由于z是最终输出,我们设置其梯度为1
|
||||
|
||||
// 在实际自动微分中,我们会从最终输出开始反向传播
|
||||
// 手动设置输出梯度为1,模拟损失函数的梯度
|
||||
one, _ := gotensor.NewOnes(z.Shape())
|
||||
z.Grad = one.Data
|
||||
|
||||
// 执行反向传播
|
||||
z.Backward()
|
||||
|
||||
fmt.Printf("反向传播后:\n")
|
||||
fmt.Printf("w的梯度: %s\n", w.String()) // 这将显示w的梯度
|
||||
fmt.Printf("x的梯度: %s\n", x.String()) // 这将显示x的梯度
|
||||
fmt.Printf("b的梯度: %s\n", b.String()) // 这将显示b的梯度
|
||||
|
||||
// 验证梯度计算
|
||||
// 对于 z = w*x + b:
|
||||
// dz/dw = x = 3
|
||||
// dz/dx = w = 2
|
||||
// dz/db = 1
|
||||
fmt.Println("\n=== 梯度验证 ===")
|
||||
w_grad, _ := w.Grad.Get(0)
|
||||
x_grad, _ := x.Grad.Get(0)
|
||||
b_grad, _ := b.Grad.Get(0)
|
||||
|
||||
fmt.Printf("w的梯度 (预期x=3.0): %.2f\n", w_grad)
|
||||
fmt.Printf("x的梯度 (预期w=2.0): %.2f\n", x_grad)
|
||||
fmt.Printf("b的梯度 (预期1.0): %.2f\n", b_grad)
|
||||
}
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.kingecg.top/kingecg/gotensor"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fmt.Println("=== 基本运算示例 ===")
|
||||
|
||||
// 创建两个2x2的张量
|
||||
t1_data := []float64{1, 2, 3, 4}
|
||||
t1_shape := []int{2, 2}
|
||||
t1, err := gotensor.NewTensor(t1_data, t1_shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
t2_data := []float64{5, 6, 7, 8}
|
||||
t2, err := gotensor.NewTensor(t2_data, t1_shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("张量1:\n%s\n", t1.String())
|
||||
fmt.Printf("张量2:\n%s\n", t2.String())
|
||||
|
||||
// 加法运算
|
||||
add_result, err := t1.Add(t2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("加法结果 (t1 + t2):\n%s\n", add_result.String())
|
||||
|
||||
// 减法运算
|
||||
sub_result, err := t1.Subtract(t2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("减法结果 (t1 - t2):\n%s\n", sub_result.String())
|
||||
|
||||
// 逐元素乘法
|
||||
mul_result, err := t1.Multiply(t2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("逐元素乘法结果 (t1 * t2):\n%s\n", mul_result.String())
|
||||
|
||||
// 数乘
|
||||
scale_result := t1.Scale(2.0)
|
||||
fmt.Printf("数乘结果 (t1 * 2):\n%s\n", scale_result.String())
|
||||
|
||||
// 矩阵乘法
|
||||
matmul_result, err := t1.MatMul(t2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("矩阵乘法结果 (t1 @ t2):\n%s\n", matmul_result.String())
|
||||
|
||||
// 创建零张量和单位矩阵
|
||||
zeros, err := gotensor.NewZeros([]int{2, 3})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("2x3零张量:\n%s\n", zeros.String())
|
||||
|
||||
identity, err := gotensor.NewIdentity(3)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("3x3单位矩阵:\n%s\n", identity.String())
|
||||
}
|
||||
|
|
@ -0,0 +1,163 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.kingecg.top/kingecg/gotensor"
|
||||
"math"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fmt.Println("=== 线性回归示例 ===")
|
||||
|
||||
// 创建简单的线性回归数据: y = 2*x + 3
|
||||
// X = [[1, x1], [1, x2], ...] (添加偏置项)
|
||||
// w = [b, w1] (偏置和权重)
|
||||
|
||||
// 准备训练数据
|
||||
x_data := []float64{1.0, 2.0, 3.0, 4.0, 5.0}
|
||||
X, err := gotensor.NewTensor(x_data, []int{5, 1})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// 创建设计矩阵,添加偏置列
|
||||
X_with_bias, err := addBiasColumn(X)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// 真实的y值: y = 2*x + 3
|
||||
y_data := []float64{5.0, 7.0, 9.0, 11.0, 13.0} // 2*x + 3
|
||||
Y, err := gotensor.NewTensor(y_data, []int{5, 1})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// 初始化权重 [b, w] = [0, 0]
|
||||
w_data := []float64{0.0, 0.0}
|
||||
W, err := gotensor.NewTensor(w_data, []int{2, 1})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("训练数据 X (带偏置):\n%s\n", X_with_bias.String())
|
||||
fmt.Printf("目标值 Y:\n%s\n", Y.String())
|
||||
fmt.Printf("初始权重 W:\n%s\n", W.String())
|
||||
|
||||
// 训练参数
|
||||
learning_rate := 0.01
|
||||
epochs := 100
|
||||
|
||||
for epoch := 0; epoch < epochs; epoch++ {
|
||||
// 前向传播: Y_pred = X * W
|
||||
Y_pred, err := X_with_bias.MatMul(W)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// 计算损失: MSE = (1/n) * sum((Y_pred - Y)^2)
|
||||
diff, err := Y_pred.Subtract(Y)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// 计算平方差
|
||||
squared_diff, err := diff.Multiply(diff)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// 计算平均值(简单地用一个常数缩放)
|
||||
// 在完整实现中,我们需要一个求平均的函数,这里简化处理
|
||||
loss_val := sumTensor(squared_diff) / float64(squared_diff.Size())
|
||||
|
||||
if epoch % 20 == 0 {
|
||||
fmt.Printf("Epoch %d, Loss: %.6f\n", epoch, loss_val)
|
||||
}
|
||||
|
||||
// 计算梯度
|
||||
// MSE梯度: dL/dW = (2/n) * X^T * (Y_pred - Y)
|
||||
X_transpose, err := X_with_bias.Data.Transpose()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
X_T := &gotensor.Tensor{Data: X_transpose}
|
||||
|
||||
grad_W_intermediate, err := X_T.MatMul(diff)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
grad_W := grad_W_intermediate.Scale(2.0 / float64(X_with_bias.Size()))
|
||||
|
||||
// 更新权重: W = W - lr * grad_W
|
||||
grad_update := grad_W.Scale(learning_rate)
|
||||
W, err = W.Subtract(grad_update)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("\n训练后的权重 W:\n%s\n", W.String())
|
||||
|
||||
// 预测
|
||||
Y_pred, err := X_with_bias.MatMul(W)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("预测值:\n%s\n", Y_pred.String())
|
||||
fmt.Printf("真实值:\n%s\n", Y.String())
|
||||
|
||||
// 计算最终误差
|
||||
diff, err := Y_pred.Subtract(Y)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
squared_diff, err := diff.Multiply(diff)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
final_loss := sumTensor(squared_diff) / float64(squared_diff.Size())
|
||||
fmt.Printf("最终损失: %.6f\n", final_loss)
|
||||
}
|
||||
|
||||
// 辅助函数:为输入矩阵添加偏置列
|
||||
func addBiasColumn(X *gotensor.Tensor) (*gotensor.Tensor, error) {
|
||||
shape := X.Shape()
|
||||
rows := shape[0]
|
||||
|
||||
// 这里需要实现列拼接,暂时使用简单的拼接方法
|
||||
// 实际上我们需要一个拼接函数,这里用一个近似的方法
|
||||
// 由于当前API没有提供拼接功能,我们手动构造
|
||||
result_data := make([]float64, rows*2)
|
||||
for i := 0; i < rows; i++ {
|
||||
result_data[i*2] = 1.0 // 偏置项
|
||||
X_val, _ := X.Data.Get(i, 0)
|
||||
result_data[i*2+1] = X_val
|
||||
}
|
||||
|
||||
return gotensor.NewTensor(result_data, []int{rows, 2})
|
||||
}
|
||||
|
||||
// 辅助函数:计算张量所有元素的和
|
||||
func sumTensor(t *gotensor.Tensor) float64 {
|
||||
sum := 0.0
|
||||
shape := t.Shape()
|
||||
if len(shape) == 2 {
|
||||
rows, cols := shape[0], shape[1]
|
||||
for i := 0; i < rows; i++ {
|
||||
for j := 0; j < cols; j++ {
|
||||
val, _ := t.Data.Get(i, j)
|
||||
sum += val
|
||||
}
|
||||
}
|
||||
} else {
|
||||
size := t.Size()
|
||||
for i := 0; i < size; i++ {
|
||||
val, _ := t.Data.Get(i)
|
||||
sum += val
|
||||
}
|
||||
}
|
||||
return math.Abs(sum) // 使用绝对值避免负数和
|
||||
}
|
||||
|
|
@ -0,0 +1,192 @@
|
|||
package gotensor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func TestNewTensor(t *testing.T) {
|
||||
data := []float64{1, 2, 3, 4}
|
||||
shape := []int{2, 2}
|
||||
|
||||
tensor, err := NewTensor(data, shape)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("创建Tensor时发生错误: %v", err)
|
||||
}
|
||||
|
||||
if tensor == nil {
|
||||
t.Error("创建的Tensor不应为nil")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tensor.Shape(), shape) {
|
||||
t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape())
|
||||
}
|
||||
|
||||
if tensor.Size() != 4 {
|
||||
t.Errorf("大小不匹配,期望: 4, 实际: %d", tensor.Size())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensorAdd(t *testing.T) {
|
||||
data1 := []float64{1, 2, 3, 4}
|
||||
data2 := []float64{5, 6, 7, 8}
|
||||
shape := []int{2, 2}
|
||||
|
||||
t1, err := NewTensor(data1, shape)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t2, err := NewTensor(data2, shape)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, err := t1.Add(t2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected_data := []float64{6, 8, 10, 12}
|
||||
|
||||
for i := 0; i < result.Size(); i++ {
|
||||
expected_val := expected_data[i]
|
||||
actual_val, err := result.Get(i/2, i%2) // 2x2矩阵
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if actual_val != expected_val {
|
||||
t.Errorf("位置[%d,%d]的值不匹配,期望: %f, 实际: %f", i/2, i%2, expected_val, actual_val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensorMultiply(t *testing.T) {
|
||||
data1 := []float64{1, 2, 3, 4}
|
||||
data2 := []float64{2, 3, 4, 5}
|
||||
shape := []int{2, 2}
|
||||
|
||||
t1, err := NewTensor(data1, shape)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t2, err := NewTensor(data2, shape)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, err := t1.Multiply(t2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected_data := []float64{2, 6, 12, 20} // 逐元素相乘
|
||||
|
||||
for i := 0; i < result.Size(); i++ {
|
||||
expected_val := expected_data[i]
|
||||
actual_val, err := result.Get(i/2, i%2) // 2x2矩阵
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if actual_val != expected_val {
|
||||
t.Errorf("位置[%d,%d]的值不匹配,期望: %f, 实际: %f", i/2, i%2, expected_val, actual_val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensorScale(t *testing.T) {
|
||||
data := []float64{1, 2, 3, 4}
|
||||
shape := []int{2, 2}
|
||||
|
||||
tensor, err := NewTensor(data, shape)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
factor := 3.0
|
||||
result := tensor.Scale(factor)
|
||||
|
||||
expected_data := []float64{3, 6, 9, 12}
|
||||
|
||||
for i := 0; i < result.Size(); i++ {
|
||||
expected_val := expected_data[i]
|
||||
actual_val, err := result.Get(i/2, i%2) // 2x2矩阵
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if actual_val != expected_val {
|
||||
t.Errorf("位置[%d,%d]的值不匹配,期望: %f, 实际: %f", i/2, i%2, expected_val, actual_val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensorShapeAndSize(t *testing.T) {
|
||||
data := []float64{1, 2, 3, 4, 5, 6}
|
||||
shape := []int{2, 3}
|
||||
|
||||
tensor, err := NewTensor(data, shape)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tensor.Shape(), shape) {
|
||||
t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape())
|
||||
}
|
||||
|
||||
if tensor.Size() != 6 {
|
||||
t.Errorf("大小不匹配,期望: 6, 实际: %d", tensor.Size())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewZeros(t *testing.T) {
|
||||
shape := []int{2, 3}
|
||||
|
||||
tensor, err := NewZeros(shape)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tensor.Shape(), shape) {
|
||||
t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape())
|
||||
}
|
||||
|
||||
for i := 0; i < tensor.Size(); i++ {
|
||||
val, err := tensor.Get(i/3, i%3) // 2x3矩阵
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if val != 0.0 {
|
||||
t.Errorf("位置[%d,%d]的值不为0,实际: %f", i/3, i%3, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewOnes(t *testing.T) {
|
||||
shape := []int{2, 2}
|
||||
|
||||
tensor, err := NewOnes(shape)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tensor.Shape(), shape) {
|
||||
t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape())
|
||||
}
|
||||
|
||||
for i := 0; i < tensor.Size(); i++ {
|
||||
val, err := tensor.Get(i/2, i%2) // 2x2矩阵
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if val != 1.0 {
|
||||
t.Errorf("位置[%d,%d]的值不为1,实际: %f", i/2, i%2, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue