192 lines
3.8 KiB
Go
192 lines
3.8 KiB
Go
package gotensor
|
||
|
||
import (
|
||
"testing"
|
||
"reflect"
|
||
)
|
||
|
||
func TestNewTensor(t *testing.T) {
|
||
data := []float64{1, 2, 3, 4}
|
||
shape := []int{2, 2}
|
||
|
||
tensor, err := NewTensor(data, shape)
|
||
|
||
if err != nil {
|
||
t.Errorf("创建Tensor时发生错误: %v", err)
|
||
}
|
||
|
||
if tensor == nil {
|
||
t.Error("创建的Tensor不应为nil")
|
||
}
|
||
|
||
if !reflect.DeepEqual(tensor.Shape(), shape) {
|
||
t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape())
|
||
}
|
||
|
||
if tensor.Size() != 4 {
|
||
t.Errorf("大小不匹配,期望: 4, 实际: %d", tensor.Size())
|
||
}
|
||
}
|
||
|
||
func TestTensorAdd(t *testing.T) {
|
||
data1 := []float64{1, 2, 3, 4}
|
||
data2 := []float64{5, 6, 7, 8}
|
||
shape := []int{2, 2}
|
||
|
||
t1, err := NewTensor(data1, shape)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
t2, err := NewTensor(data2, shape)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
result, err := t1.Add(t2)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
expected_data := []float64{6, 8, 10, 12}
|
||
|
||
for i := 0; i < result.Size(); i++ {
|
||
expected_val := expected_data[i]
|
||
actual_val, err := result.Get(i/2, i%2) // 2x2矩阵
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if actual_val != expected_val {
|
||
t.Errorf("位置[%d,%d]的值不匹配,期望: %f, 实际: %f", i/2, i%2, expected_val, actual_val)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestTensorMultiply(t *testing.T) {
|
||
data1 := []float64{1, 2, 3, 4}
|
||
data2 := []float64{2, 3, 4, 5}
|
||
shape := []int{2, 2}
|
||
|
||
t1, err := NewTensor(data1, shape)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
t2, err := NewTensor(data2, shape)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
result, err := t1.Multiply(t2)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
expected_data := []float64{2, 6, 12, 20} // 逐元素相乘
|
||
|
||
for i := 0; i < result.Size(); i++ {
|
||
expected_val := expected_data[i]
|
||
actual_val, err := result.Get(i/2, i%2) // 2x2矩阵
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if actual_val != expected_val {
|
||
t.Errorf("位置[%d,%d]的值不匹配,期望: %f, 实际: %f", i/2, i%2, expected_val, actual_val)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestTensorScale(t *testing.T) {
|
||
data := []float64{1, 2, 3, 4}
|
||
shape := []int{2, 2}
|
||
|
||
tensor, err := NewTensor(data, shape)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
factor := 3.0
|
||
result := tensor.Scale(factor)
|
||
|
||
expected_data := []float64{3, 6, 9, 12}
|
||
|
||
for i := 0; i < result.Size(); i++ {
|
||
expected_val := expected_data[i]
|
||
actual_val, err := result.Get(i/2, i%2) // 2x2矩阵
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if actual_val != expected_val {
|
||
t.Errorf("位置[%d,%d]的值不匹配,期望: %f, 实际: %f", i/2, i%2, expected_val, actual_val)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestTensorShapeAndSize(t *testing.T) {
|
||
data := []float64{1, 2, 3, 4, 5, 6}
|
||
shape := []int{2, 3}
|
||
|
||
tensor, err := NewTensor(data, shape)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if !reflect.DeepEqual(tensor.Shape(), shape) {
|
||
t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape())
|
||
}
|
||
|
||
if tensor.Size() != 6 {
|
||
t.Errorf("大小不匹配,期望: 6, 实际: %d", tensor.Size())
|
||
}
|
||
}
|
||
|
||
func TestNewZeros(t *testing.T) {
|
||
shape := []int{2, 3}
|
||
|
||
tensor, err := NewZeros(shape)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if !reflect.DeepEqual(tensor.Shape(), shape) {
|
||
t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape())
|
||
}
|
||
|
||
for i := 0; i < tensor.Size(); i++ {
|
||
val, err := tensor.Get(i/3, i%3) // 2x3矩阵
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if val != 0.0 {
|
||
t.Errorf("位置[%d,%d]的值不为0,实际: %f", i/3, i%3, val)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestNewOnes(t *testing.T) {
|
||
shape := []int{2, 2}
|
||
|
||
tensor, err := NewOnes(shape)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if !reflect.DeepEqual(tensor.Shape(), shape) {
|
||
t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape())
|
||
}
|
||
|
||
for i := 0; i < tensor.Size(); i++ {
|
||
val, err := tensor.Get(i/2, i%2) // 2x2矩阵
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if val != 1.0 {
|
||
t.Errorf("位置[%d,%d]的值不为1,实际: %f", i/2, i%2, val)
|
||
}
|
||
}
|
||
} |