gotensor/examples/basic_operation/basic_operations.go

72 lines
1.5 KiB
Go

package main
import (
"fmt"
"git.kingecg.top/kingecg/gotensor"
)
func main() {
fmt.Println("=== 基本运算示例 ===")
// 创建两个2x2的张量
t1_data := []float64{1, 2, 3, 4}
t1_shape := []int{2, 2}
t1, err := gotensor.NewTensor(t1_data, t1_shape)
if err != nil {
panic(err)
}
t2_data := []float64{5, 6, 7, 8}
t2, err := gotensor.NewTensor(t2_data, t1_shape)
if err != nil {
panic(err)
}
fmt.Printf("张量1:\n%s\n", t1.String())
fmt.Printf("张量2:\n%s\n", t2.String())
// 加法运算
add_result, err := t1.Add(t2)
if err != nil {
panic(err)
}
fmt.Printf("加法结果 (t1 + t2):\n%s\n", add_result.String())
// 减法运算
sub_result, err := t1.Subtract(t2)
if err != nil {
panic(err)
}
fmt.Printf("减法结果 (t1 - t2):\n%s\n", sub_result.String())
// 逐元素乘法
mul_result, err := t1.Multiply(t2)
if err != nil {
panic(err)
}
fmt.Printf("逐元素乘法结果 (t1 * t2):\n%s\n", mul_result.String())
// 数乘
scale_result := t1.Scale(2.0)
fmt.Printf("数乘结果 (t1 * 2):\n%s\n", scale_result.String())
// 矩阵乘法
matmul_result, err := t1.MatMul(t2)
if err != nil {
panic(err)
}
fmt.Printf("矩阵乘法结果 (t1 @ t2):\n%s\n", matmul_result.String())
// 创建零张量和单位矩阵
zeros, err := gotensor.NewZeros([]int{2, 3})
if err != nil {
panic(err)
}
fmt.Printf("2x3零张量:\n%s\n", zeros.String())
identity, err := gotensor.NewIdentity(3)
if err != nil {
panic(err)
}
fmt.Printf("3x3单位矩阵:\n%s\n", identity.String())
}