package gomatrix import ( "errors" "fmt" ) type Matrix struct { data []float64 shape []int strides []int mdim int size int } // NewMatrix 创建一个新的矩阵 func NewMatrix(data []float64, shape []int) (*Matrix, error) { if len(data) == 0 || len(shape) == 0 { return nil, errors.New("data and shape cannot be empty") } // 计算期望的大小 expectedSize := 1 for _, dim := range shape { if dim <= 0 { return nil, errors.New("all dimensions must be positive") } expectedSize *= dim } if len(data) != expectedSize { return nil, fmt.Errorf("data length %d does not match expected size %d for given shape", len(data), expectedSize) } // 计算步长 strides := make([]int, len(shape)) stride := 1 for i := len(shape) - 1; i >= 0; i-- { strides[i] = stride stride *= shape[i] } return &Matrix{ data: data, shape: shape, strides: strides, mdim: len(shape), size: expectedSize, }, nil } // NewZeros 创建一个全零矩阵 func NewZeros(shape []int) (*Matrix, error) { expectedSize := 1 for _, dim := range shape { if dim <= 0 { return nil, errors.New("all dimensions must be positive") } expectedSize *= dim } data := make([]float64, expectedSize) return NewMatrix(data, shape) } // NewOnes 创建一个全一矩阵 func NewOnes(shape []int) (*Matrix, error) { expectedSize := 1 for _, dim := range shape { if dim <= 0 { return nil, errors.New("all dimensions must be positive") } expectedSize *= dim } data := make([]float64, expectedSize) for i := range data { data[i] = 1.0 } return NewMatrix(data, shape) } // NewIdentity 创建一个单位矩阵 func NewIdentity(size int) (*Matrix, error) { if size <= 0 { return nil, errors.New("size must be positive") } data := make([]float64, size*size) for i := 0; i < size; i++ { data[i*size+i] = 1.0 } return NewMatrix(data, []int{size, size}) } // Get 获取指定位置的值 func (m *Matrix) Get(indices ...int) (float64, error) { if len(indices) != m.mdim { return 0, errors.New("number of indices must match matrix dimensions") } for i, idx := range indices { if idx < 0 || idx >= m.shape[i] { return 0, errors.New("index out of bounds") } } pos := 0 for i, idx := range indices { pos += idx * m.strides[i] } return m.data[pos], nil } // Set 设置指定位置的值 func (m *Matrix) Set(value float64, indices ...int) error { if len(indices) != m.mdim { return errors.New("number of indices must match matrix dimensions") } for i, idx := range indices { if idx < 0 || idx >= m.shape[i] { return errors.New("index out of bounds") } } pos := 0 for i, idx := range indices { pos += idx * m.strides[i] } m.data[pos] = value return nil } // Add 矩阵加法 func (m *Matrix) Add(other *Matrix) (*Matrix, error) { if !m.hasSameShape(other) { return nil, errors.New("matrices must have the same shape for addition") } resultData := make([]float64, m.size) for i := 0; i < m.size; i++ { resultData[i] = m.data[i] + other.data[i] } return NewMatrix(resultData, m.shape) } // Subtract 矩阵减法 func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) { if !m.hasSameShape(other) { return nil, errors.New("matrices must have the same shape for subtraction") } resultData := make([]float64, m.size) for i := 0; i < m.size; i++ { resultData[i] = m.data[i] - other.data[i] } return NewMatrix(resultData, m.shape) } // Multiply 矩阵乘法(逐元素相乘) func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) { if !m.hasSameShape(other) { return nil, errors.New("matrices must have the same shape for element-wise multiplication") } resultData := make([]float64, m.size) for i := 0; i < m.size; i++ { resultData[i] = m.data[i] * other.data[i] } return NewMatrix(resultData, m.shape) } // MatMul 矩阵点乘(线性代数乘法) func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) { if m.mdim != 2 || other.mdim != 2 { return nil, errors.New("matmul only supported for 2D matrices") } if m.shape[1] != other.shape[0] { return nil, fmt.Errorf("cannot multiply matrices with shapes [%d,%d] and [%d,%d]", m.shape[0], m.shape[1], other.shape[0], other.shape[1]) } rows, cols, inner := m.shape[0], other.shape[1], m.shape[1] resultData := make([]float64, rows*cols) for i := 0; i < rows; i++ { for j := 0; j < cols; j++ { sum := 0.0 for k := 0; k < inner; k++ { mIdx := i*m.shape[1] + k oIdx := k*other.shape[1] + j sum += m.data[mIdx] * other.data[oIdx] } resultData[i*cols+j] = sum } } return NewMatrix(resultData, []int{rows, cols}) } // Scale 矩阵数乘 func (m *Matrix) Scale(factor float64) *Matrix { resultData := make([]float64, m.size) for i := 0; i < m.size; i++ { resultData[i] = m.data[i] * factor } return &Matrix{ data: resultData, shape: m.shape, strides: m.strides, mdim: m.mdim, size: m.size, } } // Transpose 矩阵转置(仅支持2维矩阵) func (m *Matrix) Transpose() (*Matrix, error) { if m.mdim != 2 { return nil, errors.New("transpose only supported for 2D matrices") } rows, cols := m.shape[0], m.shape[1] resultData := make([]float64, m.size) for i := 0; i < rows; i++ { for j := 0; j < cols; j++ { resultData[j*rows+i] = m.data[i*cols+j] } } return NewMatrix(resultData, []int{cols, rows}) } // Copy 复制矩阵 func (m *Matrix) Copy() *Matrix { dataCopy := make([]float64, m.size) copy(dataCopy, m.data) shapeCopy := make([]int, m.mdim) copy(shapeCopy, m.shape) stridesCopy := make([]int, m.mdim) copy(stridesCopy, m.strides) return &Matrix{ data: dataCopy, shape: shapeCopy, strides: stridesCopy, mdim: m.mdim, size: m.size, } } // String 实现Stringer接口,用于打印矩阵 func (m *Matrix) String() string { if m.mdim == 2 { result := "[\n" rows, cols := m.shape[0], m.shape[1] for i := 0; i < rows; i++ { result += " [" for j := 0; j < cols; j++ { idx := i*cols + j result += fmt.Sprintf("%.2f", m.data[idx]) if j < cols-1 { result += ", " } } result += "]\n" } result += "]" return result } // 简单处理其他维度的情况 return fmt.Sprintf("Matrix{data:%v, shape:%v}", m.data, m.shape) } // Shape 返回矩阵的形状 func (m *Matrix) Shape() []int { shape := make([]int, m.mdim) copy(shape, m.shape) return shape } // Size 返回矩阵的大小 func (m *Matrix) Size() int { return m.size } // hasSameShape 检查两个矩阵是否具有相同的形状 func (m *Matrix) hasSameShape(other *Matrix) bool { if m.mdim != other.mdim { return false } for i := 0; i < m.mdim; i++ { if m.shape[i] != other.shape[i] { return false } } return true }