```
feat: 添加Go矩阵库基础功能 - 实现矩阵数据结构和基本操作 - 添加矩阵创建、加法、乘法、转置等功能 - 实现零矩阵、单位矩阵、全一矩阵创建方法 - 添加矩阵元素访问和修改功能 - 提供完整的矩阵运算示例程序 - 添加项目许可证文件 - 配置Go模块依赖 ```
This commit is contained in:
parent
6f07b38370
commit
b3b9017dd9
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2025 kingecg
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.kingecg.top/kingecg/gomatrix"
|
||||
)
|
||||
|
||||
// 示例程序演示矩阵功能
|
||||
func Example() {
|
||||
fmt.Println("创建矩阵示例:")
|
||||
|
||||
// 创建一个2x3的矩阵
|
||||
data1 := []float64{1, 2, 3, 4, 5, 6}
|
||||
mat1, err := gomatrix.NewMatrix(data1, []int{2, 3})
|
||||
if err != nil {
|
||||
fmt.Printf("创建矩阵失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("矩阵1:\n%s\n", mat1.String())
|
||||
|
||||
// 创建另一个2x3的矩阵
|
||||
data2 := []float64{7, 8, 9, 10, 11, 12}
|
||||
mat2, err := gomatrix.NewMatrix(data2, []int{2, 3})
|
||||
if err != nil {
|
||||
fmt.Printf("创建矩阵失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("矩阵2:\n%s\n", mat2.String())
|
||||
|
||||
// 矩阵加法
|
||||
sum, err := mat1.Add(mat2)
|
||||
if err != nil {
|
||||
fmt.Printf("矩阵加法失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("矩阵1 + 矩阵2:\n%s\n", sum.String())
|
||||
|
||||
// 创建一个3x2的矩阵用于矩阵乘法
|
||||
data3 := []float64{1, 2, 3, 4, 5, 6}
|
||||
mat3, err := gomatrix.NewMatrix(data3, []int{3, 2})
|
||||
if err != nil {
|
||||
fmt.Printf("创建矩阵失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("矩阵3 (3x2):\n%s\n", mat3.String())
|
||||
|
||||
// 创建一个2x2的矩阵用于矩阵乘法
|
||||
data4 := []float64{1, 2, 3, 4}
|
||||
mat4, err := gomatrix.NewMatrix(data4, []int{2, 2})
|
||||
if err != nil {
|
||||
fmt.Printf("创建矩阵失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("矩阵4 (2x2):\n%s\n", mat4.String())
|
||||
|
||||
// 矩阵乘法
|
||||
product, err := mat3.MatMul(mat4)
|
||||
if err != nil {
|
||||
fmt.Printf("矩阵乘法失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("矩阵3 × 矩阵4:\n%s\n", product.String())
|
||||
|
||||
// 矩阵转置
|
||||
transposed, err := mat1.Transpose()
|
||||
if err != nil {
|
||||
fmt.Printf("矩阵转置失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("矩阵1的转置:\n%s\n", transposed.String())
|
||||
|
||||
// 矩阵数乘
|
||||
scaled := mat1.Scale(2.0)
|
||||
fmt.Printf("矩阵1 × 2:\n%s\n", scaled.String())
|
||||
|
||||
// 创建零矩阵
|
||||
zeros, err := gomatrix.NewZeros([]int{2, 2})
|
||||
if err != nil {
|
||||
fmt.Printf("创建零矩阵失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("2x2零矩阵:\n%s\n", zeros.String())
|
||||
|
||||
// 创建单位矩阵
|
||||
identity, err := gomatrix.NewIdentity(3)
|
||||
if err != nil {
|
||||
fmt.Printf("创建单位矩阵失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("3x3单位矩阵:\n%s\n", identity.String())
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
package main
|
||||
|
||||
func main() {
|
||||
Example()
|
||||
}
|
||||
|
|
@ -0,0 +1,312 @@
|
|||
package gomatrix
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Matrix struct {
|
||||
data []float64
|
||||
shape []int
|
||||
strides []int
|
||||
mdim int
|
||||
size int
|
||||
}
|
||||
|
||||
// NewMatrix 创建一个新的矩阵
|
||||
func NewMatrix(data []float64, shape []int) (*Matrix, error) {
|
||||
if len(data) == 0 || len(shape) == 0 {
|
||||
return nil, errors.New("data and shape cannot be empty")
|
||||
}
|
||||
|
||||
// 计算期望的大小
|
||||
expectedSize := 1
|
||||
for _, dim := range shape {
|
||||
if dim <= 0 {
|
||||
return nil, errors.New("all dimensions must be positive")
|
||||
}
|
||||
expectedSize *= dim
|
||||
}
|
||||
|
||||
if len(data) != expectedSize {
|
||||
return nil, fmt.Errorf("data length %d does not match expected size %d for given shape", len(data), expectedSize)
|
||||
}
|
||||
|
||||
// 计算步长
|
||||
strides := make([]int, len(shape))
|
||||
stride := 1
|
||||
for i := len(shape) - 1; i >= 0; i-- {
|
||||
strides[i] = stride
|
||||
stride *= shape[i]
|
||||
}
|
||||
|
||||
return &Matrix{
|
||||
data: data,
|
||||
shape: shape,
|
||||
strides: strides,
|
||||
mdim: len(shape),
|
||||
size: expectedSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewZeros 创建一个全零矩阵
|
||||
func NewZeros(shape []int) (*Matrix, error) {
|
||||
expectedSize := 1
|
||||
for _, dim := range shape {
|
||||
if dim <= 0 {
|
||||
return nil, errors.New("all dimensions must be positive")
|
||||
}
|
||||
expectedSize *= dim
|
||||
}
|
||||
|
||||
data := make([]float64, expectedSize)
|
||||
return NewMatrix(data, shape)
|
||||
}
|
||||
|
||||
// NewOnes 创建一个全一矩阵
|
||||
func NewOnes(shape []int) (*Matrix, error) {
|
||||
expectedSize := 1
|
||||
for _, dim := range shape {
|
||||
if dim <= 0 {
|
||||
return nil, errors.New("all dimensions must be positive")
|
||||
}
|
||||
expectedSize *= dim
|
||||
}
|
||||
|
||||
data := make([]float64, expectedSize)
|
||||
for i := range data {
|
||||
data[i] = 1.0
|
||||
}
|
||||
return NewMatrix(data, shape)
|
||||
}
|
||||
|
||||
// NewIdentity 创建一个单位矩阵
|
||||
func NewIdentity(size int) (*Matrix, error) {
|
||||
if size <= 0 {
|
||||
return nil, errors.New("size must be positive")
|
||||
}
|
||||
|
||||
data := make([]float64, size*size)
|
||||
for i := 0; i < size; i++ {
|
||||
data[i*size+i] = 1.0
|
||||
}
|
||||
|
||||
return NewMatrix(data, []int{size, size})
|
||||
}
|
||||
|
||||
// Get 获取指定位置的值
|
||||
func (m *Matrix) Get(indices ...int) (float64, error) {
|
||||
if len(indices) != m.mdim {
|
||||
return 0, errors.New("number of indices must match matrix dimensions")
|
||||
}
|
||||
|
||||
for i, idx := range indices {
|
||||
if idx < 0 || idx >= m.shape[i] {
|
||||
return 0, errors.New("index out of bounds")
|
||||
}
|
||||
}
|
||||
|
||||
pos := 0
|
||||
for i, idx := range indices {
|
||||
pos += idx * m.strides[i]
|
||||
}
|
||||
|
||||
return m.data[pos], nil
|
||||
}
|
||||
|
||||
// Set 设置指定位置的值
|
||||
func (m *Matrix) Set(value float64, indices ...int) error {
|
||||
if len(indices) != m.mdim {
|
||||
return errors.New("number of indices must match matrix dimensions")
|
||||
}
|
||||
|
||||
for i, idx := range indices {
|
||||
if idx < 0 || idx >= m.shape[i] {
|
||||
return errors.New("index out of bounds")
|
||||
}
|
||||
}
|
||||
|
||||
pos := 0
|
||||
for i, idx := range indices {
|
||||
pos += idx * m.strides[i]
|
||||
}
|
||||
|
||||
m.data[pos] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add 矩阵加法
|
||||
func (m *Matrix) Add(other *Matrix) (*Matrix, error) {
|
||||
if !m.hasSameShape(other) {
|
||||
return nil, errors.New("matrices must have the same shape for addition")
|
||||
}
|
||||
|
||||
resultData := make([]float64, m.size)
|
||||
for i := 0; i < m.size; i++ {
|
||||
resultData[i] = m.data[i] + other.data[i]
|
||||
}
|
||||
|
||||
return NewMatrix(resultData, m.shape)
|
||||
}
|
||||
|
||||
// Subtract 矩阵减法
|
||||
func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) {
|
||||
if !m.hasSameShape(other) {
|
||||
return nil, errors.New("matrices must have the same shape for subtraction")
|
||||
}
|
||||
|
||||
resultData := make([]float64, m.size)
|
||||
for i := 0; i < m.size; i++ {
|
||||
resultData[i] = m.data[i] - other.data[i]
|
||||
}
|
||||
|
||||
return NewMatrix(resultData, m.shape)
|
||||
}
|
||||
|
||||
// Multiply 矩阵乘法(逐元素相乘)
|
||||
func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) {
|
||||
if !m.hasSameShape(other) {
|
||||
return nil, errors.New("matrices must have the same shape for element-wise multiplication")
|
||||
}
|
||||
|
||||
resultData := make([]float64, m.size)
|
||||
for i := 0; i < m.size; i++ {
|
||||
resultData[i] = m.data[i] * other.data[i]
|
||||
}
|
||||
|
||||
return NewMatrix(resultData, m.shape)
|
||||
}
|
||||
|
||||
// MatMul 矩阵点乘(线性代数乘法)
|
||||
func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) {
|
||||
if m.mdim != 2 || other.mdim != 2 {
|
||||
return nil, errors.New("matmul only supported for 2D matrices")
|
||||
}
|
||||
|
||||
if m.shape[1] != other.shape[0] {
|
||||
return nil, fmt.Errorf("cannot multiply matrices with shapes [%d,%d] and [%d,%d]",
|
||||
m.shape[0], m.shape[1], other.shape[0], other.shape[1])
|
||||
}
|
||||
|
||||
rows, cols, inner := m.shape[0], other.shape[1], m.shape[1]
|
||||
resultData := make([]float64, rows*cols)
|
||||
|
||||
for i := 0; i < rows; i++ {
|
||||
for j := 0; j < cols; j++ {
|
||||
sum := 0.0
|
||||
for k := 0; k < inner; k++ {
|
||||
mIdx := i*m.shape[1] + k
|
||||
oIdx := k*other.shape[1] + j
|
||||
sum += m.data[mIdx] * other.data[oIdx]
|
||||
}
|
||||
resultData[i*cols+j] = sum
|
||||
}
|
||||
}
|
||||
|
||||
return NewMatrix(resultData, []int{rows, cols})
|
||||
}
|
||||
|
||||
// Scale 矩阵数乘
|
||||
func (m *Matrix) Scale(factor float64) *Matrix {
|
||||
resultData := make([]float64, m.size)
|
||||
for i := 0; i < m.size; i++ {
|
||||
resultData[i] = m.data[i] * factor
|
||||
}
|
||||
|
||||
return &Matrix{
|
||||
data: resultData,
|
||||
shape: m.shape,
|
||||
strides: m.strides,
|
||||
mdim: m.mdim,
|
||||
size: m.size,
|
||||
}
|
||||
}
|
||||
|
||||
// Transpose 矩阵转置(仅支持2维矩阵)
|
||||
func (m *Matrix) Transpose() (*Matrix, error) {
|
||||
if m.mdim != 2 {
|
||||
return nil, errors.New("transpose only supported for 2D matrices")
|
||||
}
|
||||
|
||||
rows, cols := m.shape[0], m.shape[1]
|
||||
resultData := make([]float64, m.size)
|
||||
|
||||
for i := 0; i < rows; i++ {
|
||||
for j := 0; j < cols; j++ {
|
||||
resultData[j*rows+i] = m.data[i*cols+j]
|
||||
}
|
||||
}
|
||||
|
||||
return NewMatrix(resultData, []int{cols, rows})
|
||||
}
|
||||
|
||||
// Copy 复制矩阵
|
||||
func (m *Matrix) Copy() *Matrix {
|
||||
dataCopy := make([]float64, m.size)
|
||||
copy(dataCopy, m.data)
|
||||
|
||||
shapeCopy := make([]int, m.mdim)
|
||||
copy(shapeCopy, m.shape)
|
||||
|
||||
stridesCopy := make([]int, m.mdim)
|
||||
copy(stridesCopy, m.strides)
|
||||
|
||||
return &Matrix{
|
||||
data: dataCopy,
|
||||
shape: shapeCopy,
|
||||
strides: stridesCopy,
|
||||
mdim: m.mdim,
|
||||
size: m.size,
|
||||
}
|
||||
}
|
||||
|
||||
// String 实现Stringer接口,用于打印矩阵
|
||||
func (m *Matrix) String() string {
|
||||
if m.mdim == 2 {
|
||||
result := "[\n"
|
||||
rows, cols := m.shape[0], m.shape[1]
|
||||
for i := 0; i < rows; i++ {
|
||||
result += " ["
|
||||
for j := 0; j < cols; j++ {
|
||||
idx := i*cols + j
|
||||
result += fmt.Sprintf("%.2f", m.data[idx])
|
||||
if j < cols-1 {
|
||||
result += ", "
|
||||
}
|
||||
}
|
||||
result += "]\n"
|
||||
}
|
||||
result += "]"
|
||||
return result
|
||||
}
|
||||
|
||||
// 简单处理其他维度的情况
|
||||
return fmt.Sprintf("Matrix{data:%v, shape:%v}", m.data, m.shape)
|
||||
}
|
||||
|
||||
// Shape 返回矩阵的形状
|
||||
func (m *Matrix) Shape() []int {
|
||||
shape := make([]int, m.mdim)
|
||||
copy(shape, m.shape)
|
||||
return shape
|
||||
}
|
||||
|
||||
// Size 返回矩阵的大小
|
||||
func (m *Matrix) Size() int {
|
||||
return m.size
|
||||
}
|
||||
|
||||
// hasSameShape 检查两个矩阵是否具有相同的形状
|
||||
func (m *Matrix) hasSameShape(other *Matrix) bool {
|
||||
if m.mdim != other.mdim {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i < m.mdim; i++ {
|
||||
if m.shape[i] != other.shape[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
Loading…
Reference in New Issue