diff --git a/matrix.go b/matrix.go index 4c28431..34466c5 100644 --- a/matrix.go +++ b/matrix.go @@ -84,6 +84,8 @@ func NewOnes(shape []int) (*Matrix, error) { } return NewMatrix(data, shape) } + +// NewVector 创建一个向量 func NewVector(data []float64) (*Matrix, error) { return NewMatrix(data, []int{len(data), 1}) } @@ -239,6 +241,20 @@ func (m *Matrix) Scale(factor float64) *Matrix { } } +func (m *Matrix) Equal(other *Matrix) bool { + if m.shape[0] != other.shape[0] || m.shape[1] != other.shape[1] { + return false + } + + for i := 0; i < m.size; i++ { + if m.data[i] != other.data[i] { + return false + } + } + + return true +} + // Transpose 矩阵转置(仅支持2维矩阵) func (m *Matrix) Transpose() (*Matrix, error) { if m.mdim != 2 {