package gomatrix import ( "errors" "fmt" ) // Matrix 表示一个多维矩阵 type Matrix struct { data []float64 // 存储矩阵的实际数值 shape []int // 表示矩阵的形状,例如[2, 3]表示2行3列的矩阵 strides []int // 步长,用于多维数组访问 mdim int // 矩阵的维度数 size int // 矩阵元素总数 } // NewMatrix 创建一个新的矩阵 // data: 矩阵数据 // shape: 矩阵形状,例如[2, 3]表示2行3列 func NewMatrix(data []float64, shape []int) (*Matrix, error) { if len(data) == 0 || len(shape) == 0 { return nil, errors.New("data and shape cannot be empty") } // 计算期望的大小 expectedSize := 1 for _, dim := range shape { if dim <= 0 { return nil, errors.New("all dimensions must be positive") } expectedSize *= dim } if len(data) != expectedSize { return nil, fmt.Errorf("data length %d does not match expected size %d for given shape", len(data), expectedSize) } // 计算步长 strides := make([]int, len(shape)) stride := 1 for i := len(shape) - 1; i >= 0; i-- { strides[i] = stride stride *= shape[i] } return &Matrix{ data: data, shape: shape, strides: strides, mdim: len(shape), size: expectedSize, }, nil } // NewZeros 创建一个全零矩阵 // shape: 矩阵形状 func NewZeros(shape []int) (*Matrix, error) { expectedSize := 1 for _, dim := range shape { if dim <= 0 { return nil, errors.New("all dimensions must be positive") } expectedSize *= dim } data := make([]float64, expectedSize) return NewMatrix(data, shape) } // NewOnes 创建一个全一矩阵 // shape: 矩阵形状 func NewOnes(shape []int) (*Matrix, error) { expectedSize := 1 for _, dim := range shape { if dim <= 0 { return nil, errors.New("all dimensions must be positive") } expectedSize *= dim } data := make([]float64, expectedSize) for i := range data { data[i] = 1.0 } return NewMatrix(data, shape) } // NewVector 创建一个向量 func NewVector(data []float64) (*Matrix, error) { return NewMatrix(data, []int{len(data), 1}) } // NewIdentity 创建一个单位矩阵 // size: 单位矩阵的大小 func NewIdentity(size int) (*Matrix, error) { if size <= 0 { return nil, errors.New("size must be positive") } data := make([]float64, size*size) for i := 0; i < size; i++ { data[i*size+i] = 1.0 } return NewMatrix(data, []int{size, size}) } // Get 获取指定位置的值 // indices: 位置索引 func (m *Matrix) Get(indices ...int) (float64, error) { if len(indices) != m.mdim { return 0, errors.New("number of indices must match matrix dimensions") } for i, idx := range indices { if idx < 0 || idx >= m.shape[i] { return 0, errors.New("index out of bounds") } } pos := 0 for i, idx := range indices { pos += idx * m.strides[i] } return m.data[pos], nil } // Set 设置指定位置的值 // value: 要设置的值 // indices: 位置索引 func (m *Matrix) Set(value float64, indices ...int) error { if len(indices) != m.mdim { return errors.New("number of indices must match matrix dimensions") } for i, idx := range indices { if idx < 0 || idx >= m.shape[i] { return errors.New("index out of bounds") } } pos := 0 for i, idx := range indices { pos += idx * m.strides[i] } m.data[pos] = value return nil } // Add 矩阵加法 // other: 要相加的另一个矩阵 func (m *Matrix) Add(other *Matrix) (*Matrix, error) { if !m.hasSameShape(other) { return nil, errors.New("matrices must have the same shape for addition") } resultData := make([]float64, m.size) for i := 0; i < m.size; i++ { resultData[i] = m.data[i] + other.data[i] } return NewMatrix(resultData, m.shape) } func (m *Matrix) Sum() float64 { var sum float64 = 0 for i := 0; i < m.size; i++ { sum += m.data[i] } return sum } // Subtract 矩阵减法 // other: 要相减的另一个矩阵 func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) { if !m.hasSameShape(other) { return nil, errors.New("matrices must have the same shape for subtraction") } resultData := make([]float64, m.size) for i := 0; i < m.size; i++ { resultData[i] = m.data[i] - other.data[i] } return NewMatrix(resultData, m.shape) } // Multiply 矩阵乘法(逐元素相乘) // other: 要相乘的另一个矩阵 func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) { if !m.hasSameShape(other) { return nil, errors.New("matrices must have the same shape for element-wise multiplication") } resultData := make([]float64, m.size) for i := 0; i < m.size; i++ { resultData[i] = m.data[i] * other.data[i] } return NewMatrix(resultData, m.shape) } // MatMul 矩阵点乘(线性代数乘法) // other: 要相乘的另一个矩阵 func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) { if m.mdim != 2 || other.mdim != 2 { return nil, errors.New("matmul only supported for 2D matrices") } if m.shape[1] != other.shape[0] { return nil, fmt.Errorf("cannot multiply matrices with shapes [%d,%d] and [%d,%d]", m.shape[0], m.shape[1], other.shape[0], other.shape[1]) } rows, cols, inner := m.shape[0], other.shape[1], m.shape[1] resultData := make([]float64, rows*cols) for i := 0; i < rows; i++ { for j := 0; j < cols; j++ { sum := 0.0 for k := 0; k < inner; k++ { mIdx := i*m.shape[1] + k oIdx := k*other.shape[1] + j sum += m.data[mIdx] * other.data[oIdx] } resultData[i*cols+j] = sum } } return NewMatrix(resultData, []int{rows, cols}) } // Scale 矩阵数乘 // factor: 缩放因子 func (m *Matrix) Scale(factor float64) *Matrix { resultData := make([]float64, m.size) for i := 0; i < m.size; i++ { resultData[i] = m.data[i] * factor } return &Matrix{ data: resultData, shape: m.shape, strides: m.strides, mdim: m.mdim, size: m.size, } } func (m *Matrix) Equal(other *Matrix) bool { if m.shape[0] != other.shape[0] || m.shape[1] != other.shape[1] { return false } for i := 0; i < m.size; i++ { if m.data[i] != other.data[i] { return false } } return true } // Transpose 矩阵转置(仅支持2维矩阵) func (m *Matrix) Transpose() (*Matrix, error) { if m.mdim != 2 { return nil, errors.New("transpose only supported for 2D matrices") } rows, cols := m.shape[0], m.shape[1] resultData := make([]float64, m.size) for i := 0; i < rows; i++ { for j := 0; j < cols; j++ { resultData[j*rows+i] = m.data[i*cols+j] } } return NewMatrix(resultData, []int{cols, rows}) } // Copy 复制矩阵 func (m *Matrix) Copy() *Matrix { dataCopy := make([]float64, m.size) copy(dataCopy, m.data) shapeCopy := make([]int, m.mdim) copy(shapeCopy, m.shape) stridesCopy := make([]int, m.mdim) copy(stridesCopy, m.shape) return &Matrix{ data: dataCopy, shape: shapeCopy, strides: stridesCopy, mdim: m.mdim, size: m.size, } } // String 实现Stringer接口,用于打印矩阵 func (m *Matrix) String() string { if m.mdim == 2 { result := "[\n" rows, cols := m.shape[0], m.shape[1] for i := 0; i < rows; i++ { result += " [" for j := 0; j < cols; j++ { idx := i*cols + j result += fmt.Sprintf("%.2f", m.data[idx]) if j < cols-1 { result += ", " } } result += "]\n" } result += "]" return result } // 简单处理其他维度的情况 return fmt.Sprintf("Matrix{data:%v, shape:%v}", m.data, m.shape) } // Shape 返回矩阵的形状 func (m *Matrix) Shape() []int { shape := make([]int, m.mdim) copy(shape, m.shape) return shape } // Size 返回矩阵的大小 func (m *Matrix) Size() int { return m.size } // hasSameShape 检查两个矩阵是否具有相同的形状 func (m *Matrix) hasSameShape(other *Matrix) bool { if m.mdim != other.mdim { return false } for i := 0; i < m.mdim; i++ { if m.shape[i] != other.shape[i] { return false } } return true }