diff --git a/matrix.go b/matrix.go index 3c82841..4c28431 100644 --- a/matrix.go +++ b/matrix.go @@ -84,6 +84,9 @@ func NewOnes(shape []int) (*Matrix, error) { } return NewMatrix(data, shape) } +func NewVector(data []float64) (*Matrix, error) { + return NewMatrix(data, []int{len(data), 1}) +} // NewIdentity 创建一个单位矩阵 // size: 单位矩阵的大小 @@ -197,7 +200,7 @@ func (m *Matrix) MatMul(other *Matrix) (*Matrix, error) { } if m.shape[1] != other.shape[0] { - return nil, fmt.Errorf("cannot multiply matrices with shapes [%d,%d] and [%d,%d]", + return nil, fmt.Errorf("cannot multiply matrices with shapes [%d,%d] and [%d,%d]", m.shape[0], m.shape[1], other.shape[0], other.shape[1]) } @@ -323,4 +326,4 @@ func (m *Matrix) hasSameShape(other *Matrix) bool { } return true -} \ No newline at end of file +}