package gomatrix import ( "reflect" "testing" ) // TestNewMatrix 测试NewMatrix函数 func TestNewMatrix(t *testing.T) { t.Run("正常创建矩阵", func(t *testing.T) { data := []float64{1, 2, 3, 4} shape := []int{2, 2} matrix, err := NewMatrix(data, shape) if err != nil { t.Errorf("期望无错误,但得到: %v", err) } if matrix == nil { t.Error("期望矩阵不为nil") } if !reflect.DeepEqual(matrix.shape, shape) { t.Errorf("期望形状 %v,但得到 %v", shape, matrix.shape) } if matrix.size != 4 { t.Errorf("期望大小为4,但得到 %d", matrix.size) } }) t.Run("数据和形状长度不匹配", func(t *testing.T) { data := []float64{1, 2, 3} // 3个元素 shape := []int{2, 2} // 需要4个元素 _, err := NewMatrix(data, shape) if err == nil { t.Error("期望有错误,但没有错误返回") } }) t.Run("空数据", func(t *testing.T) { _, err := NewMatrix([]float64{}, []int{2, 2}) if err == nil { t.Error("期望有错误,但没有错误返回") } }) t.Run("空形状", func(t *testing.T) { _, err := NewMatrix([]float64{1, 2, 3, 4}, []int{}) if err == nil { t.Error("期望有错误,但没有错误返回") } }) t.Run("负维度", func(t *testing.T) { _, err := NewMatrix([]float64{1, 2, 3, 4}, []int{-1, 2}) if err == nil { t.Error("期望有错误,但没有错误返回") } }) } // TestNewZeros 测试NewZeros函数 func TestNewZeros(t *testing.T) { t.Run("创建零矩阵", func(t *testing.T) { shape := []int{2, 3} matrix, err := NewZeros(shape) if err != nil { t.Errorf("期望无错误,但得到: %v", err) } if matrix.size != 6 { t.Errorf("期望大小为6,但得到 %d", matrix.size) } for i := 0; i < matrix.size; i++ { if matrix.data[i] != 0 { t.Errorf("期望所有元素为0,但在位置 %d 发现值 %f", i, matrix.data[i]) } } }) t.Run("负维度", func(t *testing.T) { _, err := NewZeros([]int{-1, 2}) if err == nil { t.Error("期望有错误,但没有错误返回") } }) } // TestNewOnes 测试NewOnes函数 func TestNewOnes(t *testing.T) { t.Run("创建全1矩阵", func(t *testing.T) { shape := []int{2, 2} matrix, err := NewOnes(shape) if err != nil { t.Errorf("期望无错误,但得到: %v", err) } if matrix.size != 4 { t.Errorf("期望大小为4,但得到 %d", matrix.size) } for i := 0; i < matrix.size; i++ { if matrix.data[i] != 1 { t.Errorf("期望所有元素为1,但在位置 %d 发现值 %f", i, matrix.data[i]) } } }) } // TestNewIdentity 测试NewIdentity函数 func TestNewIdentity(t *testing.T) { t.Run("创建单位矩阵", func(t *testing.T) { size := 3 matrix, err := NewIdentity(size) if err != nil { t.Errorf("期望无错误,但得到: %v", err) } if matrix.size != 9 { t.Errorf("期望大小为9,但得到 %d", matrix.size) } expectedData := []float64{1, 0, 0, 0, 1, 0, 0, 0, 1} for i := 0; i < matrix.size; i++ { if matrix.data[i] != expectedData[i] { t.Errorf("期望数据 %v,但得到 %v", expectedData, matrix.data) break } } }) t.Run("负大小", func(t *testing.T) { _, err := NewIdentity(-1) if err == nil { t.Error("期望有错误,但没有错误返回") } }) } // TestGetSet 测试Get和Set方法 func TestGetSet(t *testing.T) { matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) t.Run("正常获取值", func(t *testing.T) { value, err := matrix.Get(0, 0) if err != nil { t.Errorf("期望无错误,但得到: %v", err) } if value != 1 { t.Errorf("期望值为1,但得到 %f", value) } value, err = matrix.Get(1, 1) if err != nil { t.Errorf("期望无错误,但得到: %v", err) } if value != 4 { t.Errorf("期望值为4,但得到 %f", value) } }) t.Run("设置值", func(t *testing.T) { err := matrix.Set(99, 0, 0) if err != nil { t.Errorf("期望无错误,但得到: %v", err) } value, err := matrix.Get(0, 0) if err != nil { t.Errorf("期望无错误,但得到: %v", err) } if value != 99 { t.Errorf("期望值为99,但得到 %f", value) } }) t.Run("越界索引", func(t *testing.T) { _, err := matrix.Get(5, 5) if err == nil { t.Error("期望有错误,但没有错误返回") } err = matrix.Set(100, 5, 5) if err == nil { t.Error("期望有错误,但没有错误返回") } }) } // TestAdd 测试矩阵加法 func TestAdd(t *testing.T) { matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) matrix2, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2}) t.Run("正常相加", func(t *testing.T) { result, err := matrix1.Add(matrix2) if err != nil { t.Errorf("期望无错误,但得到: %v", err) } expected := []float64{6, 8, 10, 12} for i := 0; i < result.size; i++ { if result.data[i] != expected[i] { t.Errorf("期望结果 %v,但得到 %v", expected, result.data) break } } }) t.Run("不同形状的矩阵相加", func(t *testing.T) { matrix3, _ := NewMatrix([]float64{1, 2}, []int{1, 2}) _, err := matrix1.Add(matrix3) if err == nil { t.Error("期望有错误,但没有错误返回") } }) } // TestSubtract 测试矩阵减法 func TestSubtract(t *testing.T) { matrix1, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2}) matrix2, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) t.Run("正常相减", func(t *testing.T) { result, err := matrix1.Subtract(matrix2) if err != nil { t.Errorf("期望无错误,但得到: %v", err) } expected := []float64{4, 4, 4, 4} for i := 0; i < result.size; i++ { if result.data[i] != expected[i] { t.Errorf("期望结果 %v,但得到 %v", expected, result.data) break } } }) } // TestMultiply 测试矩阵逐元素乘法 func TestMultiply(t *testing.T) { matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) matrix2, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2}) t.Run("正常逐元素相乘", func(t *testing.T) { result, err := matrix1.Multiply(matrix2) if err != nil { t.Errorf("期望无错误,但得到: %v", err) } expected := []float64{5, 12, 21, 32} for i := 0; i < result.size; i++ { if result.data[i] != expected[i] { t.Errorf("期望结果 %v,但得到 %v", expected, result.data) break } } }) } // TestMatMul 测试矩阵乘法 func TestMatMul(t *testing.T) { t.Run("正常矩阵乘法", func(t *testing.T) { // 创建2x3矩阵 [[1,2,3],[4,5,6]] matrix1, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3}) // 创建3x2矩阵 [[7,8],[9,10],[11,12]] matrix2, _ := NewMatrix([]float64{7, 8, 9, 10, 11, 12}, []int{3, 2}) result, err := matrix1.MatMul(matrix2) if err != nil { t.Errorf("期望无错误,但得到: %v", err) } // 预期结果是2x2矩阵 [[58,64],[139,154]] expected := []float64{58, 64, 139, 154} for i := 0; i < result.size; i++ { if result.data[i] != expected[i] { t.Errorf("期望结果 %v,但得到 %v", expected, result.data) break } } }) t.Run("不兼容的矩阵乘法", func(t *testing.T) { matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) matrix2, _ := NewMatrix([]float64{1, 2}, []int{1, 2}) _, err := matrix1.MatMul(matrix2) if err == nil { t.Error("期望有错误,但没有错误返回") } }) } // TestScale 测试矩阵数乘 func TestScale(t *testing.T) { matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) result := matrix.Scale(2) expected := []float64{2, 4, 6, 8} for i := 0; i < result.size; i++ { if result.data[i] != expected[i] { t.Errorf("期望结果 %v,但得到 %v", expected, result.data) break } } } // TestTranspose 测试矩阵转置 func TestTranspose(t *testing.T) { t.Run("正常转置", func(t *testing.T) { // 创建矩阵 [[1,2,3],[4,5,6]] matrix, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3}) result, err := matrix.Transpose() if err != nil { t.Errorf("期望无错误,但得到: %v", err) } // 预期结果是 [[1,4],[2,5],[3,6]] expected := []float64{1, 4, 2, 5, 3, 6} for i := 0; i < result.size; i++ { if result.data[i] != expected[i] { t.Errorf("期望结果 %v,但得到 %v", expected, result.data) break } } // 检查形状是否正确转置 if result.shape[0] != 3 || result.shape[1] != 2 { t.Errorf("期望形状 [3,2],但得到 %v", result.shape) } }) t.Run("非2D矩阵转置", func(t *testing.T) { matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2, 1}) _, err := matrix.Transpose() if err == nil { t.Error("期望有错误,但没有错误返回") } }) } // TestCopy 测试矩阵复制 func TestCopy(t *testing.T) { matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) copied := matrix.Copy() // 检查复制的矩阵是否相等 if !matrix.Equal(copied) { t.Error("复制的矩阵应该与原矩阵相等") } // 修改复制的矩阵不应该影响原矩阵 copied.Set(99, 0, 0) originalValue, _ := matrix.Get(0, 0) if originalValue == 99 { t.Error("修改复制的矩阵不应该影响原矩阵") } } // TestShapeAndSize 测试Shape和Size方法 func TestShapeAndSize(t *testing.T) { matrix, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3}) shape := matrix.Shape() if !reflect.DeepEqual(shape, []int{2, 3}) { t.Errorf("期望形状 [2,3],但得到 %v", shape) } size := matrix.Size() if size != 6 { t.Errorf("期望大小为6,但得到 %d", size) } } // TestEqual 测试Equal方法 func TestEqual(t *testing.T) { matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) matrix2, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) matrix3, _ := NewMatrix([]float64{1, 2, 3, 5}, []int{2, 2}) if !matrix1.Equal(matrix2) { t.Error("两个相同矩阵应该相等") } if matrix1.Equal(matrix3) { t.Error("两个不同矩阵不应该相等") } }