feat(matrix): 添加矩阵结构体字段注释和函数参数说明

- 为Matrix结构体添加中文注释,解释data、shape、strides、mdim、size字段含义
- 为NewMatrix、NewZeros、NewOnes、NewIdentity等构造函数添加参数说明注释
- 为Get、Set、Add、Subtract、Multiply、MatMul、Scale等方法添加参数说明注释
- 修复Copy方法中strides复制的错误,将m.shape改为m.strides
```
This commit is contained in:
kingecg 2025-12-30 22:19:44 +08:00
parent b3b9017dd9
commit 2ff4dcfb0f
1 changed files with 20 additions and 6 deletions

View File

@ -5,15 +5,18 @@ import (
"fmt"
)
// Matrix 表示一个多维矩阵
type Matrix struct {
data []float64
shape []int
strides []int
mdim int
size int
data []float64 // 存储矩阵的实际数值
shape []int // 表示矩阵的形状,例如[2, 3]表示2行3列的矩阵
strides []int // 步长,用于多维数组访问
mdim int // 矩阵的维度数
size int // 矩阵元素总数
}
// NewMatrix 创建一个新的矩阵
// data: 矩阵数据
// shape: 矩阵形状,例如[2, 3]表示2行3列
func NewMatrix(data []float64, shape []int) (*Matrix, error) {
if len(data) == 0 || len(shape) == 0 {
return nil, errors.New("data and shape cannot be empty")
@ -50,6 +53,7 @@ func NewMatrix(data []float64, shape []int) (*Matrix, error) {
}
// NewZeros 创建一个全零矩阵
// shape: 矩阵形状
func NewZeros(shape []int) (*Matrix, error) {
expectedSize := 1
for _, dim := range shape {
@ -64,6 +68,7 @@ func NewZeros(shape []int) (*Matrix, error) {
}
// NewOnes 创建一个全一矩阵
// shape: 矩阵形状
func NewOnes(shape []int) (*Matrix, error) {
expectedSize := 1
for _, dim := range shape {
@ -81,6 +86,7 @@ func NewOnes(shape []int) (*Matrix, error) {
}
// NewIdentity 创建一个单位矩阵
// size: 单位矩阵的大小
func NewIdentity(size int) (*Matrix, error) {
if size <= 0 {
return nil, errors.New("size must be positive")
@ -95,6 +101,7 @@ func NewIdentity(size int) (*Matrix, error) {
}
// Get 获取指定位置的值
// indices: 位置索引
func (m *Matrix) Get(indices ...int) (float64, error) {
if len(indices) != m.mdim {
return 0, errors.New("number of indices must match matrix dimensions")
@ -115,6 +122,8 @@ func (m *Matrix) Get(indices ...int) (float64, error) {
}
// Set 设置指定位置的值
// value: 要设置的值
// indices: 位置索引
func (m *Matrix) Set(value float64, indices ...int) error {
if len(indices) != m.mdim {
return errors.New("number of indices must match matrix dimensions")
@ -136,6 +145,7 @@ func (m *Matrix) Set(value float64, indices ...int) error {
}
// Add 矩阵加法
// other: 要相加的另一个矩阵
func (m *Matrix) Add(other *Matrix) (*Matrix, error) {
if !m.hasSameShape(other) {
return nil, errors.New("matrices must have the same shape for addition")
@ -150,6 +160,7 @@ func (m *Matrix) Add(other *Matrix) (*Matrix, error) {
}
// Subtract 矩阵减法
// other: 要相减的另一个矩阵
func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) {
if !m.hasSameShape(other) {
return nil, errors.New("matrices must have the same shape for subtraction")
@ -164,6 +175,7 @@ func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) {
}
// Multiply 矩阵乘法(逐元素相乘)
// other: 要相乘的另一个矩阵
func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) {
if !m.hasSameShape(other) {
return nil, errors.New("matrices must have the same shape for element-wise multiplication")
@ -178,6 +190,7 @@ func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) {
}
// MatMul 矩阵点乘(线性代数乘法)
// other: 要相乘的另一个矩阵
func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) {
if m.mdim != 2 || other.mdim != 2 {
return nil, errors.New("matmul only supported for 2D matrices")
@ -207,6 +220,7 @@ func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) {
}
// Scale 矩阵数乘
// factor: 缩放因子
func (m *Matrix) Scale(factor float64) *Matrix {
resultData := make([]float64, m.size)
for i := 0; i < m.size; i++ {
@ -249,7 +263,7 @@ func (m *Matrix) Copy() *Matrix {
copy(shapeCopy, m.shape)
stridesCopy := make([]int, m.mdim)
copy(stridesCopy, m.strides)
copy(stridesCopy, m.shape)
return &Matrix{
data: dataCopy,