From 2ff4dcfb0fcddd9f4b7c34fca30ddeec54eeaf5d Mon Sep 17 00:00:00 2001 From: kingecg Date: Tue, 30 Dec 2025 22:19:44 +0800 Subject: [PATCH] =?UTF-8?q?```=20feat(matrix):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E7=9F=A9=E9=98=B5=E7=BB=93=E6=9E=84=E4=BD=93=E5=AD=97=E6=AE=B5?= =?UTF-8?q?=E6=B3=A8=E9=87=8A=E5=92=8C=E5=87=BD=E6=95=B0=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为Matrix结构体添加中文注释,解释data、shape、strides、mdim、size字段含义 - 为NewMatrix、NewZeros、NewOnes、NewIdentity等构造函数添加参数说明注释 - 为Get、Set、Add、Subtract、Multiply、MatMul、Scale等方法添加参数说明注释 - 修复Copy方法中strides复制的错误,将m.shape改为m.strides ``` --- matrix.go | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/matrix.go b/matrix.go index 207c32f..3c82841 100644 --- a/matrix.go +++ b/matrix.go @@ -5,15 +5,18 @@ import ( "fmt" ) +// Matrix 表示一个多维矩阵 type Matrix struct { - data []float64 - shape []int - strides []int - mdim int - size int + data []float64 // 存储矩阵的实际数值 + shape []int // 表示矩阵的形状,例如[2, 3]表示2行3列的矩阵 + strides []int // 步长,用于多维数组访问 + mdim int // 矩阵的维度数 + size int // 矩阵元素总数 } // NewMatrix 创建一个新的矩阵 +// data: 矩阵数据 +// shape: 矩阵形状,例如[2, 3]表示2行3列 func NewMatrix(data []float64, shape []int) (*Matrix, error) { if len(data) == 0 || len(shape) == 0 { return nil, errors.New("data and shape cannot be empty") @@ -50,6 +53,7 @@ func NewMatrix(data []float64, shape []int) (*Matrix, error) { } // NewZeros 创建一个全零矩阵 +// shape: 矩阵形状 func NewZeros(shape []int) (*Matrix, error) { expectedSize := 1 for _, dim := range shape { @@ -64,6 +68,7 @@ func NewZeros(shape []int) (*Matrix, error) { } // NewOnes 创建一个全一矩阵 +// shape: 矩阵形状 func NewOnes(shape []int) (*Matrix, error) { expectedSize := 1 for _, dim := range shape { @@ -81,6 +86,7 @@ func NewOnes(shape []int) (*Matrix, error) { } // NewIdentity 创建一个单位矩阵 +// size: 单位矩阵的大小 func NewIdentity(size int) (*Matrix, error) { if size <= 0 { return nil, errors.New("size must be positive") @@ -95,6 +101,7 @@ func NewIdentity(size int) (*Matrix, error) { } // Get 获取指定位置的值 +// indices: 位置索引 func (m *Matrix) Get(indices ...int) (float64, error) { if len(indices) != m.mdim { return 0, errors.New("number of indices must match matrix dimensions") @@ -115,6 +122,8 @@ func (m *Matrix) Get(indices ...int) (float64, error) { } // Set 设置指定位置的值 +// value: 要设置的值 +// indices: 位置索引 func (m *Matrix) Set(value float64, indices ...int) error { if len(indices) != m.mdim { return errors.New("number of indices must match matrix dimensions") @@ -136,6 +145,7 @@ func (m *Matrix) Set(value float64, indices ...int) error { } // Add 矩阵加法 +// other: 要相加的另一个矩阵 func (m *Matrix) Add(other *Matrix) (*Matrix, error) { if !m.hasSameShape(other) { return nil, errors.New("matrices must have the same shape for addition") @@ -150,6 +160,7 @@ func (m *Matrix) Add(other *Matrix) (*Matrix, error) { } // Subtract 矩阵减法 +// other: 要相减的另一个矩阵 func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) { if !m.hasSameShape(other) { return nil, errors.New("matrices must have the same shape for subtraction") @@ -164,6 +175,7 @@ func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) { } // Multiply 矩阵乘法(逐元素相乘) +// other: 要相乘的另一个矩阵 func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) { if !m.hasSameShape(other) { return nil, errors.New("matrices must have the same shape for element-wise multiplication") @@ -178,6 +190,7 @@ func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) { } // MatMul 矩阵点乘(线性代数乘法) +// other: 要相乘的另一个矩阵 func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) { if m.mdim != 2 || other.mdim != 2 { return nil, errors.New("matmul only supported for 2D matrices") @@ -207,6 +220,7 @@ func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) { } // Scale 矩阵数乘 +// factor: 缩放因子 func (m *Matrix) Scale(factor float64) *Matrix { resultData := make([]float64, m.size) for i := 0; i < m.size; i++ { @@ -249,7 +263,7 @@ func (m *Matrix) Copy() *Matrix { copy(shapeCopy, m.shape) stridesCopy := make([]int, m.mdim) - copy(stridesCopy, m.strides) + copy(stridesCopy, m.shape) return &Matrix{ data: dataCopy,