402 lines
9.8 KiB
Go
402 lines
9.8 KiB
Go
package gomatrix
|
||
|
||
import (
|
||
"reflect"
|
||
"testing"
|
||
)
|
||
|
||
// TestNewMatrix 测试NewMatrix函数
|
||
func TestNewMatrix(t *testing.T) {
|
||
t.Run("正常创建矩阵", func(t *testing.T) {
|
||
data := []float64{1, 2, 3, 4}
|
||
shape := []int{2, 2}
|
||
matrix, err := NewMatrix(data, shape)
|
||
|
||
if err != nil {
|
||
t.Errorf("期望无错误,但得到: %v", err)
|
||
}
|
||
|
||
if matrix == nil {
|
||
t.Error("期望矩阵不为nil")
|
||
}
|
||
|
||
if !reflect.DeepEqual(matrix.shape, shape) {
|
||
t.Errorf("期望形状 %v,但得到 %v", shape, matrix.shape)
|
||
}
|
||
|
||
if matrix.size != 4 {
|
||
t.Errorf("期望大小为4,但得到 %d", matrix.size)
|
||
}
|
||
})
|
||
|
||
t.Run("数据和形状长度不匹配", func(t *testing.T) {
|
||
data := []float64{1, 2, 3} // 3个元素
|
||
shape := []int{2, 2} // 需要4个元素
|
||
_, err := NewMatrix(data, shape)
|
||
|
||
if err == nil {
|
||
t.Error("期望有错误,但没有错误返回")
|
||
}
|
||
})
|
||
|
||
t.Run("空数据", func(t *testing.T) {
|
||
_, err := NewMatrix([]float64{}, []int{2, 2})
|
||
|
||
if err == nil {
|
||
t.Error("期望有错误,但没有错误返回")
|
||
}
|
||
})
|
||
|
||
t.Run("空形状", func(t *testing.T) {
|
||
_, err := NewMatrix([]float64{1, 2, 3, 4}, []int{})
|
||
|
||
if err == nil {
|
||
t.Error("期望有错误,但没有错误返回")
|
||
}
|
||
})
|
||
|
||
t.Run("负维度", func(t *testing.T) {
|
||
_, err := NewMatrix([]float64{1, 2, 3, 4}, []int{-1, 2})
|
||
|
||
if err == nil {
|
||
t.Error("期望有错误,但没有错误返回")
|
||
}
|
||
})
|
||
}
|
||
|
||
// TestNewZeros 测试NewZeros函数
|
||
func TestNewZeros(t *testing.T) {
|
||
t.Run("创建零矩阵", func(t *testing.T) {
|
||
shape := []int{2, 3}
|
||
matrix, err := NewZeros(shape)
|
||
|
||
if err != nil {
|
||
t.Errorf("期望无错误,但得到: %v", err)
|
||
}
|
||
|
||
if matrix.size != 6 {
|
||
t.Errorf("期望大小为6,但得到 %d", matrix.size)
|
||
}
|
||
|
||
for i := 0; i < matrix.size; i++ {
|
||
if matrix.data[i] != 0 {
|
||
t.Errorf("期望所有元素为0,但在位置 %d 发现值 %f", i, matrix.data[i])
|
||
}
|
||
}
|
||
})
|
||
|
||
t.Run("负维度", func(t *testing.T) {
|
||
_, err := NewZeros([]int{-1, 2})
|
||
|
||
if err == nil {
|
||
t.Error("期望有错误,但没有错误返回")
|
||
}
|
||
})
|
||
}
|
||
|
||
// TestNewOnes 测试NewOnes函数
|
||
func TestNewOnes(t *testing.T) {
|
||
t.Run("创建全1矩阵", func(t *testing.T) {
|
||
shape := []int{2, 2}
|
||
matrix, err := NewOnes(shape)
|
||
|
||
if err != nil {
|
||
t.Errorf("期望无错误,但得到: %v", err)
|
||
}
|
||
|
||
if matrix.size != 4 {
|
||
t.Errorf("期望大小为4,但得到 %d", matrix.size)
|
||
}
|
||
|
||
for i := 0; i < matrix.size; i++ {
|
||
if matrix.data[i] != 1 {
|
||
t.Errorf("期望所有元素为1,但在位置 %d 发现值 %f", i, matrix.data[i])
|
||
}
|
||
}
|
||
})
|
||
}
|
||
|
||
// TestNewIdentity 测试NewIdentity函数
|
||
func TestNewIdentity(t *testing.T) {
|
||
t.Run("创建单位矩阵", func(t *testing.T) {
|
||
size := 3
|
||
matrix, err := NewIdentity(size)
|
||
|
||
if err != nil {
|
||
t.Errorf("期望无错误,但得到: %v", err)
|
||
}
|
||
|
||
if matrix.size != 9 {
|
||
t.Errorf("期望大小为9,但得到 %d", matrix.size)
|
||
}
|
||
|
||
expectedData := []float64{1, 0, 0, 0, 1, 0, 0, 0, 1}
|
||
for i := 0; i < matrix.size; i++ {
|
||
if matrix.data[i] != expectedData[i] {
|
||
t.Errorf("期望数据 %v,但得到 %v", expectedData, matrix.data)
|
||
break
|
||
}
|
||
}
|
||
})
|
||
|
||
t.Run("负大小", func(t *testing.T) {
|
||
_, err := NewIdentity(-1)
|
||
|
||
if err == nil {
|
||
t.Error("期望有错误,但没有错误返回")
|
||
}
|
||
})
|
||
}
|
||
|
||
// TestGetSet 测试Get和Set方法
|
||
func TestGetSet(t *testing.T) {
|
||
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||
|
||
t.Run("正常获取值", func(t *testing.T) {
|
||
value, err := matrix.Get(0, 0)
|
||
if err != nil {
|
||
t.Errorf("期望无错误,但得到: %v", err)
|
||
}
|
||
if value != 1 {
|
||
t.Errorf("期望值为1,但得到 %f", value)
|
||
}
|
||
|
||
value, err = matrix.Get(1, 1)
|
||
if err != nil {
|
||
t.Errorf("期望无错误,但得到: %v", err)
|
||
}
|
||
if value != 4 {
|
||
t.Errorf("期望值为4,但得到 %f", value)
|
||
}
|
||
})
|
||
|
||
t.Run("设置值", func(t *testing.T) {
|
||
err := matrix.Set(99, 0, 0)
|
||
if err != nil {
|
||
t.Errorf("期望无错误,但得到: %v", err)
|
||
}
|
||
|
||
value, err := matrix.Get(0, 0)
|
||
if err != nil {
|
||
t.Errorf("期望无错误,但得到: %v", err)
|
||
}
|
||
if value != 99 {
|
||
t.Errorf("期望值为99,但得到 %f", value)
|
||
}
|
||
})
|
||
|
||
t.Run("越界索引", func(t *testing.T) {
|
||
_, err := matrix.Get(5, 5)
|
||
if err == nil {
|
||
t.Error("期望有错误,但没有错误返回")
|
||
}
|
||
|
||
err = matrix.Set(100, 5, 5)
|
||
if err == nil {
|
||
t.Error("期望有错误,但没有错误返回")
|
||
}
|
||
})
|
||
}
|
||
|
||
// TestAdd 测试矩阵加法
|
||
func TestAdd(t *testing.T) {
|
||
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||
matrix2, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2})
|
||
|
||
t.Run("正常相加", func(t *testing.T) {
|
||
result, err := matrix1.Add(matrix2)
|
||
if err != nil {
|
||
t.Errorf("期望无错误,但得到: %v", err)
|
||
}
|
||
|
||
expected := []float64{6, 8, 10, 12}
|
||
for i := 0; i < result.size; i++ {
|
||
if result.data[i] != expected[i] {
|
||
t.Errorf("期望结果 %v,但得到 %v", expected, result.data)
|
||
break
|
||
}
|
||
}
|
||
})
|
||
|
||
t.Run("不同形状的矩阵相加", func(t *testing.T) {
|
||
matrix3, _ := NewMatrix([]float64{1, 2}, []int{1, 2})
|
||
_, err := matrix1.Add(matrix3)
|
||
if err == nil {
|
||
t.Error("期望有错误,但没有错误返回")
|
||
}
|
||
})
|
||
}
|
||
|
||
// TestSubtract 测试矩阵减法
|
||
func TestSubtract(t *testing.T) {
|
||
matrix1, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2})
|
||
matrix2, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||
|
||
t.Run("正常相减", func(t *testing.T) {
|
||
result, err := matrix1.Subtract(matrix2)
|
||
if err != nil {
|
||
t.Errorf("期望无错误,但得到: %v", err)
|
||
}
|
||
|
||
expected := []float64{4, 4, 4, 4}
|
||
for i := 0; i < result.size; i++ {
|
||
if result.data[i] != expected[i] {
|
||
t.Errorf("期望结果 %v,但得到 %v", expected, result.data)
|
||
break
|
||
}
|
||
}
|
||
})
|
||
}
|
||
|
||
// TestMultiply 测试矩阵逐元素乘法
|
||
func TestMultiply(t *testing.T) {
|
||
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||
matrix2, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2})
|
||
|
||
t.Run("正常逐元素相乘", func(t *testing.T) {
|
||
result, err := matrix1.Multiply(matrix2)
|
||
if err != nil {
|
||
t.Errorf("期望无错误,但得到: %v", err)
|
||
}
|
||
|
||
expected := []float64{5, 12, 21, 32}
|
||
for i := 0; i < result.size; i++ {
|
||
if result.data[i] != expected[i] {
|
||
t.Errorf("期望结果 %v,但得到 %v", expected, result.data)
|
||
break
|
||
}
|
||
}
|
||
})
|
||
}
|
||
|
||
// TestMatMul 测试矩阵乘法
|
||
func TestMatMul(t *testing.T) {
|
||
t.Run("正常矩阵乘法", func(t *testing.T) {
|
||
// 创建2x3矩阵 [[1,2,3],[4,5,6]]
|
||
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3})
|
||
// 创建3x2矩阵 [[7,8],[9,10],[11,12]]
|
||
matrix2, _ := NewMatrix([]float64{7, 8, 9, 10, 11, 12}, []int{3, 2})
|
||
|
||
result, err := matrix1.MatMul(matrix2)
|
||
if err != nil {
|
||
t.Errorf("期望无错误,但得到: %v", err)
|
||
}
|
||
|
||
// 预期结果是2x2矩阵 [[58,64],[139,154]]
|
||
expected := []float64{58, 64, 139, 154}
|
||
for i := 0; i < result.size; i++ {
|
||
if result.data[i] != expected[i] {
|
||
t.Errorf("期望结果 %v,但得到 %v", expected, result.data)
|
||
break
|
||
}
|
||
}
|
||
})
|
||
|
||
t.Run("不兼容的矩阵乘法", func(t *testing.T) {
|
||
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||
matrix2, _ := NewMatrix([]float64{1, 2}, []int{1, 2})
|
||
_, err := matrix1.MatMul(matrix2)
|
||
if err == nil {
|
||
t.Error("期望有错误,但没有错误返回")
|
||
}
|
||
})
|
||
}
|
||
|
||
// TestScale 测试矩阵数乘
|
||
func TestScale(t *testing.T) {
|
||
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||
|
||
result := matrix.Scale(2)
|
||
expected := []float64{2, 4, 6, 8}
|
||
for i := 0; i < result.size; i++ {
|
||
if result.data[i] != expected[i] {
|
||
t.Errorf("期望结果 %v,但得到 %v", expected, result.data)
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// TestTranspose 测试矩阵转置
|
||
func TestTranspose(t *testing.T) {
|
||
t.Run("正常转置", func(t *testing.T) {
|
||
// 创建矩阵 [[1,2,3],[4,5,6]]
|
||
matrix, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3})
|
||
|
||
result, err := matrix.Transpose()
|
||
if err != nil {
|
||
t.Errorf("期望无错误,但得到: %v", err)
|
||
}
|
||
|
||
// 预期结果是 [[1,4],[2,5],[3,6]]
|
||
expected := []float64{1, 4, 2, 5, 3, 6}
|
||
for i := 0; i < result.size; i++ {
|
||
if result.data[i] != expected[i] {
|
||
t.Errorf("期望结果 %v,但得到 %v", expected, result.data)
|
||
break
|
||
}
|
||
}
|
||
|
||
// 检查形状是否正确转置
|
||
if result.shape[0] != 3 || result.shape[1] != 2 {
|
||
t.Errorf("期望形状 [3,2],但得到 %v", result.shape)
|
||
}
|
||
})
|
||
|
||
t.Run("非2D矩阵转置", func(t *testing.T) {
|
||
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2, 1})
|
||
_, err := matrix.Transpose()
|
||
if err == nil {
|
||
t.Error("期望有错误,但没有错误返回")
|
||
}
|
||
})
|
||
}
|
||
|
||
// TestCopy 测试矩阵复制
|
||
func TestCopy(t *testing.T) {
|
||
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||
|
||
copied := matrix.Copy()
|
||
|
||
// 检查复制的矩阵是否相等
|
||
if !matrix.Equal(copied) {
|
||
t.Error("复制的矩阵应该与原矩阵相等")
|
||
}
|
||
|
||
// 修改复制的矩阵不应该影响原矩阵
|
||
copied.Set(99, 0, 0)
|
||
|
||
originalValue, _ := matrix.Get(0, 0)
|
||
if originalValue == 99 {
|
||
t.Error("修改复制的矩阵不应该影响原矩阵")
|
||
}
|
||
}
|
||
|
||
// TestShapeAndSize 测试Shape和Size方法
|
||
func TestShapeAndSize(t *testing.T) {
|
||
matrix, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3})
|
||
|
||
shape := matrix.Shape()
|
||
if !reflect.DeepEqual(shape, []int{2, 3}) {
|
||
t.Errorf("期望形状 [2,3],但得到 %v", shape)
|
||
}
|
||
|
||
size := matrix.Size()
|
||
if size != 6 {
|
||
t.Errorf("期望大小为6,但得到 %d", size)
|
||
}
|
||
}
|
||
|
||
// TestEqual 测试Equal方法
|
||
func TestEqual(t *testing.T) {
|
||
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||
matrix2, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
|
||
matrix3, _ := NewMatrix([]float64{1, 2, 3, 5}, []int{2, 2})
|
||
|
||
if !matrix1.Equal(matrix2) {
|
||
t.Error("两个相同矩阵应该相等")
|
||
}
|
||
|
||
if matrix1.Equal(matrix3) {
|
||
t.Error("两个不同矩阵不应该相等")
|
||
}
|
||
} |