```
test(gomatrix): 添加矩阵操作的全面测试用例 添加了完整的单元测试覆盖矩阵库的核心功能,包括: - 矩阵创建函数测试(NewMatrix, NewZeros, NewOnes, NewIdentity) - 矩阵基本操作测试(Get, Set, Add, Subtract, Multiply) - 矩阵乘法和转置操作测试(MatMul, Transpose) - 矩阵复制和比较测试(Copy, Equal) - 边界条件和错误处理测试 - 各种异常情况的测试用例 ```
This commit is contained in:
parent
c2ed416436
commit
f40960a855
|
|
@ -0,0 +1,402 @@
|
|||
package gomatrix
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestNewMatrix 测试NewMatrix函数
|
||||
func TestNewMatrix(t *testing.T) {
|
||||
t.Run("正常创建矩阵", func(t *testing.T) {
|
||||
data := []float64{1, 2, 3, 4}
|
||||
shape := []int{2, 2}
|
||||
matrix, err := NewMatrix(data, shape)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("期望无错误,但得到: %v", err)
|
||||
}
|
||||
|
||||
if matrix == nil {
|
||||
t.Error("期望矩阵不为nil")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(matrix.shape, shape) {
|
||||
t.Errorf("期望形状 %v,但得到 %v", shape, matrix.shape)
|
||||
}
|
||||
|
||||
if matrix.size != 4 {
|
||||
t.Errorf("期望大小为4,但得到 %d", matrix.size)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("数据和形状长度不匹配", func(t *testing.T) {
|
||||
data := []float64{1, 2, 3} // 3个元素
|
||||
shape := []int{2, 2} // 需要4个元素
|
||||
_, err := NewMatrix(data, shape)
|
||||
|
||||
if err == nil {
|
||||
t.Error("期望有错误,但没有错误返回")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("空数据", func(t *testing.T) {
|
||||
_, err := NewMatrix([]float64{}, []int{2, 2})
|
||||
|
||||
if err == nil {
|
||||
t.Error("期望有错误,但没有错误返回")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("空形状", func(t *testing.T) {
|
||||
_, err := NewMatrix([]float64{1, 2, 3, 4}, []int{})
|
||||
|
||||
if err == nil {
|
||||
t.Error("期望有错误,但没有错误返回")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("负维度", func(t *testing.T) {
|
||||
_, err := NewMatrix([]float64{1, 2, 3, 4}, []int{-1, 2})
|
||||
|
||||
if err == nil {
|
||||
t.Error("期望有错误,但没有错误返回")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestNewZeros 测试NewZeros函数
|
||||
func TestNewZeros(t *testing.T) {
|
||||
t.Run("创建零矩阵", func(t *testing.T) {
|
||||
shape := []int{2, 3}
|
||||
matrix, err := NewZeros(shape)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("期望无错误,但得到: %v", err)
|
||||
}
|
||||
|
||||
if matrix.size != 6 {
|
||||
t.Errorf("期望大小为6,但得到 %d", matrix.size)
|
||||
}
|
||||
|
||||
for i := 0; i < matrix.size; i++ {
|
||||
if matrix.data[i] != 0 {
|
||||
t.Errorf("期望所有元素为0,但在位置 %d 发现值 %f", i, matrix.data[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("负维度", func(t *testing.T) {
|
||||
_, err := NewZeros([]int{-1, 2})
|
||||
|
||||
if err == nil {
|
||||
t.Error("期望有错误,但没有错误返回")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestNewOnes 测试NewOnes函数
|
||||
func TestNewOnes(t *testing.T) {
|
||||
t.Run("创建全1矩阵", func(t *testing.T) {
|
||||
shape := []int{2, 2}
|
||||
matrix, err := NewOnes(shape)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("期望无错误,但得到: %v", err)
|
||||
}
|
||||
|
||||
if matrix.size != 4 {
|
||||
t.Errorf("期望大小为4,但得到 %d", matrix.size)
|
||||
}
|
||||
|
||||
for i := 0; i < matrix.size; i++ {
|
||||
if matrix.data[i] != 1 {
|
||||
t.Errorf("期望所有元素为1,但在位置 %d 发现值 %f", i, matrix.data[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestNewIdentity 测试NewIdentity函数
|
||||
func TestNewIdentity(t *testing.T) {
|
||||
t.Run("创建单位矩阵", func(t *testing.T) {
|
||||
size := 3
|
||||
matrix, err := NewIdentity(size)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("期望无错误,但得到: %v", err)
|
||||
}
|
||||
|
||||
if matrix.size != 9 {
|
||||
t.Errorf("期望大小为9,但得到 %d", matrix.size)
|
||||
}
|
||||
|
||||
expectedData := []float64{1, 0, 0, 0, 1, 0, 0, 0, 1}
|
||||
for i := 0; i < matrix.size; i++ {
|
||||
if matrix.data[i] != expectedData[i] {
|
||||
t.Errorf("期望数据 %v,但得到 %v", expectedData, matrix.data)
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("负大小", func(t *testing.T) {
|
||||
_, err := NewIdentity(-1)
|
||||
|
||||
if err == nil {
|
||||
t.Error("期望有错误,但没有错误返回")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestGetSet 测试Get和Set方法
|
||||
func TestGetSet(t *testing.T) {
|
||||
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||||
|
||||
t.Run("正常获取值", func(t *testing.T) {
|
||||
value, err := matrix.Get(0, 0)
|
||||
if err != nil {
|
||||
t.Errorf("期望无错误,但得到: %v", err)
|
||||
}
|
||||
if value != 1 {
|
||||
t.Errorf("期望值为1,但得到 %f", value)
|
||||
}
|
||||
|
||||
value, err = matrix.Get(1, 1)
|
||||
if err != nil {
|
||||
t.Errorf("期望无错误,但得到: %v", err)
|
||||
}
|
||||
if value != 4 {
|
||||
t.Errorf("期望值为4,但得到 %f", value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("设置值", func(t *testing.T) {
|
||||
err := matrix.Set(99, 0, 0)
|
||||
if err != nil {
|
||||
t.Errorf("期望无错误,但得到: %v", err)
|
||||
}
|
||||
|
||||
value, err := matrix.Get(0, 0)
|
||||
if err != nil {
|
||||
t.Errorf("期望无错误,但得到: %v", err)
|
||||
}
|
||||
if value != 99 {
|
||||
t.Errorf("期望值为99,但得到 %f", value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("越界索引", func(t *testing.T) {
|
||||
_, err := matrix.Get(5, 5)
|
||||
if err == nil {
|
||||
t.Error("期望有错误,但没有错误返回")
|
||||
}
|
||||
|
||||
err = matrix.Set(100, 5, 5)
|
||||
if err == nil {
|
||||
t.Error("期望有错误,但没有错误返回")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestAdd 测试矩阵加法
|
||||
func TestAdd(t *testing.T) {
|
||||
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||||
matrix2, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2})
|
||||
|
||||
t.Run("正常相加", func(t *testing.T) {
|
||||
result, err := matrix1.Add(matrix2)
|
||||
if err != nil {
|
||||
t.Errorf("期望无错误,但得到: %v", err)
|
||||
}
|
||||
|
||||
expected := []float64{6, 8, 10, 12}
|
||||
for i := 0; i < result.size; i++ {
|
||||
if result.data[i] != expected[i] {
|
||||
t.Errorf("期望结果 %v,但得到 %v", expected, result.data)
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("不同形状的矩阵相加", func(t *testing.T) {
|
||||
matrix3, _ := NewMatrix([]float64{1, 2}, []int{1, 2})
|
||||
_, err := matrix1.Add(matrix3)
|
||||
if err == nil {
|
||||
t.Error("期望有错误,但没有错误返回")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestSubtract 测试矩阵减法
|
||||
func TestSubtract(t *testing.T) {
|
||||
matrix1, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2})
|
||||
matrix2, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||||
|
||||
t.Run("正常相减", func(t *testing.T) {
|
||||
result, err := matrix1.Subtract(matrix2)
|
||||
if err != nil {
|
||||
t.Errorf("期望无错误,但得到: %v", err)
|
||||
}
|
||||
|
||||
expected := []float64{4, 4, 4, 4}
|
||||
for i := 0; i < result.size; i++ {
|
||||
if result.data[i] != expected[i] {
|
||||
t.Errorf("期望结果 %v,但得到 %v", expected, result.data)
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMultiply 测试矩阵逐元素乘法
|
||||
func TestMultiply(t *testing.T) {
|
||||
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||||
matrix2, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2})
|
||||
|
||||
t.Run("正常逐元素相乘", func(t *testing.T) {
|
||||
result, err := matrix1.Multiply(matrix2)
|
||||
if err != nil {
|
||||
t.Errorf("期望无错误,但得到: %v", err)
|
||||
}
|
||||
|
||||
expected := []float64{5, 12, 21, 32}
|
||||
for i := 0; i < result.size; i++ {
|
||||
if result.data[i] != expected[i] {
|
||||
t.Errorf("期望结果 %v,但得到 %v", expected, result.data)
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMatMul 测试矩阵乘法
|
||||
func TestMatMul(t *testing.T) {
|
||||
t.Run("正常矩阵乘法", func(t *testing.T) {
|
||||
// 创建2x3矩阵 [[1,2,3],[4,5,6]]
|
||||
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3})
|
||||
// 创建3x2矩阵 [[7,8],[9,10],[11,12]]
|
||||
matrix2, _ := NewMatrix([]float64{7, 8, 9, 10, 11, 12}, []int{3, 2})
|
||||
|
||||
result, err := matrix1.MatMul(matrix2)
|
||||
if err != nil {
|
||||
t.Errorf("期望无错误,但得到: %v", err)
|
||||
}
|
||||
|
||||
// 预期结果是2x2矩阵 [[58,64],[139,154]]
|
||||
expected := []float64{58, 64, 139, 154}
|
||||
for i := 0; i < result.size; i++ {
|
||||
if result.data[i] != expected[i] {
|
||||
t.Errorf("期望结果 %v,但得到 %v", expected, result.data)
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("不兼容的矩阵乘法", func(t *testing.T) {
|
||||
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||||
matrix2, _ := NewMatrix([]float64{1, 2}, []int{1, 2})
|
||||
_, err := matrix1.MatMul(matrix2)
|
||||
if err == nil {
|
||||
t.Error("期望有错误,但没有错误返回")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestScale 测试矩阵数乘
|
||||
func TestScale(t *testing.T) {
|
||||
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||||
|
||||
result := matrix.Scale(2)
|
||||
expected := []float64{2, 4, 6, 8}
|
||||
for i := 0; i < result.size; i++ {
|
||||
if result.data[i] != expected[i] {
|
||||
t.Errorf("期望结果 %v,但得到 %v", expected, result.data)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTranspose 测试矩阵转置
|
||||
func TestTranspose(t *testing.T) {
|
||||
t.Run("正常转置", func(t *testing.T) {
|
||||
// 创建矩阵 [[1,2,3],[4,5,6]]
|
||||
matrix, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3})
|
||||
|
||||
result, err := matrix.Transpose()
|
||||
if err != nil {
|
||||
t.Errorf("期望无错误,但得到: %v", err)
|
||||
}
|
||||
|
||||
// 预期结果是 [[1,4],[2,5],[3,6]]
|
||||
expected := []float64{1, 4, 2, 5, 3, 6}
|
||||
for i := 0; i < result.size; i++ {
|
||||
if result.data[i] != expected[i] {
|
||||
t.Errorf("期望结果 %v,但得到 %v", expected, result.data)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 检查形状是否正确转置
|
||||
if result.shape[0] != 3 || result.shape[1] != 2 {
|
||||
t.Errorf("期望形状 [3,2],但得到 %v", result.shape)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("非2D矩阵转置", func(t *testing.T) {
|
||||
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2, 1})
|
||||
_, err := matrix.Transpose()
|
||||
if err == nil {
|
||||
t.Error("期望有错误,但没有错误返回")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestCopy 测试矩阵复制
|
||||
func TestCopy(t *testing.T) {
|
||||
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||||
|
||||
copied := matrix.Copy()
|
||||
|
||||
// 检查复制的矩阵是否相等
|
||||
if !matrix.Equal(copied) {
|
||||
t.Error("复制的矩阵应该与原矩阵相等")
|
||||
}
|
||||
|
||||
// 修改复制的矩阵不应该影响原矩阵
|
||||
copied.Set(99, 0, 0)
|
||||
|
||||
originalValue, _ := matrix.Get(0, 0)
|
||||
if originalValue == 99 {
|
||||
t.Error("修改复制的矩阵不应该影响原矩阵")
|
||||
}
|
||||
}
|
||||
|
||||
// TestShapeAndSize 测试Shape和Size方法
|
||||
func TestShapeAndSize(t *testing.T) {
|
||||
matrix, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3})
|
||||
|
||||
shape := matrix.Shape()
|
||||
if !reflect.DeepEqual(shape, []int{2, 3}) {
|
||||
t.Errorf("期望形状 [2,3],但得到 %v", shape)
|
||||
}
|
||||
|
||||
size := matrix.Size()
|
||||
if size != 6 {
|
||||
t.Errorf("期望大小为6,但得到 %d", size)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEqual 测试Equal方法
|
||||
func TestEqual(t *testing.T) {
|
||||
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||||
matrix2, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||||
matrix3, _ := NewMatrix([]float64{1, 2, 3, 5}, []int{2, 2})
|
||||
|
||||
if !matrix1.Equal(matrix2) {
|
||||
t.Error("两个相同矩阵应该相等")
|
||||
}
|
||||
|
||||
if matrix1.Equal(matrix3) {
|
||||
t.Error("两个不同矩阵不应该相等")
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue