82 lines
2.1 KiB
Go
82 lines
2.1 KiB
Go
package main
|
||
|
||
import (
|
||
"fmt"
|
||
"git.kingecg.top/kingecg/gotensor"
|
||
)
|
||
|
||
func main() {
|
||
fmt.Println("=== 自动微分和反向传播示例 ===")
|
||
|
||
// 创建一些张量,用于模拟简单的计算图
|
||
// 例如: z = (w * x) + b,其中w, x, b是张量
|
||
w_data := []float64{2.0}
|
||
w, err := gotensor.NewTensor(w_data, []int{1})
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
x_data := []float64{3.0}
|
||
x, err := gotensor.NewTensor(x_data, []int{1})
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
b_data := []float64{1.0}
|
||
b, err := gotensor.NewTensor(b_data, []int{1})
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
fmt.Printf("w = %s\n", w.String())
|
||
fmt.Printf("x = %s\n", x.String())
|
||
fmt.Printf("b = %s\n", b.String())
|
||
|
||
// 计算 y = w * x
|
||
y, err := w.Multiply(x)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
fmt.Printf("y = w * x = %s\n", y.String())
|
||
|
||
// 计算 z = y + b
|
||
z, err := y.Add(b)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
fmt.Printf("z = y + b = %s\n", z.String())
|
||
|
||
// 设置z的梯度为1 (通常这是损失函数,所以梯度是1)
|
||
// 在实际应用中,这通常是损失函数的梯度
|
||
z.ZeroGrad() // 确保梯度初始化为零
|
||
// 在这里,我们直接操作梯度,将输出节点的梯度设为1
|
||
// 为了演示,我们手动设置输出梯度
|
||
// 由于z是最终输出,我们设置其梯度为1
|
||
|
||
// 在实际自动微分中,我们会从最终输出开始反向传播
|
||
// 手动设置输出梯度为1,模拟损失函数的梯度
|
||
one, _ := gotensor.NewOnes(z.Shape())
|
||
z.Grad = one.Data
|
||
|
||
// 执行反向传播
|
||
z.Backward()
|
||
|
||
fmt.Printf("反向传播后:\n")
|
||
fmt.Printf("w的梯度: %s\n", w.String()) // 这将显示w的梯度
|
||
fmt.Printf("x的梯度: %s\n", x.String()) // 这将显示x的梯度
|
||
fmt.Printf("b的梯度: %s\n", b.String()) // 这将显示b的梯度
|
||
|
||
// 验证梯度计算
|
||
// 对于 z = w*x + b:
|
||
// dz/dw = x = 3
|
||
// dz/dx = w = 2
|
||
// dz/db = 1
|
||
fmt.Println("\n=== 梯度验证 ===")
|
||
w_grad, _ := w.Grad.Get(0)
|
||
x_grad, _ := x.Grad.Get(0)
|
||
b_grad, _ := b.Grad.Get(0)
|
||
|
||
fmt.Printf("w的梯度 (预期x=3.0): %.2f\n", w_grad)
|
||
fmt.Printf("x的梯度 (预期w=2.0): %.2f\n", x_grad)
|
||
fmt.Printf("b的梯度 (预期1.0): %.2f\n", b_grad)
|
||
} |