gomatrix/matrix_test.go

402 lines
9.8 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package gomatrix
import (
"reflect"
"testing"
)
// TestNewMatrix 测试NewMatrix函数
func TestNewMatrix(t *testing.T) {
t.Run("正常创建矩阵", func(t *testing.T) {
data := []float64{1, 2, 3, 4}
shape := []int{2, 2}
matrix, err := NewMatrix(data, shape)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
if matrix == nil {
t.Error("期望矩阵不为nil")
}
if !reflect.DeepEqual(matrix.shape, shape) {
t.Errorf("期望形状 %v但得到 %v", shape, matrix.shape)
}
if matrix.size != 4 {
t.Errorf("期望大小为4但得到 %d", matrix.size)
}
})
t.Run("数据和形状长度不匹配", func(t *testing.T) {
data := []float64{1, 2, 3} // 3个元素
shape := []int{2, 2} // 需要4个元素
_, err := NewMatrix(data, shape)
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
t.Run("空数据", func(t *testing.T) {
_, err := NewMatrix([]float64{}, []int{2, 2})
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
t.Run("空形状", func(t *testing.T) {
_, err := NewMatrix([]float64{1, 2, 3, 4}, []int{})
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
t.Run("负维度", func(t *testing.T) {
_, err := NewMatrix([]float64{1, 2, 3, 4}, []int{-1, 2})
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
}
// TestNewZeros 测试NewZeros函数
func TestNewZeros(t *testing.T) {
t.Run("创建零矩阵", func(t *testing.T) {
shape := []int{2, 3}
matrix, err := NewZeros(shape)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
if matrix.size != 6 {
t.Errorf("期望大小为6但得到 %d", matrix.size)
}
for i := 0; i < matrix.size; i++ {
if matrix.data[i] != 0 {
t.Errorf("期望所有元素为0但在位置 %d 发现值 %f", i, matrix.data[i])
}
}
})
t.Run("负维度", func(t *testing.T) {
_, err := NewZeros([]int{-1, 2})
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
}
// TestNewOnes 测试NewOnes函数
func TestNewOnes(t *testing.T) {
t.Run("创建全1矩阵", func(t *testing.T) {
shape := []int{2, 2}
matrix, err := NewOnes(shape)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
if matrix.size != 4 {
t.Errorf("期望大小为4但得到 %d", matrix.size)
}
for i := 0; i < matrix.size; i++ {
if matrix.data[i] != 1 {
t.Errorf("期望所有元素为1但在位置 %d 发现值 %f", i, matrix.data[i])
}
}
})
}
// TestNewIdentity 测试NewIdentity函数
func TestNewIdentity(t *testing.T) {
t.Run("创建单位矩阵", func(t *testing.T) {
size := 3
matrix, err := NewIdentity(size)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
if matrix.size != 9 {
t.Errorf("期望大小为9但得到 %d", matrix.size)
}
expectedData := []float64{1, 0, 0, 0, 1, 0, 0, 0, 1}
for i := 0; i < matrix.size; i++ {
if matrix.data[i] != expectedData[i] {
t.Errorf("期望数据 %v但得到 %v", expectedData, matrix.data)
break
}
}
})
t.Run("负大小", func(t *testing.T) {
_, err := NewIdentity(-1)
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
}
// TestGetSet 测试Get和Set方法
func TestGetSet(t *testing.T) {
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
t.Run("正常获取值", func(t *testing.T) {
value, err := matrix.Get(0, 0)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
if value != 1 {
t.Errorf("期望值为1但得到 %f", value)
}
value, err = matrix.Get(1, 1)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
if value != 4 {
t.Errorf("期望值为4但得到 %f", value)
}
})
t.Run("设置值", func(t *testing.T) {
err := matrix.Set(99, 0, 0)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
value, err := matrix.Get(0, 0)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
if value != 99 {
t.Errorf("期望值为99但得到 %f", value)
}
})
t.Run("越界索引", func(t *testing.T) {
_, err := matrix.Get(5, 5)
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
err = matrix.Set(100, 5, 5)
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
}
// TestAdd 测试矩阵加法
func TestAdd(t *testing.T) {
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
matrix2, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2})
t.Run("正常相加", func(t *testing.T) {
result, err := matrix1.Add(matrix2)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
expected := []float64{6, 8, 10, 12}
for i := 0; i < result.size; i++ {
if result.data[i] != expected[i] {
t.Errorf("期望结果 %v但得到 %v", expected, result.data)
break
}
}
})
t.Run("不同形状的矩阵相加", func(t *testing.T) {
matrix3, _ := NewMatrix([]float64{1, 2}, []int{1, 2})
_, err := matrix1.Add(matrix3)
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
}
// TestSubtract 测试矩阵减法
func TestSubtract(t *testing.T) {
matrix1, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2})
matrix2, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
t.Run("正常相减", func(t *testing.T) {
result, err := matrix1.Subtract(matrix2)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
expected := []float64{4, 4, 4, 4}
for i := 0; i < result.size; i++ {
if result.data[i] != expected[i] {
t.Errorf("期望结果 %v但得到 %v", expected, result.data)
break
}
}
})
}
// TestMultiply 测试矩阵逐元素乘法
func TestMultiply(t *testing.T) {
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
matrix2, _ := NewMatrix([]float64{5, 6, 7, 8}, []int{2, 2})
t.Run("正常逐元素相乘", func(t *testing.T) {
result, err := matrix1.Multiply(matrix2)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
expected := []float64{5, 12, 21, 32}
for i := 0; i < result.size; i++ {
if result.data[i] != expected[i] {
t.Errorf("期望结果 %v但得到 %v", expected, result.data)
break
}
}
})
}
// TestMatMul 测试矩阵乘法
func TestMatMul(t *testing.T) {
t.Run("正常矩阵乘法", func(t *testing.T) {
// 创建2x3矩阵 [[1,2,3],[4,5,6]]
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3})
// 创建3x2矩阵 [[7,8],[9,10],[11,12]]
matrix2, _ := NewMatrix([]float64{7, 8, 9, 10, 11, 12}, []int{3, 2})
result, err := matrix1.MatMul(matrix2)
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
// 预期结果是2x2矩阵 [[58,64],[139,154]]
expected := []float64{58, 64, 139, 154}
for i := 0; i < result.size; i++ {
if result.data[i] != expected[i] {
t.Errorf("期望结果 %v但得到 %v", expected, result.data)
break
}
}
})
t.Run("不兼容的矩阵乘法", func(t *testing.T) {
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
matrix2, _ := NewMatrix([]float64{1, 2}, []int{1, 2})
_, err := matrix1.MatMul(matrix2)
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
}
// TestScale 测试矩阵数乘
func TestScale(t *testing.T) {
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
result := matrix.Scale(2)
expected := []float64{2, 4, 6, 8}
for i := 0; i < result.size; i++ {
if result.data[i] != expected[i] {
t.Errorf("期望结果 %v但得到 %v", expected, result.data)
break
}
}
}
// TestTranspose 测试矩阵转置
func TestTranspose(t *testing.T) {
t.Run("正常转置", func(t *testing.T) {
// 创建矩阵 [[1,2,3],[4,5,6]]
matrix, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3})
result, err := matrix.Transpose()
if err != nil {
t.Errorf("期望无错误,但得到: %v", err)
}
// 预期结果是 [[1,4],[2,5],[3,6]]
expected := []float64{1, 4, 2, 5, 3, 6}
for i := 0; i < result.size; i++ {
if result.data[i] != expected[i] {
t.Errorf("期望结果 %v但得到 %v", expected, result.data)
break
}
}
// 检查形状是否正确转置
if result.shape[0] != 3 || result.shape[1] != 2 {
t.Errorf("期望形状 [3,2],但得到 %v", result.shape)
}
})
t.Run("非2D矩阵转置", func(t *testing.T) {
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2, 1})
_, err := matrix.Transpose()
if err == nil {
t.Error("期望有错误,但没有错误返回")
}
})
}
// TestCopy 测试矩阵复制
func TestCopy(t *testing.T) {
matrix, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
copied := matrix.Copy()
// 检查复制的矩阵是否相等
if !matrix.Equal(copied) {
t.Error("复制的矩阵应该与原矩阵相等")
}
// 修改复制的矩阵不应该影响原矩阵
copied.Set(99, 0, 0)
originalValue, _ := matrix.Get(0, 0)
if originalValue == 99 {
t.Error("修改复制的矩阵不应该影响原矩阵")
}
}
// TestShapeAndSize 测试Shape和Size方法
func TestShapeAndSize(t *testing.T) {
matrix, _ := NewMatrix([]float64{1, 2, 3, 4, 5, 6}, []int{2, 3})
shape := matrix.Shape()
if !reflect.DeepEqual(shape, []int{2, 3}) {
t.Errorf("期望形状 [2,3],但得到 %v", shape)
}
size := matrix.Size()
if size != 6 {
t.Errorf("期望大小为6但得到 %d", size)
}
}
// TestEqual 测试Equal方法
func TestEqual(t *testing.T) {
matrix1, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
matrix2, _ := NewMatrix([]float64{1, 2, 3, 4}, []int{2, 2})
matrix3, _ := NewMatrix([]float64{1, 2, 3, 5}, []int{2, 2})
if !matrix1.Equal(matrix2) {
t.Error("两个相同矩阵应该相等")
}
if matrix1.Equal(matrix3) {
t.Error("两个不同矩阵不应该相等")
}
}