201 lines
4.5 KiB
Go
201 lines
4.5 KiB
Go
package gotensor
|
||
|
||
import (
|
||
"testing"
|
||
)
|
||
|
||
func TestSigmoid(t *testing.T) {
|
||
data := []float64{0.0, 1.0, -1.0, 2.0}
|
||
shape := []int{2, 2}
|
||
|
||
tensor, err := NewTensor(data, shape)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
result := tensor.Sigmoid()
|
||
|
||
// 验证输出值是否在0和1之间
|
||
for i := 0; i < result.Size(); i++ {
|
||
val, err := result.Get(i/2, i%2)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if val < 0.0 || val > 1.0 {
|
||
t.Errorf("Sigmoid输出应在[0,1]范围内,实际值: %f", val)
|
||
}
|
||
}
|
||
|
||
// 特定值检查
|
||
expected_values := []float64{0.5, 0.7310585786300049, 0.2689414213699951, 0.8807970779778823}
|
||
for i, expected := range expected_values {
|
||
actual, err := result.Get(i/2, i%2)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if actual < expected-0.001 || actual > expected+0.001 {
|
||
t.Errorf("位置[%d,%d]的Sigmoid值不正确,期望: %f, 实际: %f", i/2, i%2, expected, actual)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestReLU(t *testing.T) {
|
||
data := []float64{-1.0, 0.0, 1.0, 2.0}
|
||
shape := []int{2, 2}
|
||
|
||
tensor, err := NewTensor(data, shape)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
result := tensor.ReLU()
|
||
|
||
// 验证ReLU的性质:负值变0,非负值保持不变
|
||
expected_values := []float64{0.0, 0.0, 1.0, 2.0}
|
||
for i, expected := range expected_values {
|
||
actual, err := result.Get(i/2, i%2)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if actual != expected {
|
||
t.Errorf("位置[%d,%d]的ReLU值不正确,期望: %f, 实际: %f", i/2, i%2, expected, actual)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestSoftmax(t *testing.T) {
|
||
data := []float64{1.0, 2.0, 3.0, 4.0}
|
||
shape := []int{2, 2}
|
||
|
||
tensor, err := NewTensor(data, shape)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
result := tensor.Softmax()
|
||
|
||
// 验证Softmax的性质:所有值在[0,1]之间,且每行之和为1
|
||
for i := 0; i < 2; i++ {
|
||
rowSum := 0.0
|
||
for j := 0; j < 2; j++ {
|
||
val, err := result.Get(i, j)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if val < 0.0 || val > 1.0 {
|
||
t.Errorf("Softmax输出应在[0,1]范围内,行%d列%d的值: %f", i, j, val)
|
||
}
|
||
|
||
rowSum += val
|
||
}
|
||
|
||
if rowSum < 0.99 || rowSum > 1.01 {
|
||
t.Errorf("行%d的Softmax值之和应为1,实际值: %f", i, rowSum)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestFlatten(t *testing.T) {
|
||
data := []float64{1, 2, 3, 4}
|
||
shape := []int{2, 2}
|
||
|
||
tensor, err := NewTensor(data, shape)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
flattened := tensor.Flatten()
|
||
|
||
// 验证展平后的形状
|
||
if flattened.Size() != 4 {
|
||
t.Errorf("展平后的大小应为4,实际值: %d", flattened.Size())
|
||
}
|
||
|
||
// 验证展平后的值
|
||
for i := 0; i < 4; i++ {
|
||
expected := float64(i + 1)
|
||
actual, err := flattened.Get(i)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if actual != expected {
|
||
t.Errorf("位置%d的值不正确,期望: %f, 实际: %f", i, expected, actual)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestMeanSquaredError(t *testing.T) {
|
||
data1 := []float64{1.0, 2.0, 3.0, 4.0}
|
||
data2 := []float64{2.0, 3.0, 4.0, 5.0}
|
||
shape := []int{2, 2}
|
||
|
||
tensor1, err := NewTensor(data1, shape)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
tensor2, err := NewTensor(data2, shape)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
mse := tensor1.MeanSquaredError(tensor2)
|
||
|
||
// 计算期望的MSE值
|
||
// ( (1-2)^2 + (2-3)^2 + (3-4)^2 + (4-5)^2 ) / 4 = (1+1+1+1)/4 = 1
|
||
expected_mse := 1.0
|
||
actual_mse, err := mse.Data.Get(0)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if actual_mse != expected_mse {
|
||
t.Errorf("MSE值不正确,期望: %f, 实际: %f", expected_mse, actual_mse)
|
||
}
|
||
}
|
||
|
||
func TestMaxPool2D(t *testing.T) {
|
||
// 创建一个简单的4x4输入,批大小为1,通道数为1
|
||
data := []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||
shape := []int{1, 1, 4, 4} // [batch, channel, height, width]
|
||
|
||
tensor, err := NewTensor(data, shape)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
// 2x2池化,步长为2
|
||
pooled, err := tensor.MaxPool2D(2, 2)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
// 验证输出形状
|
||
expected_shape := []int{1, 1, 2, 2}
|
||
actual_shape := pooled.Shape()
|
||
if len(actual_shape) != len(expected_shape) {
|
||
t.Errorf("池化后的形状不正确,期望长度: %d, 实际长度: %d", len(expected_shape), len(actual_shape))
|
||
}
|
||
|
||
// 检查池化结果是否正确
|
||
// 左上角: max(1,2,5,6) = 6
|
||
// 右上角: max(3,4,7,8) = 8
|
||
// 左下角: max(9,10,13,14) = 14
|
||
// 右下角: max(11,12,15,16) = 16
|
||
expected_vals := []float64{6, 8, 14, 16}
|
||
for i, expected := range expected_vals {
|
||
actual, err := pooled.Get(0, 0, i/2, i%2)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if actual != expected {
|
||
t.Errorf("位置[0,0,%d,%d]的池化值不正确,期望: %f, 实际: %f", i/2, i%2, expected, actual)
|
||
}
|
||
}
|
||
} |