test(gomatrix): 添加矩阵操作的全面测试用例

添加了完整的单元测试覆盖矩阵库的核心功能,包括:
- 矩阵创建函数测试(NewMatrix, NewZeros, NewOnes, NewIdentity)
- 矩阵基本操作测试(Get, Set, Add, Subtract, Multiply)
- 矩阵乘法和转置操作测试(MatMul, Transpose)
- 矩阵复制和比较测试(Copy, Equal)
- 边界条件和错误处理测试
- 各种异常情况的测试用例
```
This commit is contained in:
程广 2025-12-31 17:26:27 +08:00
parent c2ed416436
commit f40960a855
1 changed files with 402 additions and 0 deletions

402
matrix_test.go Normal file
View File

@ -0,0 +1,402 @@
package gomatrix
import (
"reflect"
"testing"
)
// TestNewMatrix 测试NewMatrix函数
func TestNewMatrix(t *testing.T) {
t.Run("正常创建矩阵", func(t *testing.T) {
data := []float64{1, 2, 3, 4}
shape := []int{2, 2}
matrix, err := NewMatrix(data, shape)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
if matrix == nil {
t.Error("期望矩阵不为nil")
}
if !reflect.DeepEqual(matrix.shape, shape) {
t.Errorf("期望形状 %v但得到 %v", shape, matrix.shape)
}
if matrix.size != 4 {
t.Errorf("期望大小为4但得到 %d", matrix.size)
}
})
t.Run("数据和形状长度不匹配", func(t *testing.T) {
data := []float64{1, 2, 3} // 3个元素
shape := []int{2, 2} // 需要4个元素
_, err := NewMatrix(data, shape)
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
t.Run("空数据", func(t *testing.T) {
_, err := NewMatrix([]float64{}, []int{2, 2})
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
t.Run("空形状", func(t *testing.T) {
_, err := NewMatrix([]float64{1, 2, 3, 4}, []int{})
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
t.Run("负维度", func(t *testing.T) {
_, err := NewMatrix([]float64{1, 2, 3, 4}, []int{-1, 2})
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
}
// TestNewZeros 测试NewZeros函数
func TestNewZeros(t *testing.T) {
t.Run("创建零矩阵", func(t *testing.T) {
shape := []int{2, 3}
matrix, err := NewZeros(shape)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
if matrix.size != 6 {
t.Errorf("期望大小为6但得到 %d", matrix.size)
}
for i := 0; i < matrix.size; i++ {
if matrix.data[i] != 0 {
t.Errorf("期望所有元素为0但在位置 %d 发现值 %f", i, matrix.data[i])
}
}
})
t.Run("负维度", func(t *testing.T) {
_, err := NewZeros([]int{-1, 2})
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
}
// TestNewOnes 测试NewOnes函数
func TestNewOnes(t *testing.T) {
t.Run("创建全1矩阵", func(t *testing.T) {
shape := []int{2, 2}
matrix, err := NewOnes(shape)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
if matrix.size != 4 {
t.Errorf("期望大小为4但得到 %d", matrix.size)
}
for i := 0; i < matrix.size; i++ {
if matrix.data[i] != 1 {
t.Errorf("期望所有元素为1但在位置 %d 发现值 %f", i, matrix.data[i])
}
}
})
}
// TestNewIdentity 测试NewIdentity函数
func TestNewIdentity(t *testing.T) {
t.Run("创建单位矩阵", func(t *testing.T) {
size := 3
matrix, err := NewIdentity(size)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
if matrix.size != 9 {
t.Errorf("期望大小为9但得到 %d", matrix.size)
}
expectedData := []float64{1, 0, 0, 0, 1, 0, 0, 0, 1}
for i := 0; i < matrix.size; i++ {
if matrix.data[i] != expectedData[i] {
t.Errorf("期望数据 %v但得到 %v", expectedData, matrix.data)
break
}
}
})
t.Run("负大小", func(t *testing.T) {
_, err := NewIdentity(-1)
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
}
// TestGetSet 测试Get和Set方法
func TestGetSet(t *testing.T) {
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
t.Run("正常获取值", func(t *testing.T) {
value, err := matrix.Get(0, 0)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
if value != 1 {
t.Errorf("期望值为1但得到 %f", value)
}
value, err = matrix.Get(1, 1)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
if value != 4 {
t.Errorf("期望值为4但得到 %f", value)
}
})
t.Run("设置值", func(t *testing.T) {
err := matrix.Set(99, 0, 0)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
value, err := matrix.Get(0, 0)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
if value != 99 {
t.Errorf("期望值为99但得到 %f", value)
}
})
t.Run("越界索引", func(t *testing.T) {
_, err := matrix.Get(5, 5)
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
err = matrix.Set(100, 5, 5)
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
}
// TestAdd 测试矩阵加法
func TestAdd(t *testing.T) {
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
matrix2, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2})
t.Run("正常相加", func(t *testing.T) {
result, err := matrix1.Add(matrix2)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
expected := []float64{6, 8, 10, 12}
for i := 0; i < result.size; i++ {
if result.data[i] != expected[i] {
t.Errorf("期望结果 %v但得到 %v", expected, result.data)
break
}
}
})
t.Run("不同形状的矩阵相加", func(t *testing.T) {
matrix3, _ := NewMatrix([]float64{1, 2}, []int{1, 2})
_, err := matrix1.Add(matrix3)
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
}
// TestSubtract 测试矩阵减法
func TestSubtract(t *testing.T) {
matrix1, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2})
matrix2, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
t.Run("正常相减", func(t *testing.T) {
result, err := matrix1.Subtract(matrix2)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
expected := []float64{4, 4, 4, 4}
for i := 0; i < result.size; i++ {
if result.data[i] != expected[i] {
t.Errorf("期望结果 %v但得到 %v", expected, result.data)
break
}
}
})
}
// TestMultiply 测试矩阵逐元素乘法
func TestMultiply(t *testing.T) {
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
matrix2, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2})
t.Run("正常逐元素相乘", func(t *testing.T) {
result, err := matrix1.Multiply(matrix2)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
expected := []float64{5, 12, 21, 32}
for i := 0; i < result.size; i++ {
if result.data[i] != expected[i] {
t.Errorf("期望结果 %v但得到 %v", expected, result.data)
break
}
}
})
}
// TestMatMul 测试矩阵乘法
func TestMatMul(t *testing.T) {
t.Run("正常矩阵乘法", func(t *testing.T) {
// 创建2x3矩阵 [[1,2,3],[4,5,6]]
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3})
// 创建3x2矩阵 [[7,8],[9,10],[11,12]]
matrix2, _ := NewMatrix([]float64{7, 8, 9, 10, 11, 12}, []int{3, 2})
result, err := matrix1.MatMul(matrix2)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
// 预期结果是2x2矩阵 [[58,64],[139,154]]
expected := []float64{58, 64, 139, 154}
for i := 0; i < result.size; i++ {
if result.data[i] != expected[i] {
t.Errorf("期望结果 %v但得到 %v", expected, result.data)
break
}
}
})
t.Run("不兼容的矩阵乘法", func(t *testing.T) {
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
matrix2, _ := NewMatrix([]float64{1, 2}, []int{1, 2})
_, err := matrix1.MatMul(matrix2)
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
}
// TestScale 测试矩阵数乘
func TestScale(t *testing.T) {
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
result := matrix.Scale(2)
expected := []float64{2, 4, 6, 8}
for i := 0; i < result.size; i++ {
if result.data[i] != expected[i] {
t.Errorf("期望结果 %v但得到 %v", expected, result.data)
break
}
}
}
// TestTranspose 测试矩阵转置
func TestTranspose(t *testing.T) {
t.Run("正常转置", func(t *testing.T) {
// 创建矩阵 [[1,2,3],[4,5,6]]
matrix, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3})
result, err := matrix.Transpose()
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
// 预期结果是 [[1,4],[2,5],[3,6]]
expected := []float64{1, 4, 2, 5, 3, 6}
for i := 0; i < result.size; i++ {
if result.data[i] != expected[i] {
t.Errorf("期望结果 %v但得到 %v", expected, result.data)
break
}
}
// 检查形状是否正确转置
if result.shape[0] != 3 || result.shape[1] != 2 {
t.Errorf("期望形状 [3,2],但得到 %v", result.shape)
}
})
t.Run("非2D矩阵转置", func(t *testing.T) {
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2, 1})
_, err := matrix.Transpose()
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
}
// TestCopy 测试矩阵复制
func TestCopy(t *testing.T) {
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
copied := matrix.Copy()
// 检查复制的矩阵是否相等
if !matrix.Equal(copied) {
t.Error("复制的矩阵应该与原矩阵相等")
}
// 修改复制的矩阵不应该影响原矩阵
copied.Set(99, 0, 0)
originalValue, _ := matrix.Get(0, 0)
if originalValue == 99 {
t.Error("修改复制的矩阵不应该影响原矩阵")
}
}
// TestShapeAndSize 测试Shape和Size方法
func TestShapeAndSize(t *testing.T) {
matrix, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3})
shape := matrix.Shape()
if !reflect.DeepEqual(shape, []int{2, 3}) {
t.Errorf("期望形状 [2,3],但得到 %v", shape)
}
size := matrix.Size()
if size != 6 {
t.Errorf("期望大小为6但得到 %d", size)
}
}
// TestEqual 测试Equal方法
func TestEqual(t *testing.T) {
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
matrix2, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
matrix3, _ := NewMatrix([]float64{1, 2, 3, 5}, []int{2, 2})
if !matrix1.Equal(matrix2) {
t.Error("两个相同矩阵应该相等")
}
if matrix1.Equal(matrix3) {
t.Error("两个不同矩阵不应该相等")
}
}