gomatrix/matrix.go

312 lines
6.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package gomatrix
import (
"errors"
"fmt"
)
type Matrix struct {
data []float64
shape []int
strides []int
mdim int
size int
}
// NewMatrix 创建一个新的矩阵
func NewMatrix(data []float64, shape []int) (*Matrix, error) {
if len(data) == 0 || len(shape) == 0 {
return nil, errors.New("data and shape cannot be empty")
}
// 计算期望的大小
expectedSize := 1
for _, dim := range shape {
if dim <= 0 {
return nil, errors.New("all dimensions must be positive")
}
expectedSize *= dim
}
if len(data) != expectedSize {
return nil, fmt.Errorf("data length %d does not match expected size %d for given shape", len(data), expectedSize)
}
// 计算步长
strides := make([]int, len(shape))
stride := 1
for i := len(shape) - 1; i >= 0; i-- {
strides[i] = stride
stride *= shape[i]
}
return &Matrix{
data: data,
shape: shape,
strides: strides,
mdim: len(shape),
size: expectedSize,
}, nil
}
// NewZeros 创建一个全零矩阵
func NewZeros(shape []int) (*Matrix, error) {
expectedSize := 1
for _, dim := range shape {
if dim <= 0 {
return nil, errors.New("all dimensions must be positive")
}
expectedSize *= dim
}
data := make([]float64, expectedSize)
return NewMatrix(data, shape)
}
// NewOnes 创建一个全一矩阵
func NewOnes(shape []int) (*Matrix, error) {
expectedSize := 1
for _, dim := range shape {
if dim <= 0 {
return nil, errors.New("all dimensions must be positive")
}
expectedSize *= dim
}
data := make([]float64, expectedSize)
for i := range data {
data[i] = 1.0
}
return NewMatrix(data, shape)
}
// NewIdentity 创建一个单位矩阵
func NewIdentity(size int) (*Matrix, error) {
if size <= 0 {
return nil, errors.New("size must be positive")
}
data := make([]float64, size*size)
for i := 0; i < size; i++ {
data[i*size+i] = 1.0
}
return NewMatrix(data, []int{size, size})
}
// Get 获取指定位置的值
func (m *Matrix) Get(indices ...int) (float64, error) {
if len(indices) != m.mdim {
return 0, errors.New("number of indices must match matrix dimensions")
}
for i, idx := range indices {
if idx < 0 || idx >= m.shape[i] {
return 0, errors.New("index out of bounds")
}
}
pos := 0
for i, idx := range indices {
pos += idx * m.strides[i]
}
return m.data[pos], nil
}
// Set 设置指定位置的值
func (m *Matrix) Set(value float64, indices ...int) error {
if len(indices) != m.mdim {
return errors.New("number of indices must match matrix dimensions")
}
for i, idx := range indices {
if idx < 0 || idx >= m.shape[i] {
return errors.New("index out of bounds")
}
}
pos := 0
for i, idx := range indices {
pos += idx * m.strides[i]
}
m.data[pos] = value
return nil
}
// Add 矩阵加法
func (m *Matrix) Add(other *Matrix) (*Matrix, error) {
if !m.hasSameShape(other) {
return nil, errors.New("matrices must have the same shape for addition")
}
resultData := make([]float64, m.size)
for i := 0; i < m.size; i++ {
resultData[i] = m.data[i] + other.data[i]
}
return NewMatrix(resultData, m.shape)
}
// Subtract 矩阵减法
func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) {
if !m.hasSameShape(other) {
return nil, errors.New("matrices must have the same shape for subtraction")
}
resultData := make([]float64, m.size)
for i := 0; i < m.size; i++ {
resultData[i] = m.data[i] - other.data[i]
}
return NewMatrix(resultData, m.shape)
}
// Multiply 矩阵乘法(逐元素相乘)
func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) {
if !m.hasSameShape(other) {
return nil, errors.New("matrices must have the same shape for element-wise multiplication")
}
resultData := make([]float64, m.size)
for i := 0; i < m.size; i++ {
resultData[i] = m.data[i] * other.data[i]
}
return NewMatrix(resultData, m.shape)
}
// MatMul 矩阵点乘(线性代数乘法)
func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) {
if m.mdim != 2 || other.mdim != 2 {
return nil, errors.New("matmul only supported for 2D matrices")
}
if m.shape[1] != other.shape[0] {
return nil, fmt.Errorf("cannot multiply matrices with shapes [%d,%d] and [%d,%d]",
m.shape[0], m.shape[1], other.shape[0], other.shape[1])
}
rows, cols, inner := m.shape[0], other.shape[1], m.shape[1]
resultData := make([]float64, rows*cols)
for i := 0; i < rows; i++ {
for j := 0; j < cols; j++ {
sum := 0.0
for k := 0; k < inner; k++ {
mIdx := i*m.shape[1] + k
oIdx := k*other.shape[1] + j
sum += m.data[mIdx] * other.data[oIdx]
}
resultData[i*cols+j] = sum
}
}
return NewMatrix(resultData, []int{rows, cols})
}
// Scale 矩阵数乘
func (m *Matrix) Scale(factor float64) *Matrix {
resultData := make([]float64, m.size)
for i := 0; i < m.size; i++ {
resultData[i] = m.data[i] * factor
}
return &Matrix{
data: resultData,
shape: m.shape,
strides: m.strides,
mdim: m.mdim,
size: m.size,
}
}
// Transpose 矩阵转置仅支持2维矩阵
func (m *Matrix) Transpose() (*Matrix, error) {
if m.mdim != 2 {
return nil, errors.New("transpose only supported for 2D matrices")
}
rows, cols := m.shape[0], m.shape[1]
resultData := make([]float64, m.size)
for i := 0; i < rows; i++ {
for j := 0; j < cols; j++ {
resultData[j*rows+i] = m.data[i*cols+j]
}
}
return NewMatrix(resultData, []int{cols, rows})
}
// Copy 复制矩阵
func (m *Matrix) Copy() *Matrix {
dataCopy := make([]float64, m.size)
copy(dataCopy, m.data)
shapeCopy := make([]int, m.mdim)
copy(shapeCopy, m.shape)
stridesCopy := make([]int, m.mdim)
copy(stridesCopy, m.strides)
return &Matrix{
data: dataCopy,
shape: shapeCopy,
strides: stridesCopy,
mdim: m.mdim,
size: m.size,
}
}
// String 实现Stringer接口用于打印矩阵
func (m *Matrix) String() string {
if m.mdim == 2 {
result := "[\n"
rows, cols := m.shape[0], m.shape[1]
for i := 0; i < rows; i++ {
result += " ["
for j := 0; j < cols; j++ {
idx := i*cols + j
result += fmt.Sprintf("%.2f", m.data[idx])
if j < cols-1 {
result += ", "
}
}
result += "]\n"
}
result += "]"
return result
}
// 简单处理其他维度的情况
return fmt.Sprintf("Matrix{data:%v, shape:%v}", m.data, m.shape)
}
// Shape 返回矩阵的形状
func (m *Matrix) Shape() []int {
shape := make([]int, m.mdim)
copy(shape, m.shape)
return shape
}
// Size 返回矩阵的大小
func (m *Matrix) Size() int {
return m.size
}
// hasSameShape 检查两个矩阵是否具有相同的形状
func (m *Matrix) hasSameShape(other *Matrix) bool {
if m.mdim != other.mdim {
return false
}
for i := 0; i < m.mdim; i++ {
if m.shape[i] != other.shape[i] {
return false
}
}
return true
}