354 lines
7.7 KiB
Go
354 lines
7.7 KiB
Go
package gomatrix
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
)
|
||
|
||
// Matrix 表示一个多维矩阵
|
||
type Matrix struct {
|
||
data []float64 // 存储矩阵的实际数值
|
||
shape []int // 表示矩阵的形状,例如[2, 3]表示2行3列的矩阵
|
||
strides []int // 步长,用于多维数组访问
|
||
mdim int // 矩阵的维度数
|
||
size int // 矩阵元素总数
|
||
}
|
||
|
||
// NewMatrix 创建一个新的矩阵
|
||
// data: 矩阵数据
|
||
// shape: 矩阵形状,例如[2, 3]表示2行3列
|
||
func NewMatrix(data []float64, shape []int) (*Matrix, error) {
|
||
if len(data) == 0 || len(shape) == 0 {
|
||
return nil, errors.New("data and shape cannot be empty")
|
||
}
|
||
|
||
// 计算期望的大小
|
||
expectedSize := 1
|
||
for _, dim := range shape {
|
||
if dim <= 0 {
|
||
return nil, errors.New("all dimensions must be positive")
|
||
}
|
||
expectedSize *= dim
|
||
}
|
||
|
||
if len(data) != expectedSize {
|
||
return nil, fmt.Errorf("data length %d does not match expected size %d for given shape", len(data), expectedSize)
|
||
}
|
||
|
||
// 计算步长
|
||
strides := make([]int, len(shape))
|
||
stride := 1
|
||
for i := len(shape) - 1; i >= 0; i-- {
|
||
strides[i] = stride
|
||
stride *= shape[i]
|
||
}
|
||
|
||
return &Matrix{
|
||
data: data,
|
||
shape: shape,
|
||
strides: strides,
|
||
mdim: len(shape),
|
||
size: expectedSize,
|
||
}, nil
|
||
}
|
||
|
||
// NewZeros 创建一个全零矩阵
|
||
// shape: 矩阵形状
|
||
func NewZeros(shape []int) (*Matrix, error) {
|
||
expectedSize := 1
|
||
for _, dim := range shape {
|
||
if dim <= 0 {
|
||
return nil, errors.New("all dimensions must be positive")
|
||
}
|
||
expectedSize *= dim
|
||
}
|
||
|
||
data := make([]float64, expectedSize)
|
||
return NewMatrix(data, shape)
|
||
}
|
||
|
||
// NewOnes 创建一个全一矩阵
|
||
// shape: 矩阵形状
|
||
func NewOnes(shape []int) (*Matrix, error) {
|
||
expectedSize := 1
|
||
for _, dim := range shape {
|
||
if dim <= 0 {
|
||
return nil, errors.New("all dimensions must be positive")
|
||
}
|
||
expectedSize *= dim
|
||
}
|
||
|
||
data := make([]float64, expectedSize)
|
||
for i := range data {
|
||
data[i] = 1.0
|
||
}
|
||
return NewMatrix(data, shape)
|
||
}
|
||
|
||
// NewVector 创建一个向量
|
||
func NewVector(data []float64) (*Matrix, error) {
|
||
return NewMatrix(data, []int{len(data), 1})
|
||
}
|
||
|
||
// NewIdentity 创建一个单位矩阵
|
||
// size: 单位矩阵的大小
|
||
func NewIdentity(size int) (*Matrix, error) {
|
||
if size <= 0 {
|
||
return nil, errors.New("size must be positive")
|
||
}
|
||
|
||
data := make([]float64, size*size)
|
||
for i := 0; i < size; i++ {
|
||
data[i*size+i] = 1.0
|
||
}
|
||
|
||
return NewMatrix(data, []int{size, size})
|
||
}
|
||
|
||
// Get 获取指定位置的值
|
||
// indices: 位置索引
|
||
func (m *Matrix) Get(indices ...int) (float64, error) {
|
||
if len(indices) != m.mdim {
|
||
return 0, errors.New("number of indices must match matrix dimensions")
|
||
}
|
||
|
||
for i, idx := range indices {
|
||
if idx < 0 || idx >= m.shape[i] {
|
||
return 0, errors.New("index out of bounds")
|
||
}
|
||
}
|
||
|
||
pos := 0
|
||
for i, idx := range indices {
|
||
pos += idx * m.strides[i]
|
||
}
|
||
|
||
return m.data[pos], nil
|
||
}
|
||
|
||
// Set 设置指定位置的值
|
||
// value: 要设置的值
|
||
// indices: 位置索引
|
||
func (m *Matrix) Set(value float64, indices ...int) error {
|
||
if len(indices) != m.mdim {
|
||
return errors.New("number of indices must match matrix dimensions")
|
||
}
|
||
|
||
for i, idx := range indices {
|
||
if idx < 0 || idx >= m.shape[i] {
|
||
return errors.New("index out of bounds")
|
||
}
|
||
}
|
||
|
||
pos := 0
|
||
for i, idx := range indices {
|
||
pos += idx * m.strides[i]
|
||
}
|
||
|
||
m.data[pos] = value
|
||
return nil
|
||
}
|
||
|
||
// Add 矩阵加法
|
||
// other: 要相加的另一个矩阵
|
||
func (m *Matrix) Add(other *Matrix) (*Matrix, error) {
|
||
if !m.hasSameShape(other) {
|
||
return nil, errors.New("matrices must have the same shape for addition")
|
||
}
|
||
|
||
resultData := make([]float64, m.size)
|
||
for i := 0; i < m.size; i++ {
|
||
resultData[i] = m.data[i] + other.data[i]
|
||
}
|
||
|
||
return NewMatrix(resultData, m.shape)
|
||
}
|
||
|
||
func (m *Matrix) Sum() float64 {
|
||
var sum float64 = 0
|
||
for i := 0; i < m.size; i++ {
|
||
sum += m.data[i]
|
||
}
|
||
return sum
|
||
}
|
||
|
||
// Subtract 矩阵减法
|
||
// other: 要相减的另一个矩阵
|
||
func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) {
|
||
if !m.hasSameShape(other) {
|
||
return nil, errors.New("matrices must have the same shape for subtraction")
|
||
}
|
||
|
||
resultData := make([]float64, m.size)
|
||
for i := 0; i < m.size; i++ {
|
||
resultData[i] = m.data[i] - other.data[i]
|
||
}
|
||
|
||
return NewMatrix(resultData, m.shape)
|
||
}
|
||
|
||
// Multiply 矩阵乘法(逐元素相乘)
|
||
// other: 要相乘的另一个矩阵
|
||
func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) {
|
||
if !m.hasSameShape(other) {
|
||
return nil, errors.New("matrices must have the same shape for element-wise multiplication")
|
||
}
|
||
|
||
resultData := make([]float64, m.size)
|
||
for i := 0; i < m.size; i++ {
|
||
resultData[i] = m.data[i] * other.data[i]
|
||
}
|
||
|
||
return NewMatrix(resultData, m.shape)
|
||
}
|
||
|
||
// MatMul 矩阵点乘(线性代数乘法)
|
||
// other: 要相乘的另一个矩阵
|
||
func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) {
|
||
if m.mdim != 2 || other.mdim != 2 {
|
||
return nil, errors.New("matmul only supported for 2D matrices")
|
||
}
|
||
|
||
if m.shape[1] != other.shape[0] {
|
||
return nil, fmt.Errorf("cannot multiply matrices with shapes [%d,%d] and [%d,%d]",
|
||
m.shape[0], m.shape[1], other.shape[0], other.shape[1])
|
||
}
|
||
|
||
rows, cols, inner := m.shape[0], other.shape[1], m.shape[1]
|
||
resultData := make([]float64, rows*cols)
|
||
|
||
for i := 0; i < rows; i++ {
|
||
for j := 0; j < cols; j++ {
|
||
sum := 0.0
|
||
for k := 0; k < inner; k++ {
|
||
mIdx := i*m.shape[1] + k
|
||
oIdx := k*other.shape[1] + j
|
||
sum += m.data[mIdx] * other.data[oIdx]
|
||
}
|
||
resultData[i*cols+j] = sum
|
||
}
|
||
}
|
||
|
||
return NewMatrix(resultData, []int{rows, cols})
|
||
}
|
||
|
||
// Scale 矩阵数乘
|
||
// factor: 缩放因子
|
||
func (m *Matrix) Scale(factor float64) *Matrix {
|
||
resultData := make([]float64, m.size)
|
||
for i := 0; i < m.size; i++ {
|
||
resultData[i] = m.data[i] * factor
|
||
}
|
||
|
||
return &Matrix{
|
||
data: resultData,
|
||
shape: m.shape,
|
||
strides: m.strides,
|
||
mdim: m.mdim,
|
||
size: m.size,
|
||
}
|
||
}
|
||
|
||
func (m *Matrix) Equal(other *Matrix) bool {
|
||
if m.shape[0] != other.shape[0] || m.shape[1] != other.shape[1] {
|
||
return false
|
||
}
|
||
|
||
for i := 0; i < m.size; i++ {
|
||
if m.data[i] != other.data[i] {
|
||
return false
|
||
}
|
||
}
|
||
|
||
return true
|
||
}
|
||
|
||
// Transpose 矩阵转置(仅支持2维矩阵)
|
||
func (m *Matrix) Transpose() (*Matrix, error) {
|
||
if m.mdim != 2 {
|
||
return nil, errors.New("transpose only supported for 2D matrices")
|
||
}
|
||
|
||
rows, cols := m.shape[0], m.shape[1]
|
||
resultData := make([]float64, m.size)
|
||
|
||
for i := 0; i < rows; i++ {
|
||
for j := 0; j < cols; j++ {
|
||
resultData[j*rows+i] = m.data[i*cols+j]
|
||
}
|
||
}
|
||
|
||
return NewMatrix(resultData, []int{cols, rows})
|
||
}
|
||
|
||
// Copy 复制矩阵
|
||
func (m *Matrix) Copy() *Matrix {
|
||
dataCopy := make([]float64, m.size)
|
||
copy(dataCopy, m.data)
|
||
|
||
shapeCopy := make([]int, m.mdim)
|
||
copy(shapeCopy, m.shape)
|
||
|
||
stridesCopy := make([]int, m.mdim)
|
||
copy(stridesCopy, m.shape)
|
||
|
||
return &Matrix{
|
||
data: dataCopy,
|
||
shape: shapeCopy,
|
||
strides: stridesCopy,
|
||
mdim: m.mdim,
|
||
size: m.size,
|
||
}
|
||
}
|
||
|
||
// String 实现Stringer接口,用于打印矩阵
|
||
func (m *Matrix) String() string {
|
||
if m.mdim == 2 {
|
||
result := "[\n"
|
||
rows, cols := m.shape[0], m.shape[1]
|
||
for i := 0; i < rows; i++ {
|
||
result += " ["
|
||
for j := 0; j < cols; j++ {
|
||
idx := i*cols + j
|
||
result += fmt.Sprintf("%.2f", m.data[idx])
|
||
if j < cols-1 {
|
||
result += ", "
|
||
}
|
||
}
|
||
result += "]\n"
|
||
}
|
||
result += "]"
|
||
return result
|
||
}
|
||
|
||
// 简单处理其他维度的情况
|
||
return fmt.Sprintf("Matrix{data:%v, shape:%v}", m.data, m.shape)
|
||
}
|
||
|
||
// Shape 返回矩阵的形状
|
||
func (m *Matrix) Shape() []int {
|
||
shape := make([]int, m.mdim)
|
||
copy(shape, m.shape)
|
||
return shape
|
||
}
|
||
|
||
// Size 返回矩阵的大小
|
||
func (m *Matrix) Size() int {
|
||
return m.size
|
||
}
|
||
|
||
// hasSameShape 检查两个矩阵是否具有相同的形状
|
||
func (m *Matrix) hasSameShape(other *Matrix) bool {
|
||
if m.mdim != other.mdim {
|
||
return false
|
||
}
|
||
|
||
for i := 0; i < m.mdim; i++ {
|
||
if m.shape[i] != other.shape[i] {
|
||
return false
|
||
}
|
||
}
|
||
|
||
return true
|
||
}
|