From f40960a855c7175a556028dc6e6639baed8cbcc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A8=8B=E5=B9=BF?= Date: Wed, 31 Dec 2025 17:26:27 +0800 Subject: [PATCH] =?UTF-8?q?```=20test(gomatrix):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E7=9F=A9=E9=98=B5=E6=93=8D=E4=BD=9C=E7=9A=84=E5=85=A8=E9=9D=A2?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加了完整的单元测试覆盖矩阵库的核心功能,包括: - 矩阵创建函数测试(NewMatrix, NewZeros, NewOnes, NewIdentity) - 矩阵基本操作测试(Get, Set, Add, Subtract, Multiply) - 矩阵乘法和转置操作测试(MatMul, Transpose) - 矩阵复制和比较测试(Copy, Equal) - 边界条件和错误处理测试 - 各种异常情况的测试用例 ``` --- matrix_test.go | 402 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 402 insertions(+) create mode 100644 matrix_test.go diff --git a/matrix_test.go b/matrix_test.go new file mode 100644 index 0000000..ee67d63 --- /dev/null +++ b/matrix_test.go @@ -0,0 +1,402 @@ +package gomatrix + +import ( + "reflect" + "testing" +) + +// TestNewMatrix 测试NewMatrix函数 +func TestNewMatrix(t *testing.T) { + t.Run("正常创建矩阵", func(t *testing.T) { + data := []float64{1, 2, 3, 4} + shape := []int{2, 2} + matrix, err := NewMatrix(data, shape) + + if err != nil { + t.Errorf("期望无错误,但得到: %v", err) + } + + if matrix == nil { + t.Error("期望矩阵不为nil") + } + + if !reflect.DeepEqual(matrix.shape, shape) { + t.Errorf("期望形状 %v,但得到 %v", shape, matrix.shape) + } + + if matrix.size != 4 { + t.Errorf("期望大小为4,但得到 %d", matrix.size) + } + }) + + t.Run("数据和形状长度不匹配", func(t *testing.T) { + data := []float64{1, 2, 3} // 3个元素 + shape := []int{2, 2} // 需要4个元素 + _, err := NewMatrix(data, shape) + + if err == nil { + t.Error("期望有错误,但没有错误返回") + } + }) + + t.Run("空数据", func(t *testing.T) { + _, err := NewMatrix([]float64{}, []int{2, 2}) + + if err == nil { + t.Error("期望有错误,但没有错误返回") + } + }) + + t.Run("空形状", func(t *testing.T) { + _, err := NewMatrix([]float64{1, 2, 3, 4}, []int{}) + + if err == nil { + t.Error("期望有错误,但没有错误返回") + } + }) + + t.Run("负维度", func(t *testing.T) { + _, err := NewMatrix([]float64{1, 2, 3, 4}, []int{-1, 2}) + + if err == nil { + t.Error("期望有错误,但没有错误返回") + } + }) +} + +// TestNewZeros 测试NewZeros函数 +func TestNewZeros(t *testing.T) { + t.Run("创建零矩阵", func(t *testing.T) { + shape := []int{2, 3} + matrix, err := NewZeros(shape) + + if err != nil { + t.Errorf("期望无错误,但得到: %v", err) + } + + if matrix.size != 6 { + t.Errorf("期望大小为6,但得到 %d", matrix.size) + } + + for i := 0; i < matrix.size; i++ { + if matrix.data[i] != 0 { + t.Errorf("期望所有元素为0,但在位置 %d 发现值 %f", i, matrix.data[i]) + } + } + }) + + t.Run("负维度", func(t *testing.T) { + _, err := NewZeros([]int{-1, 2}) + + if err == nil { + t.Error("期望有错误,但没有错误返回") + } + }) +} + +// TestNewOnes 测试NewOnes函数 +func TestNewOnes(t *testing.T) { + t.Run("创建全1矩阵", func(t *testing.T) { + shape := []int{2, 2} + matrix, err := NewOnes(shape) + + if err != nil { + t.Errorf("期望无错误,但得到: %v", err) + } + + if matrix.size != 4 { + t.Errorf("期望大小为4,但得到 %d", matrix.size) + } + + for i := 0; i < matrix.size; i++ { + if matrix.data[i] != 1 { + t.Errorf("期望所有元素为1,但在位置 %d 发现值 %f", i, matrix.data[i]) + } + } + }) +} + +// TestNewIdentity 测试NewIdentity函数 +func TestNewIdentity(t *testing.T) { + t.Run("创建单位矩阵", func(t *testing.T) { + size := 3 + matrix, err := NewIdentity(size) + + if err != nil { + t.Errorf("期望无错误,但得到: %v", err) + } + + if matrix.size != 9 { + t.Errorf("期望大小为9,但得到 %d", matrix.size) + } + + expectedData := []float64{1, 0, 0, 0, 1, 0, 0, 0, 1} + for i := 0; i < matrix.size; i++ { + if matrix.data[i] != expectedData[i] { + t.Errorf("期望数据 %v,但得到 %v", expectedData, matrix.data) + break + } + } + }) + + t.Run("负大小", func(t *testing.T) { + _, err := NewIdentity(-1) + + if err == nil { + t.Error("期望有错误,但没有错误返回") + } + }) +} + +// TestGetSet 测试Get和Set方法 +func TestGetSet(t *testing.T) { + matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) + + t.Run("正常获取值", func(t *testing.T) { + value, err := matrix.Get(0, 0) + if err != nil { + t.Errorf("期望无错误,但得到: %v", err) + } + if value != 1 { + t.Errorf("期望值为1,但得到 %f", value) + } + + value, err = matrix.Get(1, 1) + if err != nil { + t.Errorf("期望无错误,但得到: %v", err) + } + if value != 4 { + t.Errorf("期望值为4,但得到 %f", value) + } + }) + + t.Run("设置值", func(t *testing.T) { + err := matrix.Set(99, 0, 0) + if err != nil { + t.Errorf("期望无错误,但得到: %v", err) + } + + value, err := matrix.Get(0, 0) + if err != nil { + t.Errorf("期望无错误,但得到: %v", err) + } + if value != 99 { + t.Errorf("期望值为99,但得到 %f", value) + } + }) + + t.Run("越界索引", func(t *testing.T) { + _, err := matrix.Get(5, 5) + if err == nil { + t.Error("期望有错误,但没有错误返回") + } + + err = matrix.Set(100, 5, 5) + if err == nil { + t.Error("期望有错误,但没有错误返回") + } + }) +} + +// TestAdd 测试矩阵加法 +func TestAdd(t *testing.T) { + matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) + matrix2, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2}) + + t.Run("正常相加", func(t *testing.T) { + result, err := matrix1.Add(matrix2) + if err != nil { + t.Errorf("期望无错误,但得到: %v", err) + } + + expected := []float64{6, 8, 10, 12} + for i := 0; i < result.size; i++ { + if result.data[i] != expected[i] { + t.Errorf("期望结果 %v,但得到 %v", expected, result.data) + break + } + } + }) + + t.Run("不同形状的矩阵相加", func(t *testing.T) { + matrix3, _ := NewMatrix([]float64{1, 2}, []int{1, 2}) + _, err := matrix1.Add(matrix3) + if err == nil { + t.Error("期望有错误,但没有错误返回") + } + }) +} + +// TestSubtract 测试矩阵减法 +func TestSubtract(t *testing.T) { + matrix1, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2}) + matrix2, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) + + t.Run("正常相减", func(t *testing.T) { + result, err := matrix1.Subtract(matrix2) + if err != nil { + t.Errorf("期望无错误,但得到: %v", err) + } + + expected := []float64{4, 4, 4, 4} + for i := 0; i < result.size; i++ { + if result.data[i] != expected[i] { + t.Errorf("期望结果 %v,但得到 %v", expected, result.data) + break + } + } + }) +} + +// TestMultiply 测试矩阵逐元素乘法 +func TestMultiply(t *testing.T) { + matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) + matrix2, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2}) + + t.Run("正常逐元素相乘", func(t *testing.T) { + result, err := matrix1.Multiply(matrix2) + if err != nil { + t.Errorf("期望无错误,但得到: %v", err) + } + + expected := []float64{5, 12, 21, 32} + for i := 0; i < result.size; i++ { + if result.data[i] != expected[i] { + t.Errorf("期望结果 %v,但得到 %v", expected, result.data) + break + } + } + }) +} + +// TestMatMul 测试矩阵乘法 +func TestMatMul(t *testing.T) { + t.Run("正常矩阵乘法", func(t *testing.T) { + // 创建2x3矩阵 [[1,2,3],[4,5,6]] + matrix1, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3}) + // 创建3x2矩阵 [[7,8],[9,10],[11,12]] + matrix2, _ := NewMatrix([]float64{7, 8, 9, 10, 11, 12}, []int{3, 2}) + + result, err := matrix1.MatMul(matrix2) + if err != nil { + t.Errorf("期望无错误,但得到: %v", err) + } + + // 预期结果是2x2矩阵 [[58,64],[139,154]] + expected := []float64{58, 64, 139, 154} + for i := 0; i < result.size; i++ { + if result.data[i] != expected[i] { + t.Errorf("期望结果 %v,但得到 %v", expected, result.data) + break + } + } + }) + + t.Run("不兼容的矩阵乘法", func(t *testing.T) { + matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) + matrix2, _ := NewMatrix([]float64{1, 2}, []int{1, 2}) + _, err := matrix1.MatMul(matrix2) + if err == nil { + t.Error("期望有错误,但没有错误返回") + } + }) +} + +// TestScale 测试矩阵数乘 +func TestScale(t *testing.T) { + matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) + + result := matrix.Scale(2) + expected := []float64{2, 4, 6, 8} + for i := 0; i < result.size; i++ { + if result.data[i] != expected[i] { + t.Errorf("期望结果 %v,但得到 %v", expected, result.data) + break + } + } +} + +// TestTranspose 测试矩阵转置 +func TestTranspose(t *testing.T) { + t.Run("正常转置", func(t *testing.T) { + // 创建矩阵 [[1,2,3],[4,5,6]] + matrix, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3}) + + result, err := matrix.Transpose() + if err != nil { + t.Errorf("期望无错误,但得到: %v", err) + } + + // 预期结果是 [[1,4],[2,5],[3,6]] + expected := []float64{1, 4, 2, 5, 3, 6} + for i := 0; i < result.size; i++ { + if result.data[i] != expected[i] { + t.Errorf("期望结果 %v,但得到 %v", expected, result.data) + break + } + } + + // 检查形状是否正确转置 + if result.shape[0] != 3 || result.shape[1] != 2 { + t.Errorf("期望形状 [3,2],但得到 %v", result.shape) + } + }) + + t.Run("非2D矩阵转置", func(t *testing.T) { + matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2, 1}) + _, err := matrix.Transpose() + if err == nil { + t.Error("期望有错误,但没有错误返回") + } + }) +} + +// TestCopy 测试矩阵复制 +func TestCopy(t *testing.T) { + matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) + + copied := matrix.Copy() + + // 检查复制的矩阵是否相等 + if !matrix.Equal(copied) { + t.Error("复制的矩阵应该与原矩阵相等") + } + + // 修改复制的矩阵不应该影响原矩阵 + copied.Set(99, 0, 0) + + originalValue, _ := matrix.Get(0, 0) + if originalValue == 99 { + t.Error("修改复制的矩阵不应该影响原矩阵") + } +} + +// TestShapeAndSize 测试Shape和Size方法 +func TestShapeAndSize(t *testing.T) { + matrix, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3}) + + shape := matrix.Shape() + if !reflect.DeepEqual(shape, []int{2, 3}) { + t.Errorf("期望形状 [2,3],但得到 %v", shape) + } + + size := matrix.Size() + if size != 6 { + t.Errorf("期望大小为6,但得到 %d", size) + } +} + +// TestEqual 测试Equal方法 +func TestEqual(t *testing.T) { + matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) + matrix2, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2}) + matrix3, _ := NewMatrix([]float64{1, 2, 3, 5}, []int{2, 2}) + + if !matrix1.Equal(matrix2) { + t.Error("两个相同矩阵应该相等") + } + + if matrix1.Equal(matrix3) { + t.Error("两个不同矩阵不应该相等") + } +} \ No newline at end of file