package gotensor import ( "testing" "reflect" ) func TestNewTensor(t *testing.T) { data := []float64{1, 2, 3, 4} shape := []int{2, 2} tensor, err := NewTensor(data, shape) if err != nil { t.Errorf("创建Tensor时发生错误: %v", err) } if tensor == nil { t.Error("创建的Tensor不应为nil") } if !reflect.DeepEqual(tensor.Shape(), shape) { t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape()) } if tensor.Size() != 4 { t.Errorf("大小不匹配,期望: 4, 实际: %d", tensor.Size()) } } func TestTensorAdd(t *testing.T) { data1 := []float64{1, 2, 3, 4} data2 := []float64{5, 6, 7, 8} shape := []int{2, 2} t1, err := NewTensor(data1, shape) if err != nil { t.Fatal(err) } t2, err := NewTensor(data2, shape) if err != nil { t.Fatal(err) } result, err := t1.Add(t2) if err != nil { t.Fatal(err) } expected_data := []float64{6, 8, 10, 12} for i := 0; i < result.Size(); i++ { expected_val := expected_data[i] actual_val, err := result.Get(i/2, i%2) // 2x2矩阵 if err != nil { t.Fatal(err) } if actual_val != expected_val { t.Errorf("位置[%d,%d]的值不匹配,期望: %f, 实际: %f", i/2, i%2, expected_val, actual_val) } } } func TestTensorMultiply(t *testing.T) { data1 := []float64{1, 2, 3, 4} data2 := []float64{2, 3, 4, 5} shape := []int{2, 2} t1, err := NewTensor(data1, shape) if err != nil { t.Fatal(err) } t2, err := NewTensor(data2, shape) if err != nil { t.Fatal(err) } result, err := t1.Multiply(t2) if err != nil { t.Fatal(err) } expected_data := []float64{2, 6, 12, 20} // 逐元素相乘 for i := 0; i < result.Size(); i++ { expected_val := expected_data[i] actual_val, err := result.Get(i/2, i%2) // 2x2矩阵 if err != nil { t.Fatal(err) } if actual_val != expected_val { t.Errorf("位置[%d,%d]的值不匹配,期望: %f, 实际: %f", i/2, i%2, expected_val, actual_val) } } } func TestTensorScale(t *testing.T) { data := []float64{1, 2, 3, 4} shape := []int{2, 2} tensor, err := NewTensor(data, shape) if err != nil { t.Fatal(err) } factor := 3.0 result := tensor.Scale(factor) expected_data := []float64{3, 6, 9, 12} for i := 0; i < result.Size(); i++ { expected_val := expected_data[i] actual_val, err := result.Get(i/2, i%2) // 2x2矩阵 if err != nil { t.Fatal(err) } if actual_val != expected_val { t.Errorf("位置[%d,%d]的值不匹配,期望: %f, 实际: %f", i/2, i%2, expected_val, actual_val) } } } func TestTensorShapeAndSize(t *testing.T) { data := []float64{1, 2, 3, 4, 5, 6} shape := []int{2, 3} tensor, err := NewTensor(data, shape) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(tensor.Shape(), shape) { t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape()) } if tensor.Size() != 6 { t.Errorf("大小不匹配,期望: 6, 实际: %d", tensor.Size()) } } func TestNewZeros(t *testing.T) { shape := []int{2, 3} tensor, err := NewZeros(shape) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(tensor.Shape(), shape) { t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape()) } for i := 0; i < tensor.Size(); i++ { val, err := tensor.Get(i/3, i%3) // 2x3矩阵 if err != nil { t.Fatal(err) } if val != 0.0 { t.Errorf("位置[%d,%d]的值不为0,实际: %f", i/3, i%3, val) } } } func TestNewOnes(t *testing.T) { shape := []int{2, 2} tensor, err := NewOnes(shape) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(tensor.Shape(), shape) { t.Errorf("形状不匹配,期望: %v, 实际: %v", shape, tensor.Shape()) } for i := 0; i < tensor.Size(); i++ { val, err := tensor.Get(i/2, i%2) // 2x2矩阵 if err != nil { t.Fatal(err) } if val != 1.0 { t.Errorf("位置[%d,%d]的值不为1,实际: %f", i/2, i%2, val) } } }