gotensor/examples/autograd/autograd_example.go

82 lines
2.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"fmt"
"git.kingecg.top/kingecg/gotensor"
)
func main() {
fmt.Println("=== 自动微分和反向传播示例 ===")
// 创建一些张量,用于模拟简单的计算图
// 例如: z = (w * x) + b其中w, x, b是张量
w_data := []float64{2.0}
w, err := gotensor.NewTensor(w_data, []int{1})
if err != nil {
panic(err)
}
x_data := []float64{3.0}
x, err := gotensor.NewTensor(x_data, []int{1})
if err != nil {
panic(err)
}
b_data := []float64{1.0}
b, err := gotensor.NewTensor(b_data, []int{1})
if err != nil {
panic(err)
}
fmt.Printf("w = %s\n", w.String())
fmt.Printf("x = %s\n", x.String())
fmt.Printf("b = %s\n", b.String())
// 计算 y = w * x
y, err := w.Multiply(x)
if err != nil {
panic(err)
}
fmt.Printf("y = w * x = %s\n", y.String())
// 计算 z = y + b
z, err := y.Add(b)
if err != nil {
panic(err)
}
fmt.Printf("z = y + b = %s\n", z.String())
// 设置z的梯度为1 (通常这是损失函数所以梯度是1)
// 在实际应用中,这通常是损失函数的梯度
z.ZeroGrad() // 确保梯度初始化为零
// 在这里我们直接操作梯度将输出节点的梯度设为1
// 为了演示,我们手动设置输出梯度
// 由于z是最终输出我们设置其梯度为1
// 在实际自动微分中,我们会从最终输出开始反向传播
// 手动设置输出梯度为1模拟损失函数的梯度
one, _ := gotensor.NewOnes(z.Shape())
z.Grad = one.Data
// 执行反向传播
z.Backward()
fmt.Printf("反向传播后:\n")
fmt.Printf("w的梯度: %s\n", w.String()) // 这将显示w的梯度
fmt.Printf("x的梯度: %s\n", x.String()) // 这将显示x的梯度
fmt.Printf("b的梯度: %s\n", b.String()) // 这将显示b的梯度
// 验证梯度计算
// 对于 z = w*x + b:
// dz/dw = x = 3
// dz/dx = w = 2
// dz/db = 1
fmt.Println("\n=== 梯度验证 ===")
w_grad, _ := w.Grad.Get(0)
x_grad, _ := x.Grad.Get(0)
b_grad, _ := b.Grad.Get(0)
fmt.Printf("w的梯度 (预期x=3.0): %.2f\n", w_grad)
fmt.Printf("x的梯度 (预期w=2.0): %.2f\n", x_grad)
fmt.Printf("b的梯度 (预期1.0): %.2f\n", b_grad)
}