gotensor/extended_tensor_test.go

201 lines
4.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package gotensor
import (
"testing"
)
func TestSigmoid(t *testing.T) {
data := []float64{0.0, 1.0, -1.0, 2.0}
shape := []int{2, 2}
tensor, err := NewTensor(data, shape)
if err != nil {
t.Fatal(err)
}
result := tensor.Sigmoid()
// 验证输出值是否在0和1之间
for i := 0; i < result.Size(); i++ {
val, err := result.Get(i/2, i%2)
if err != nil {
t.Fatal(err)
}
if val < 0.0 || val > 1.0 {
t.Errorf("Sigmoid输出应在[0,1]范围内,实际值: %f", val)
}
}
// 特定值检查
expected_values := []float64{0.5, 0.7310585786300049, 0.2689414213699951, 0.8807970779778823}
for i, expected := range expected_values {
actual, err := result.Get(i/2, i%2)
if err != nil {
t.Fatal(err)
}
if actual < expected-0.001 || actual > expected+0.001 {
t.Errorf("位置[%d,%d]的Sigmoid值不正确期望: %f, 实际: %f", i/2, i%2, expected, actual)
}
}
}
func TestReLU(t *testing.T) {
data := []float64{-1.0, 0.0, 1.0, 2.0}
shape := []int{2, 2}
tensor, err := NewTensor(data, shape)
if err != nil {
t.Fatal(err)
}
result := tensor.ReLU()
// 验证ReLU的性质负值变0非负值保持不变
expected_values := []float64{0.0, 0.0, 1.0, 2.0}
for i, expected := range expected_values {
actual, err := result.Get(i/2, i%2)
if err != nil {
t.Fatal(err)
}
if actual != expected {
t.Errorf("位置[%d,%d]的ReLU值不正确期望: %f, 实际: %f", i/2, i%2, expected, actual)
}
}
}
func TestSoftmax(t *testing.T) {
data := []float64{1.0, 2.0, 3.0, 4.0}
shape := []int{2, 2}
tensor, err := NewTensor(data, shape)
if err != nil {
t.Fatal(err)
}
result := tensor.Softmax()
// 验证Softmax的性质所有值在[0,1]之间且每行之和为1
for i := 0; i < 2; i++ {
rowSum := 0.0
for j := 0; j < 2; j++ {
val, err := result.Get(i, j)
if err != nil {
t.Fatal(err)
}
if val < 0.0 || val > 1.0 {
t.Errorf("Softmax输出应在[0,1]范围内,行%d列%d的值: %f", i, j, val)
}
rowSum += val
}
if rowSum < 0.99 || rowSum > 1.01 {
t.Errorf("行%d的Softmax值之和应为1实际值: %f", i, rowSum)
}
}
}
func TestFlatten(t *testing.T) {
data := []float64{1, 2, 3, 4}
shape := []int{2, 2}
tensor, err := NewTensor(data, shape)
if err != nil {
t.Fatal(err)
}
flattened := tensor.Flatten()
// 验证展平后的形状
if flattened.Size() != 4 {
t.Errorf("展平后的大小应为4实际值: %d", flattened.Size())
}
// 验证展平后的值
for i := 0; i < 4; i++ {
expected := float64(i + 1)
actual, err := flattened.Get(i)
if err != nil {
t.Fatal(err)
}
if actual != expected {
t.Errorf("位置%d的值不正确期望: %f, 实际: %f", i, expected, actual)
}
}
}
func TestMeanSquaredError(t *testing.T) {
data1 := []float64{1.0, 2.0, 3.0, 4.0}
data2 := []float64{2.0, 3.0, 4.0, 5.0}
shape := []int{2, 2}
tensor1, err := NewTensor(data1, shape)
if err != nil {
t.Fatal(err)
}
tensor2, err := NewTensor(data2, shape)
if err != nil {
t.Fatal(err)
}
mse := tensor1.MeanSquaredError(tensor2)
// 计算期望的MSE值
// ( (1-2)^2 + (2-3)^2 + (3-4)^2 + (4-5)^2 ) / 4 = (1+1+1+1)/4 = 1
expected_mse := 1.0
actual_mse, err := mse.Data.Get(0)
if err != nil {
t.Fatal(err)
}
if actual_mse != expected_mse {
t.Errorf("MSE值不正确期望: %f, 实际: %f", expected_mse, actual_mse)
}
}
func TestMaxPool2D(t *testing.T) {
// 创建一个简单的4x4输入批大小为1通道数为1
data := []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
shape := []int{1, 1, 4, 4} // [batch, channel, height, width]
tensor, err := NewTensor(data, shape)
if err != nil {
t.Fatal(err)
}
// 2x2池化步长为2
pooled, err := tensor.MaxPool2D(2, 2)
if err != nil {
t.Fatal(err)
}
// 验证输出形状
expected_shape := []int{1, 1, 2, 2}
actual_shape := pooled.Shape()
if len(actual_shape) != len(expected_shape) {
t.Errorf("池化后的形状不正确,期望长度: %d, 实际长度: %d", len(expected_shape), len(actual_shape))
}
// 检查池化结果是否正确
// 左上角: max(1,2,5,6) = 6
// 右上角: max(3,4,7,8) = 8
// 左下角: max(9,10,13,14) = 14
// 右下角: max(11,12,15,16) = 16
expected_vals := []float64{6, 8, 14, 16}
for i, expected := range expected_vals {
actual, err := pooled.Get(0, 0, i/2, i%2)
if err != nil {
t.Fatal(err)
}
if actual != expected {
t.Errorf("位置[0,0,%d,%d]的池化值不正确,期望: %f, 实际: %f", i/2, i%2, expected, actual)
}
}
}