diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..35b8419 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 kingecg + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/exmaples/example.go b/exmaples/example.go new file mode 100644 index 0000000..229fad9 --- /dev/null +++ b/exmaples/example.go @@ -0,0 +1,92 @@ +package main + +import ( + "fmt" + + "git.kingecg.top/kingecg/gomatrix" +) + +// 示例程序演示矩阵功能 +func Example() { + fmt.Println("创建矩阵示例:") + + // 创建一个2x3的矩阵 + data1 := []float64{1, 2, 3, 4, 5, 6} + mat1, err := gomatrix.NewMatrix(data1, []int{2, 3}) + if err != nil { + fmt.Printf("创建矩阵失败: %v\n", err) + return + } + fmt.Printf("矩阵1:\n%s\n", mat1.String()) + + // 创建另一个2x3的矩阵 + data2 := []float64{7, 8, 9, 10, 11, 12} + mat2, err := gomatrix.NewMatrix(data2, []int{2, 3}) + if err != nil { + fmt.Printf("创建矩阵失败: %v\n", err) + return + } + fmt.Printf("矩阵2:\n%s\n", mat2.String()) + + // 矩阵加法 + sum, err := mat1.Add(mat2) + if err != nil { + fmt.Printf("矩阵加法失败: %v\n", err) + return + } + fmt.Printf("矩阵1 + 矩阵2:\n%s\n", sum.String()) + + // 创建一个3x2的矩阵用于矩阵乘法 + data3 := []float64{1, 2, 3, 4, 5, 6} + mat3, err := gomatrix.NewMatrix(data3, []int{3, 2}) + if err != nil { + fmt.Printf("创建矩阵失败: %v\n", err) + return + } + fmt.Printf("矩阵3 (3x2):\n%s\n", mat3.String()) + + // 创建一个2x2的矩阵用于矩阵乘法 + data4 := []float64{1, 2, 3, 4} + mat4, err := gomatrix.NewMatrix(data4, []int{2, 2}) + if err != nil { + fmt.Printf("创建矩阵失败: %v\n", err) + return + } + fmt.Printf("矩阵4 (2x2):\n%s\n", mat4.String()) + + // 矩阵乘法 + product, err := mat3.MatMul(mat4) + if err != nil { + fmt.Printf("矩阵乘法失败: %v\n", err) + return + } + fmt.Printf("矩阵3 × 矩阵4:\n%s\n", product.String()) + + // 矩阵转置 + transposed, err := mat1.Transpose() + if err != nil { + fmt.Printf("矩阵转置失败: %v\n", err) + return + } + fmt.Printf("矩阵1的转置:\n%s\n", transposed.String()) + + // 矩阵数乘 + scaled := mat1.Scale(2.0) + fmt.Printf("矩阵1 × 2:\n%s\n", scaled.String()) + + // 创建零矩阵 + zeros, err := gomatrix.NewZeros([]int{2, 2}) + if err != nil { + fmt.Printf("创建零矩阵失败: %v\n", err) + return + } + fmt.Printf("2x2零矩阵:\n%s\n", zeros.String()) + + // 创建单位矩阵 + identity, err := gomatrix.NewIdentity(3) + if err != nil { + fmt.Printf("创建单位矩阵失败: %v\n", err) + return + } + fmt.Printf("3x3单位矩阵:\n%s\n", identity.String()) +} diff --git a/exmaples/main.go b/exmaples/main.go new file mode 100644 index 0000000..7dbb448 --- /dev/null +++ b/exmaples/main.go @@ -0,0 +1,5 @@ +package main + +func main() { + Example() +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ec89284 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.kingecg.top/kingecg/gomatrix + +go 1.25.1 diff --git a/matrix.go b/matrix.go new file mode 100644 index 0000000..207c32f --- /dev/null +++ b/matrix.go @@ -0,0 +1,312 @@ +package gomatrix + +import ( + "errors" + "fmt" +) + +type Matrix struct { + data []float64 + shape []int + strides []int + mdim int + size int +} + +// NewMatrix 创建一个新的矩阵 +func NewMatrix(data []float64, shape []int) (*Matrix, error) { + if len(data) == 0 || len(shape) == 0 { + return nil, errors.New("data and shape cannot be empty") + } + + // 计算期望的大小 + expectedSize := 1 + for _, dim := range shape { + if dim <= 0 { + return nil, errors.New("all dimensions must be positive") + } + expectedSize *= dim + } + + if len(data) != expectedSize { + return nil, fmt.Errorf("data length %d does not match expected size %d for given shape", len(data), expectedSize) + } + + // 计算步长 + strides := make([]int, len(shape)) + stride := 1 + for i := len(shape) - 1; i >= 0; i-- { + strides[i] = stride + stride *= shape[i] + } + + return &Matrix{ + data: data, + shape: shape, + strides: strides, + mdim: len(shape), + size: expectedSize, + }, nil +} + +// NewZeros 创建一个全零矩阵 +func NewZeros(shape []int) (*Matrix, error) { + expectedSize := 1 + for _, dim := range shape { + if dim <= 0 { + return nil, errors.New("all dimensions must be positive") + } + expectedSize *= dim + } + + data := make([]float64, expectedSize) + return NewMatrix(data, shape) +} + +// NewOnes 创建一个全一矩阵 +func NewOnes(shape []int) (*Matrix, error) { + expectedSize := 1 + for _, dim := range shape { + if dim <= 0 { + return nil, errors.New("all dimensions must be positive") + } + expectedSize *= dim + } + + data := make([]float64, expectedSize) + for i := range data { + data[i] = 1.0 + } + return NewMatrix(data, shape) +} + +// NewIdentity 创建一个单位矩阵 +func NewIdentity(size int) (*Matrix, error) { + if size <= 0 { + return nil, errors.New("size must be positive") + } + + data := make([]float64, size*size) + for i := 0; i < size; i++ { + data[i*size+i] = 1.0 + } + + return NewMatrix(data, []int{size, size}) +} + +// Get 获取指定位置的值 +func (m *Matrix) Get(indices ...int) (float64, error) { + if len(indices) != m.mdim { + return 0, errors.New("number of indices must match matrix dimensions") + } + + for i, idx := range indices { + if idx < 0 || idx >= m.shape[i] { + return 0, errors.New("index out of bounds") + } + } + + pos := 0 + for i, idx := range indices { + pos += idx * m.strides[i] + } + + return m.data[pos], nil +} + +// Set 设置指定位置的值 +func (m *Matrix) Set(value float64, indices ...int) error { + if len(indices) != m.mdim { + return errors.New("number of indices must match matrix dimensions") + } + + for i, idx := range indices { + if idx < 0 || idx >= m.shape[i] { + return errors.New("index out of bounds") + } + } + + pos := 0 + for i, idx := range indices { + pos += idx * m.strides[i] + } + + m.data[pos] = value + return nil +} + +// Add 矩阵加法 +func (m *Matrix) Add(other *Matrix) (*Matrix, error) { + if !m.hasSameShape(other) { + return nil, errors.New("matrices must have the same shape for addition") + } + + resultData := make([]float64, m.size) + for i := 0; i < m.size; i++ { + resultData[i] = m.data[i] + other.data[i] + } + + return NewMatrix(resultData, m.shape) +} + +// Subtract 矩阵减法 +func (m *Matrix) Subtract(other *Matrix) (*Matrix, error) { + if !m.hasSameShape(other) { + return nil, errors.New("matrices must have the same shape for subtraction") + } + + resultData := make([]float64, m.size) + for i := 0; i < m.size; i++ { + resultData[i] = m.data[i] - other.data[i] + } + + return NewMatrix(resultData, m.shape) +} + +// Multiply 矩阵乘法(逐元素相乘) +func (m *Matrix) Multiply(other *Matrix) (*Matrix, error) { + if !m.hasSameShape(other) { + return nil, errors.New("matrices must have the same shape for element-wise multiplication") + } + + resultData := make([]float64, m.size) + for i := 0; i < m.size; i++ { + resultData[i] = m.data[i] * other.data[i] + } + + return NewMatrix(resultData, m.shape) +} + +// MatMul 矩阵点乘(线性代数乘法) +func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) { + if m.mdim != 2 || other.mdim != 2 { + return nil, errors.New("matmul only supported for 2D matrices") + } + + if m.shape[1] != other.shape[0] { + return nil, fmt.Errorf("cannot multiply matrices with shapes [%d,%d] and [%d,%d]", + m.shape[0], m.shape[1], other.shape[0], other.shape[1]) + } + + rows, cols, inner := m.shape[0], other.shape[1], m.shape[1] + resultData := make([]float64, rows*cols) + + for i := 0; i < rows; i++ { + for j := 0; j < cols; j++ { + sum := 0.0 + for k := 0; k < inner; k++ { + mIdx := i*m.shape[1] + k + oIdx := k*other.shape[1] + j + sum += m.data[mIdx] * other.data[oIdx] + } + resultData[i*cols+j] = sum + } + } + + return NewMatrix(resultData, []int{rows, cols}) +} + +// Scale 矩阵数乘 +func (m *Matrix) Scale(factor float64) *Matrix { + resultData := make([]float64, m.size) + for i := 0; i < m.size; i++ { + resultData[i] = m.data[i] * factor + } + + return &Matrix{ + data: resultData, + shape: m.shape, + strides: m.strides, + mdim: m.mdim, + size: m.size, + } +} + +// Transpose 矩阵转置(仅支持2维矩阵) +func (m *Matrix) Transpose() (*Matrix, error) { + if m.mdim != 2 { + return nil, errors.New("transpose only supported for 2D matrices") + } + + rows, cols := m.shape[0], m.shape[1] + resultData := make([]float64, m.size) + + for i := 0; i < rows; i++ { + for j := 0; j < cols; j++ { + resultData[j*rows+i] = m.data[i*cols+j] + } + } + + return NewMatrix(resultData, []int{cols, rows}) +} + +// Copy 复制矩阵 +func (m *Matrix) Copy() *Matrix { + dataCopy := make([]float64, m.size) + copy(dataCopy, m.data) + + shapeCopy := make([]int, m.mdim) + copy(shapeCopy, m.shape) + + stridesCopy := make([]int, m.mdim) + copy(stridesCopy, m.strides) + + return &Matrix{ + data: dataCopy, + shape: shapeCopy, + strides: stridesCopy, + mdim: m.mdim, + size: m.size, + } +} + +// String 实现Stringer接口,用于打印矩阵 +func (m *Matrix) String() string { + if m.mdim == 2 { + result := "[\n" + rows, cols := m.shape[0], m.shape[1] + for i := 0; i < rows; i++ { + result += " [" + for j := 0; j < cols; j++ { + idx := i*cols + j + result += fmt.Sprintf("%.2f", m.data[idx]) + if j < cols-1 { + result += ", " + } + } + result += "]\n" + } + result += "]" + return result + } + + // 简单处理其他维度的情况 + return fmt.Sprintf("Matrix{data:%v, shape:%v}", m.data, m.shape) +} + +// Shape 返回矩阵的形状 +func (m *Matrix) Shape() []int { + shape := make([]int, m.mdim) + copy(shape, m.shape) + return shape +} + +// Size 返回矩阵的大小 +func (m *Matrix) Size() int { + return m.size +} + +// hasSameShape 检查两个矩阵是否具有相同的形状 +func (m *Matrix) hasSameShape(other *Matrix) bool { + if m.mdim != other.mdim { + return false + } + + for i := 0; i < m.mdim; i++ { + if m.shape[i] != other.shape[i] { + return false + } + } + + return true +} \ No newline at end of file