```
feat: 添加Go矩阵库基础功能 - 实现矩阵数据结构和基本操作 - 添加矩阵创建、加法、乘法、转置等功能 - 实现零矩阵、单位矩阵、全一矩阵创建方法 - 添加矩阵元素访问和修改功能 - 提供完整的矩阵运算示例程序 - 添加项目许可证文件 - 配置Go模块依赖 ```
This commit is contained in:
parent
6f07b38370
commit
b3b9017dd9
|
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2025 kingecg
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"git.kingecg.top/kingecg/gomatrix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 示例程序演示矩阵功能
|
||||||
|
func Example() {
|
||||||
|
fmt.Println("创建矩阵示例:")
|
||||||
|
|
||||||
|
// 创建一个2x3的矩阵
|
||||||
|
data1 := []float64{1, 2, 3, 4, 5, 6}
|
||||||
|
mat1, err := gomatrix.NewMatrix(data1, []int{2, 3})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("创建矩阵失败: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("矩阵1:\n%s\n", mat1.String())
|
||||||
|
|
||||||
|
// 创建另一个2x3的矩阵
|
||||||
|
data2 := []float64{7, 8, 9, 10, 11, 12}
|
||||||
|
mat2, err := gomatrix.NewMatrix(data2, []int{2, 3})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("创建矩阵失败: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("矩阵2:\n%s\n", mat2.String())
|
||||||
|
|
||||||
|
// 矩阵加法
|
||||||
|
sum, err := mat1.Add(mat2)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("矩阵加法失败: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("矩阵1 + 矩阵2:\n%s\n", sum.String())
|
||||||
|
|
||||||
|
// 创建一个3x2的矩阵用于矩阵乘法
|
||||||
|
data3 := []float64{1, 2, 3, 4, 5, 6}
|
||||||
|
mat3, err := gomatrix.NewMatrix(data3, []int{3, 2})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("创建矩阵失败: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("矩阵3 (3x2):\n%s\n", mat3.String())
|
||||||
|
|
||||||
|
// 创建一个2x2的矩阵用于矩阵乘法
|
||||||
|
data4 := []float64{1, 2, 3, 4}
|
||||||
|
mat4, err := gomatrix.NewMatrix(data4, []int{2, 2})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("创建矩阵失败: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("矩阵4 (2x2):\n%s\n", mat4.String())
|
||||||
|
|
||||||
|
// 矩阵乘法
|
||||||
|
product, err := mat3.MatMul(mat4)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("矩阵乘法失败: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("矩阵3 × 矩阵4:\n%s\n", product.String())
|
||||||
|
|
||||||
|
// 矩阵转置
|
||||||
|
transposed, err := mat1.Transpose()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("矩阵转置失败: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("矩阵1的转置:\n%s\n", transposed.String())
|
||||||
|
|
||||||
|
// 矩阵数乘
|
||||||
|
scaled := mat1.Scale(2.0)
|
||||||
|
fmt.Printf("矩阵1 × 2:\n%s\n", scaled.String())
|
||||||
|
|
||||||
|
// 创建零矩阵
|
||||||
|
zeros, err := gomatrix.NewZeros([]int{2, 2})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("创建零矩阵失败: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("2x2零矩阵:\n%s\n", zeros.String())
|
||||||
|
|
||||||
|
// 创建单位矩阵
|
||||||
|
identity, err := gomatrix.NewIdentity(3)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("创建单位矩阵失败: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("3x3单位矩阵:\n%s\n", identity.String())
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
Example()
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,312 @@
|
||||||
|
package gomatrix
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Matrix struct {
|
||||||
|
data []float64
|
||||||
|
shape []int
|
||||||
|
strides []int
|
||||||
|
mdim int
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMatrix 创建一个新的矩阵
|
||||||
|
func NewMatrix(data []float64, shape []int) (*Matrix, error) {
|
||||||
|
if len(data) == 0 || len(shape) == 0 {
|
||||||
|
return nil, errors.New("data and shape cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算期望的大小
|
||||||
|
expectedSize := 1
|
||||||
|
for _, dim := range shape {
|
||||||
|
if dim <= 0 {
|
||||||
|
return nil, errors.New("all dimensions must be positive")
|
||||||
|
}
|
||||||
|
expectedSize *= dim
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data) != expectedSize {
|
||||||
|
return nil, fmt.Errorf("data length %d does not match expected size %d for given shape", len(data), expectedSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算步长
|
||||||
|
strides := make([]int, len(shape))
|
||||||
|
stride := 1
|
||||||
|
for i := len(shape) - 1; i >= 0; i-- {
|
||||||
|
strides[i] = stride
|
||||||
|
stride *= shape[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Matrix{
|
||||||
|
data: data,
|
||||||
|
shape: shape,
|
||||||
|
strides: strides,
|
||||||
|
mdim: len(shape),
|
||||||
|
size: expectedSize,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewZeros 创建一个全零矩阵
|
||||||
|
func NewZeros(shape []int) (*Matrix, error) {
|
||||||
|
expectedSize := 1
|
||||||
|
for _, dim := range shape {
|
||||||
|
if dim <= 0 {
|
||||||
|
return nil, errors.New("all dimensions must be positive")
|
||||||
|
}
|
||||||
|
expectedSize *= dim
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make([]float64, expectedSize)
|
||||||
|
return NewMatrix(data, shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOnes 创建一个全一矩阵
|
||||||
|
func NewOnes(shape []int) (*Matrix, error) {
|
||||||
|
expectedSize := 1
|
||||||
|
for _, dim := range shape {
|
||||||
|
if dim <= 0 {
|
||||||
|
return nil, errors.New("all dimensions must be positive")
|
||||||
|
}
|
||||||
|
expectedSize *= dim
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make([]float64, expectedSize)
|
||||||
|
for i := range data {
|
||||||
|
data[i] = 1.0
|
||||||
|
}
|
||||||
|
return NewMatrix(data, shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIdentity 创建一个单位矩阵
|
||||||
|
func NewIdentity(size int) (*Matrix, error) {
|
||||||
|
if size <= 0 {
|
||||||
|
return nil, errors.New("size must be positive")
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make([]float64, size*size)
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
data[i*size+i] = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewMatrix(data, []int{size, size})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 获取指定位置的值
|
||||||
|
func (m *Matrix) Get(indices ...int) (float64, error) {
|
||||||
|
if len(indices) != m.mdim {
|
||||||
|
return 0, errors.New("number of indices must match matrix dimensions")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, idx := range indices {
|
||||||
|
if idx < 0 || idx >= m.shape[i] {
|
||||||
|
return 0, errors.New("index out of bounds")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pos := 0
|
||||||
|
for i, idx := range indices {
|
||||||
|
pos += idx * m.strides[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.data[pos], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set 设置指定位置的值
|
||||||
|
func (m *Matrix) Set(value float64, indices ...int) error {
|
||||||
|
if len(indices) != m.mdim {
|
||||||
|
return errors.New("number of indices must match matrix dimensions")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, idx := range indices {
|
||||||
|
if idx < 0 || idx >= m.shape[i] {
|
||||||
|
return errors.New("index out of bounds")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pos := 0
|
||||||
|
for i, idx := range indices {
|
||||||
|
pos += idx * m.strides[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
m.data[pos] = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add 矩阵加法
|
||||||
|
func (m *Matrix) Add(other *Matrix) (*Matrix, error) {
|
||||||
|
if !m.hasSameShape(other) {
|
||||||
|
return nil, errors.New("matrices must have the same shape for addition")
|
||||||
|
}
|
||||||
|
|
||||||
|
resultData := make([]float64, m.size)
|
||||||
|
for i := 0; i < m.size; i++ {
|
||||||
|
resultData[i] = m.data[i] + other.data[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewMatrix(resultData, m.shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subtract 矩阵减法
|
||||||
|
func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) {
|
||||||
|
if !m.hasSameShape(other) {
|
||||||
|
return nil, errors.New("matrices must have the same shape for subtraction")
|
||||||
|
}
|
||||||
|
|
||||||
|
resultData := make([]float64, m.size)
|
||||||
|
for i := 0; i < m.size; i++ {
|
||||||
|
resultData[i] = m.data[i] - other.data[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewMatrix(resultData, m.shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply 矩阵乘法(逐元素相乘)
|
||||||
|
func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) {
|
||||||
|
if !m.hasSameShape(other) {
|
||||||
|
return nil, errors.New("matrices must have the same shape for element-wise multiplication")
|
||||||
|
}
|
||||||
|
|
||||||
|
resultData := make([]float64, m.size)
|
||||||
|
for i := 0; i < m.size; i++ {
|
||||||
|
resultData[i] = m.data[i] * other.data[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewMatrix(resultData, m.shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatMul 矩阵点乘(线性代数乘法)
|
||||||
|
func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) {
|
||||||
|
if m.mdim != 2 || other.mdim != 2 {
|
||||||
|
return nil, errors.New("matmul only supported for 2D matrices")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.shape[1] != other.shape[0] {
|
||||||
|
return nil, fmt.Errorf("cannot multiply matrices with shapes [%d,%d] and [%d,%d]",
|
||||||
|
m.shape[0], m.shape[1], other.shape[0], other.shape[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, cols, inner := m.shape[0], other.shape[1], m.shape[1]
|
||||||
|
resultData := make([]float64, rows*cols)
|
||||||
|
|
||||||
|
for i := 0; i < rows; i++ {
|
||||||
|
for j := 0; j < cols; j++ {
|
||||||
|
sum := 0.0
|
||||||
|
for k := 0; k < inner; k++ {
|
||||||
|
mIdx := i*m.shape[1] + k
|
||||||
|
oIdx := k*other.shape[1] + j
|
||||||
|
sum += m.data[mIdx] * other.data[oIdx]
|
||||||
|
}
|
||||||
|
resultData[i*cols+j] = sum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewMatrix(resultData, []int{rows, cols})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scale 矩阵数乘
|
||||||
|
func (m *Matrix) Scale(factor float64) *Matrix {
|
||||||
|
resultData := make([]float64, m.size)
|
||||||
|
for i := 0; i < m.size; i++ {
|
||||||
|
resultData[i] = m.data[i] * factor
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Matrix{
|
||||||
|
data: resultData,
|
||||||
|
shape: m.shape,
|
||||||
|
strides: m.strides,
|
||||||
|
mdim: m.mdim,
|
||||||
|
size: m.size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transpose 矩阵转置(仅支持2维矩阵)
|
||||||
|
func (m *Matrix) Transpose() (*Matrix, error) {
|
||||||
|
if m.mdim != 2 {
|
||||||
|
return nil, errors.New("transpose only supported for 2D matrices")
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, cols := m.shape[0], m.shape[1]
|
||||||
|
resultData := make([]float64, m.size)
|
||||||
|
|
||||||
|
for i := 0; i < rows; i++ {
|
||||||
|
for j := 0; j < cols; j++ {
|
||||||
|
resultData[j*rows+i] = m.data[i*cols+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewMatrix(resultData, []int{cols, rows})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy 复制矩阵
|
||||||
|
func (m *Matrix) Copy() *Matrix {
|
||||||
|
dataCopy := make([]float64, m.size)
|
||||||
|
copy(dataCopy, m.data)
|
||||||
|
|
||||||
|
shapeCopy := make([]int, m.mdim)
|
||||||
|
copy(shapeCopy, m.shape)
|
||||||
|
|
||||||
|
stridesCopy := make([]int, m.mdim)
|
||||||
|
copy(stridesCopy, m.strides)
|
||||||
|
|
||||||
|
return &Matrix{
|
||||||
|
data: dataCopy,
|
||||||
|
shape: shapeCopy,
|
||||||
|
strides: stridesCopy,
|
||||||
|
mdim: m.mdim,
|
||||||
|
size: m.size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String 实现Stringer接口,用于打印矩阵
|
||||||
|
func (m *Matrix) String() string {
|
||||||
|
if m.mdim == 2 {
|
||||||
|
result := "[\n"
|
||||||
|
rows, cols := m.shape[0], m.shape[1]
|
||||||
|
for i := 0; i < rows; i++ {
|
||||||
|
result += " ["
|
||||||
|
for j := 0; j < cols; j++ {
|
||||||
|
idx := i*cols + j
|
||||||
|
result += fmt.Sprintf("%.2f", m.data[idx])
|
||||||
|
if j < cols-1 {
|
||||||
|
result += ", "
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result += "]\n"
|
||||||
|
}
|
||||||
|
result += "]"
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// 简单处理其他维度的情况
|
||||||
|
return fmt.Sprintf("Matrix{data:%v, shape:%v}", m.data, m.shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shape 返回矩阵的形状
|
||||||
|
func (m *Matrix) Shape() []int {
|
||||||
|
shape := make([]int, m.mdim)
|
||||||
|
copy(shape, m.shape)
|
||||||
|
return shape
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size 返回矩阵的大小
|
||||||
|
func (m *Matrix) Size() int {
|
||||||
|
return m.size
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasSameShape 检查两个矩阵是否具有相同的形状
|
||||||
|
func (m *Matrix) hasSameShape(other *Matrix) bool {
|
||||||
|
if m.mdim != other.mdim {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < m.mdim; i++ {
|
||||||
|
if m.shape[i] != other.shape[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue