```
feat(matrix): 添加矩阵结构体字段注释和函数参数说明 - 为Matrix结构体添加中文注释,解释data、shape、strides、mdim、size字段含义 - 为NewMatrix、NewZeros、NewOnes、NewIdentity等构造函数添加参数说明注释 - 为Get、Set、Add、Subtract、Multiply、MatMul、Scale等方法添加参数说明注释 - 修复Copy方法中strides复制的错误,将m.shape改为m.strides ```
This commit is contained in:
parent
b3b9017dd9
commit
2ff4dcfb0f
26
matrix.go
26
matrix.go
|
|
@ -5,15 +5,18 @@ import (
|
|||
"fmt"
|
||||
)
|
||||
|
||||
// Matrix 表示一个多维矩阵
|
||||
type Matrix struct {
|
||||
data []float64
|
||||
shape []int
|
||||
strides []int
|
||||
mdim int
|
||||
size int
|
||||
data []float64 // 存储矩阵的实际数值
|
||||
shape []int // 表示矩阵的形状,例如[2, 3]表示2行3列的矩阵
|
||||
strides []int // 步长,用于多维数组访问
|
||||
mdim int // 矩阵的维度数
|
||||
size int // 矩阵元素总数
|
||||
}
|
||||
|
||||
// NewMatrix 创建一个新的矩阵
|
||||
// data: 矩阵数据
|
||||
// shape: 矩阵形状,例如[2, 3]表示2行3列
|
||||
func NewMatrix(data []float64, shape []int) (*Matrix, error) {
|
||||
if len(data) == 0 || len(shape) == 0 {
|
||||
return nil, errors.New("data and shape cannot be empty")
|
||||
|
|
@ -50,6 +53,7 @@ func NewMatrix(data []float64, shape []int) (*Matrix, error) {
|
|||
}
|
||||
|
||||
// NewZeros 创建一个全零矩阵
|
||||
// shape: 矩阵形状
|
||||
func NewZeros(shape []int) (*Matrix, error) {
|
||||
expectedSize := 1
|
||||
for _, dim := range shape {
|
||||
|
|
@ -64,6 +68,7 @@ func NewZeros(shape []int) (*Matrix, error) {
|
|||
}
|
||||
|
||||
// NewOnes 创建一个全一矩阵
|
||||
// shape: 矩阵形状
|
||||
func NewOnes(shape []int) (*Matrix, error) {
|
||||
expectedSize := 1
|
||||
for _, dim := range shape {
|
||||
|
|
@ -81,6 +86,7 @@ func NewOnes(shape []int) (*Matrix, error) {
|
|||
}
|
||||
|
||||
// NewIdentity 创建一个单位矩阵
|
||||
// size: 单位矩阵的大小
|
||||
func NewIdentity(size int) (*Matrix, error) {
|
||||
if size <= 0 {
|
||||
return nil, errors.New("size must be positive")
|
||||
|
|
@ -95,6 +101,7 @@ func NewIdentity(size int) (*Matrix, error) {
|
|||
}
|
||||
|
||||
// Get 获取指定位置的值
|
||||
// indices: 位置索引
|
||||
func (m *Matrix) Get(indices ...int) (float64, error) {
|
||||
if len(indices) != m.mdim {
|
||||
return 0, errors.New("number of indices must match matrix dimensions")
|
||||
|
|
@ -115,6 +122,8 @@ func (m *Matrix) Get(indices ...int) (float64, error) {
|
|||
}
|
||||
|
||||
// Set 设置指定位置的值
|
||||
// value: 要设置的值
|
||||
// indices: 位置索引
|
||||
func (m *Matrix) Set(value float64, indices ...int) error {
|
||||
if len(indices) != m.mdim {
|
||||
return errors.New("number of indices must match matrix dimensions")
|
||||
|
|
@ -136,6 +145,7 @@ func (m *Matrix) Set(value float64, indices ...int) error {
|
|||
}
|
||||
|
||||
// Add 矩阵加法
|
||||
// other: 要相加的另一个矩阵
|
||||
func (m *Matrix) Add(other *Matrix) (*Matrix, error) {
|
||||
if !m.hasSameShape(other) {
|
||||
return nil, errors.New("matrices must have the same shape for addition")
|
||||
|
|
@ -150,6 +160,7 @@ func (m *Matrix) Add(other *Matrix) (*Matrix, error) {
|
|||
}
|
||||
|
||||
// Subtract 矩阵减法
|
||||
// other: 要相减的另一个矩阵
|
||||
func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) {
|
||||
if !m.hasSameShape(other) {
|
||||
return nil, errors.New("matrices must have the same shape for subtraction")
|
||||
|
|
@ -164,6 +175,7 @@ func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) {
|
|||
}
|
||||
|
||||
// Multiply 矩阵乘法(逐元素相乘)
|
||||
// other: 要相乘的另一个矩阵
|
||||
func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) {
|
||||
if !m.hasSameShape(other) {
|
||||
return nil, errors.New("matrices must have the same shape for element-wise multiplication")
|
||||
|
|
@ -178,6 +190,7 @@ func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) {
|
|||
}
|
||||
|
||||
// MatMul 矩阵点乘(线性代数乘法)
|
||||
// other: 要相乘的另一个矩阵
|
||||
func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) {
|
||||
if m.mdim != 2 || other.mdim != 2 {
|
||||
return nil, errors.New("matmul only supported for 2D matrices")
|
||||
|
|
@ -207,6 +220,7 @@ func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) {
|
|||
}
|
||||
|
||||
// Scale 矩阵数乘
|
||||
// factor: 缩放因子
|
||||
func (m *Matrix) Scale(factor float64) *Matrix {
|
||||
resultData := make([]float64, m.size)
|
||||
for i := 0; i < m.size; i++ {
|
||||
|
|
@ -249,7 +263,7 @@ func (m *Matrix) Copy() *Matrix {
|
|||
copy(shapeCopy, m.shape)
|
||||
|
||||
stridesCopy := make([]int, m.mdim)
|
||||
copy(stridesCopy, m.strides)
|
||||
copy(stridesCopy, m.shape)
|
||||
|
||||
return &Matrix{
|
||||
data: dataCopy,
|
||||
|
|
|
|||
Loading…
Reference in New Issue