gotensor/examples/linear_regression/linear_regression.go

163 lines
3.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"fmt"
"git.kingecg.top/kingecg/gotensor"
"math"
)
func main() {
fmt.Println("=== 线性回归示例 ===")
// 创建简单的线性回归数据: y = 2*x + 3
// X = [[1, x1], [1, x2], ...] (添加偏置项)
// w = [b, w1] (偏置和权重)
// 准备训练数据
x_data := []float64{1.0, 2.0, 3.0, 4.0, 5.0}
X, err := gotensor.NewTensor(x_data, []int{5, 1})
if err != nil {
panic(err)
}
// 创建设计矩阵,添加偏置列
X_with_bias, err := addBiasColumn(X)
if err != nil {
panic(err)
}
// 真实的y值: y = 2*x + 3
y_data := []float64{5.0, 7.0, 9.0, 11.0, 13.0} // 2*x + 3
Y, err := gotensor.NewTensor(y_data, []int{5, 1})
if err != nil {
panic(err)
}
// 初始化权重 [b, w] = [0, 0]
w_data := []float64{0.0, 0.0}
W, err := gotensor.NewTensor(w_data, []int{2, 1})
if err != nil {
panic(err)
}
fmt.Printf("训练数据 X (带偏置):\n%s\n", X_with_bias.String())
fmt.Printf("目标值 Y:\n%s\n", Y.String())
fmt.Printf("初始权重 W:\n%s\n", W.String())
// 训练参数
learning_rate := 0.01
epochs := 100
for epoch := 0; epoch < epochs; epoch++ {
// 前向传播: Y_pred = X * W
Y_pred, err := X_with_bias.MatMul(W)
if err != nil {
panic(err)
}
// 计算损失: MSE = (1/n) * sum((Y_pred - Y)^2)
diff, err := Y_pred.Subtract(Y)
if err != nil {
panic(err)
}
// 计算平方差
squared_diff, err := diff.Multiply(diff)
if err != nil {
panic(err)
}
// 计算平均值(简单地用一个常数缩放)
// 在完整实现中,我们需要一个求平均的函数,这里简化处理
loss_val := sumTensor(squared_diff) / float64(squared_diff.Size())
if epoch % 20 == 0 {
fmt.Printf("Epoch %d, Loss: %.6f\n", epoch, loss_val)
}
// 计算梯度
// MSE梯度: dL/dW = (2/n) * X^T * (Y_pred - Y)
X_transpose, err := X_with_bias.Data.Transpose()
if err != nil {
panic(err)
}
X_T := &gotensor.Tensor{Data: X_transpose}
grad_W_intermediate, err := X_T.MatMul(diff)
if err != nil {
panic(err)
}
grad_W := grad_W_intermediate.Scale(2.0 / float64(X_with_bias.Size()))
// 更新权重: W = W - lr * grad_W
grad_update := grad_W.Scale(learning_rate)
W, err = W.Subtract(grad_update)
if err != nil {
panic(err)
}
}
fmt.Printf("\n训练后的权重 W:\n%s\n", W.String())
// 预测
Y_pred, err := X_with_bias.MatMul(W)
if err != nil {
panic(err)
}
fmt.Printf("预测值:\n%s\n", Y_pred.String())
fmt.Printf("真实值:\n%s\n", Y.String())
// 计算最终误差
diff, err := Y_pred.Subtract(Y)
if err != nil {
panic(err)
}
squared_diff, err := diff.Multiply(diff)
if err != nil {
panic(err)
}
final_loss := sumTensor(squared_diff) / float64(squared_diff.Size())
fmt.Printf("最终损失: %.6f\n", final_loss)
}
// 辅助函数:为输入矩阵添加偏置列
func addBiasColumn(X *gotensor.Tensor) (*gotensor.Tensor, error) {
shape := X.Shape()
rows := shape[0]
// 这里需要实现列拼接,暂时使用简单的拼接方法
// 实际上我们需要一个拼接函数,这里用一个近似的方法
// 由于当前API没有提供拼接功能我们手动构造
result_data := make([]float64, rows*2)
for i := 0; i < rows; i++ {
result_data[i*2] = 1.0 // 偏置项
X_val, _ := X.Data.Get(i, 0)
result_data[i*2+1] = X_val
}
return gotensor.NewTensor(result_data, []int{rows, 2})
}
// 辅助函数:计算张量所有元素的和
func sumTensor(t *gotensor.Tensor) float64 {
sum := 0.0
shape := t.Shape()
if len(shape) == 2 {
rows, cols := shape[0], shape[1]
for i := 0; i < rows; i++ {
for j := 0; j < cols; j++ {
val, _ := t.Data.Get(i, j)
sum += val
}
}
} else {
size := t.Size()
for i := 0; i < size; i++ {
val, _ := t.Data.Get(i)
sum += val
}
}
return math.Abs(sum) // 使用绝对值避免负数和
}