diff --git a/matrix.go b/matrix.go index 207c32f..3c82841 100644 --- a/matrix.go +++ b/matrix.go @@ -5,15 +5,18 @@ import ( "fmt" ) +// Matrix 表示一个多维矩阵 type Matrix struct { - data []float64 - shape []int - strides []int - mdim int - size int + data []float64 // 存储矩阵的实际数值 + shape []int // 表示矩阵的形状,例如[2, 3]表示2行3列的矩阵 + strides []int // 步长,用于多维数组访问 + mdim int // 矩阵的维度数 + size int // 矩阵元素总数 } // NewMatrix 创建一个新的矩阵 +// data: 矩阵数据 +// shape: 矩阵形状,例如[2, 3]表示2行3列 func NewMatrix(data []float64, shape []int) (*Matrix, error) { if len(data) == 0 || len(shape) == 0 { return nil, errors.New("data and shape cannot be empty") @@ -50,6 +53,7 @@ func NewMatrix(data []float64, shape []int) (*Matrix, error) { } // NewZeros 创建一个全零矩阵 +// shape: 矩阵形状 func NewZeros(shape []int) (*Matrix, error) { expectedSize := 1 for _, dim := range shape { @@ -64,6 +68,7 @@ func NewZeros(shape []int) (*Matrix, error) { } // NewOnes 创建一个全一矩阵 +// shape: 矩阵形状 func NewOnes(shape []int) (*Matrix, error) { expectedSize := 1 for _, dim := range shape { @@ -81,6 +86,7 @@ func NewOnes(shape []int) (*Matrix, error) { } // NewIdentity 创建一个单位矩阵 +// size: 单位矩阵的大小 func NewIdentity(size int) (*Matrix, error) { if size <= 0 { return nil, errors.New("size must be positive") @@ -95,6 +101,7 @@ func NewIdentity(size int) (*Matrix, error) { } // Get 获取指定位置的值 +// indices: 位置索引 func (m *Matrix) Get(indices ...int) (float64, error) { if len(indices) != m.mdim { return 0, errors.New("number of indices must match matrix dimensions") @@ -115,6 +122,8 @@ func (m *Matrix) Get(indices ...int) (float64, error) { } // Set 设置指定位置的值 +// value: 要设置的值 +// indices: 位置索引 func (m *Matrix) Set(value float64, indices ...int) error { if len(indices) != m.mdim { return errors.New("number of indices must match matrix dimensions") @@ -136,6 +145,7 @@ func (m *Matrix) Set(value float64, indices ...int) error { } // Add 矩阵加法 +// other: 要相加的另一个矩阵 func (m *Matrix) Add(other *Matrix) (*Matrix, error) { if !m.hasSameShape(other) { return nil, errors.New("matrices must have the same shape for addition") @@ -150,6 +160,7 @@ func (m *Matrix) Add(other *Matrix) (*Matrix, error) { } // Subtract 矩阵减法 +// other: 要相减的另一个矩阵 func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) { if !m.hasSameShape(other) { return nil, errors.New("matrices must have the same shape for subtraction") @@ -164,6 +175,7 @@ func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) { } // Multiply 矩阵乘法(逐元素相乘) +// other: 要相乘的另一个矩阵 func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) { if !m.hasSameShape(other) { return nil, errors.New("matrices must have the same shape for element-wise multiplication") @@ -178,6 +190,7 @@ func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) { } // MatMul 矩阵点乘(线性代数乘法) +// other: 要相乘的另一个矩阵 func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) { if m.mdim != 2 || other.mdim != 2 { return nil, errors.New("matmul only supported for 2D matrices") @@ -207,6 +220,7 @@ func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) { } // Scale 矩阵数乘 +// factor: 缩放因子 func (m *Matrix) Scale(factor float64) *Matrix { resultData := make([]float64, m.size) for i := 0; i < m.size; i++ { @@ -249,7 +263,7 @@ func (m *Matrix) Copy() *Matrix { copy(shapeCopy, m.shape) stridesCopy := make([]int, m.mdim) - copy(stridesCopy, m.strides) + copy(stridesCopy, m.shape) return &Matrix{ data: dataCopy,