gomatrix/matrix.go

354 lines
7.7 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package gomatrix
import (
"errors"
"fmt"
)
// Matrix 表示一个多维矩阵
type Matrix struct {
data []float64 // 存储矩阵的实际数值
shape []int // 表示矩阵的形状,例如[2, 3]表示2行3列的矩阵
strides []int // 步长,用于多维数组访问
mdim int // 矩阵的维度数
size int // 矩阵元素总数
}
// NewMatrix 创建一个新的矩阵
// data: 矩阵数据
// shape: 矩阵形状,例如[2, 3]表示2行3列
func NewMatrix(data []float64, shape []int) (*Matrix, error) {
if len(data) == 0 || len(shape) == 0 {
return nil, errors.New("data and shape cannot be empty")
}
// 计算期望的大小
expectedSize := 1
for _, dim := range shape {
if dim <= 0 {
return nil, errors.New("all dimensions must be positive")
}
expectedSize *= dim
}
if len(data) != expectedSize {
return nil, fmt.Errorf("data length %d does not match expected size %d for given shape", len(data), expectedSize)
}
// 计算步长
strides := make([]int, len(shape))
stride := 1
for i := len(shape) - 1; i >= 0; i-- {
strides[i] = stride
stride *= shape[i]
}
return &Matrix{
data: data,
shape: shape,
strides: strides,
mdim: len(shape),
size: expectedSize,
}, nil
}
// NewZeros 创建一个全零矩阵
// shape: 矩阵形状
func NewZeros(shape []int) (*Matrix, error) {
expectedSize := 1
for _, dim := range shape {
if dim <= 0 {
return nil, errors.New("all dimensions must be positive")
}
expectedSize *= dim
}
data := make([]float64, expectedSize)
return NewMatrix(data, shape)
}
// NewOnes 创建一个全一矩阵
// shape: 矩阵形状
func NewOnes(shape []int) (*Matrix, error) {
expectedSize := 1
for _, dim := range shape {
if dim <= 0 {
return nil, errors.New("all dimensions must be positive")
}
expectedSize *= dim
}
data := make([]float64, expectedSize)
for i := range data {
data[i] = 1.0
}
return NewMatrix(data, shape)
}
// NewVector 创建一个向量
func NewVector(data []float64) (*Matrix, error) {
return NewMatrix(data, []int{len(data), 1})
}
// NewIdentity 创建一个单位矩阵
// size: 单位矩阵的大小
func NewIdentity(size int) (*Matrix, error) {
if size <= 0 {
return nil, errors.New("size must be positive")
}
data := make([]float64, size*size)
for i := 0; i < size; i++ {
data[i*size+i] = 1.0
}
return NewMatrix(data, []int{size, size})
}
// Get 获取指定位置的值
// indices: 位置索引
func (m *Matrix) Get(indices ...int) (float64, error) {
if len(indices) != m.mdim {
return 0, errors.New("number of indices must match matrix dimensions")
}
for i, idx := range indices {
if idx < 0 || idx >= m.shape[i] {
return 0, errors.New("index out of bounds")
}
}
pos := 0
for i, idx := range indices {
pos += idx * m.strides[i]
}
return m.data[pos], nil
}
// Set 设置指定位置的值
// value: 要设置的值
// indices: 位置索引
func (m *Matrix) Set(value float64, indices ...int) error {
if len(indices) != m.mdim {
return errors.New("number of indices must match matrix dimensions")
}
for i, idx := range indices {
if idx < 0 || idx >= m.shape[i] {
return errors.New("index out of bounds")
}
}
pos := 0
for i, idx := range indices {
pos += idx * m.strides[i]
}
m.data[pos] = value
return nil
}
// Add 矩阵加法
// other: 要相加的另一个矩阵
func (m *Matrix) Add(other *Matrix) (*Matrix, error) {
if !m.hasSameShape(other) {
return nil, errors.New("matrices must have the same shape for addition")
}
resultData := make([]float64, m.size)
for i := 0; i < m.size; i++ {
resultData[i] = m.data[i] + other.data[i]
}
return NewMatrix(resultData, m.shape)
}
func (m *Matrix) Sum() float64 {
var sum float64 = 0
for i := 0; i < m.size; i++ {
sum += m.data[i]
}
return sum
}
// Subtract 矩阵减法
// other: 要相减的另一个矩阵
func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) {
if !m.hasSameShape(other) {
return nil, errors.New("matrices must have the same shape for subtraction")
}
resultData := make([]float64, m.size)
for i := 0; i < m.size; i++ {
resultData[i] = m.data[i] - other.data[i]
}
return NewMatrix(resultData, m.shape)
}
// Multiply 矩阵乘法(逐元素相乘)
// other: 要相乘的另一个矩阵
func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) {
if !m.hasSameShape(other) {
return nil, errors.New("matrices must have the same shape for element-wise multiplication")
}
resultData := make([]float64, m.size)
for i := 0; i < m.size; i++ {
resultData[i] = m.data[i] * other.data[i]
}
return NewMatrix(resultData, m.shape)
}
// MatMul 矩阵点乘(线性代数乘法)
// other: 要相乘的另一个矩阵
func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) {
if m.mdim != 2 || other.mdim != 2 {
return nil, errors.New("matmul only supported for 2D matrices")
}
if m.shape[1] != other.shape[0] {
return nil, fmt.Errorf("cannot multiply matrices with shapes [%d,%d] and [%d,%d]",
m.shape[0], m.shape[1], other.shape[0], other.shape[1])
}
rows, cols, inner := m.shape[0], other.shape[1], m.shape[1]
resultData := make([]float64, rows*cols)
for i := 0; i < rows; i++ {
for j := 0; j < cols; j++ {
sum := 0.0
for k := 0; k < inner; k++ {
mIdx := i*m.shape[1] + k
oIdx := k*other.shape[1] + j
sum += m.data[mIdx] * other.data[oIdx]
}
resultData[i*cols+j] = sum
}
}
return NewMatrix(resultData, []int{rows, cols})
}
// Scale 矩阵数乘
// factor: 缩放因子
func (m *Matrix) Scale(factor float64) *Matrix {
resultData := make([]float64, m.size)
for i := 0; i < m.size; i++ {
resultData[i] = m.data[i] * factor
}
return &Matrix{
data: resultData,
shape: m.shape,
strides: m.strides,
mdim: m.mdim,
size: m.size,
}
}
func (m *Matrix) Equal(other *Matrix) bool {
if m.shape[0] != other.shape[0] || m.shape[1] != other.shape[1] {
return false
}
for i := 0; i < m.size; i++ {
if m.data[i] != other.data[i] {
return false
}
}
return true
}
// Transpose 矩阵转置仅支持2维矩阵
func (m *Matrix) Transpose() (*Matrix, error) {
if m.mdim != 2 {
return nil, errors.New("transpose only supported for 2D matrices")
}
rows, cols := m.shape[0], m.shape[1]
resultData := make([]float64, m.size)
for i := 0; i < rows; i++ {
for j := 0; j < cols; j++ {
resultData[j*rows+i] = m.data[i*cols+j]
}
}
return NewMatrix(resultData, []int{cols, rows})
}
// Copy 复制矩阵
func (m *Matrix) Copy() *Matrix {
dataCopy := make([]float64, m.size)
copy(dataCopy, m.data)
shapeCopy := make([]int, m.mdim)
copy(shapeCopy, m.shape)
stridesCopy := make([]int, m.mdim)
copy(stridesCopy, m.shape)
return &Matrix{
data: dataCopy,
shape: shapeCopy,
strides: stridesCopy,
mdim: m.mdim,
size: m.size,
}
}
// String 实现Stringer接口用于打印矩阵
func (m *Matrix) String() string {
if m.mdim == 2 {
result := "[\n"
rows, cols := m.shape[0], m.shape[1]
for i := 0; i < rows; i++ {
result += " ["
for j := 0; j < cols; j++ {
idx := i*cols + j
result += fmt.Sprintf("%.2f", m.data[idx])
if j < cols-1 {
result += ", "
}
}
result += "]\n"
}
result += "]"
return result
}
// 简单处理其他维度的情况
return fmt.Sprintf("Matrix{data:%v, shape:%v}", m.data, m.shape)
}
// Shape 返回矩阵的形状
func (m *Matrix) Shape() []int {
shape := make([]int, m.mdim)
copy(shape, m.shape)
return shape
}
// Size 返回矩阵的大小
func (m *Matrix) Size() int {
return m.size
}
// hasSameShape 检查两个矩阵是否具有相同的形状
func (m *Matrix) hasSameShape(other *Matrix) bool {
if m.mdim != other.mdim {
return false
}
for i := 0; i < m.mdim; i++ {
if m.shape[i] != other.shape[i] {
return false
}
}
return true
}