diff --git a/matrix.lua b/matrix.lua index e9945c07fba9b0157e457581c2e3a2a2b10c4da1..e2f0577092cdcb05bec675867dccbd6359198ebd 100644 --- a/matrix.lua +++ b/matrix.lua @@ -1,8 +1,12 @@ require("kara3d.vector") +-- ********** DEFINITION AND GENERIC FUNCTIONS ********** + matrix = {class = "matrix"} matrix.__index = matrix +-- Creates a new matrix with r rows and c columns, +-- initialized with the table vals function matrix.new(r, c, vals) local m = {rows = r or 1, cols = c or 1, values = {}} vals = vals or {} @@ -13,6 +17,12 @@ function matrix.new(r, c, vals) return m end +-- Clones this matrix +function matrix:clone() + return matrix.new(self.rows, self.cols, self.values) +end + +-- Returns a string representing the matrix function matrix:tostring() local str = self.rows .. " x " .. self.cols .. "\n" for i = 1, self.rows do @@ -23,16 +33,67 @@ function matrix:tostring() return str end +-- ********** STATIC FUNCTIONS ********** + +-- Returns a matrix with r rows and c columns, +-- with 1 on the diagonal and 0 in the other cells function matrix.identity(r, c) local m = matrix.new(r, c) for i = 1, math.min(r, c) do m:set(i, i, 1) end return m end -function matrix:clone() - return matrix.new(self.rows, self.cols, self.values) +-- Returns the product of two matrices, or a scalar and a matrix +-- Only works if m1 or m2 is a scalar, +-- or if m1 has a number of columns equal to m2's number of rows +function matrix.mul_matrix(m1, m2) + if (type(m1) == "number") then return matrix.mul_scalar(m2, m1) end + if (type(m2) == "number") then return matrix.mul_scalar(m1, m2) end + if (m1.cols ~= m2.rows) then return m1 end + + local m = matrix.new(m1.rows, m2.cols) + for i = 1, m1.rows do + for j = 1, m2.cols do + local sum = 0 + for k = 1, m1.cols do + sum = sum + m1:get(i, k) * m2:get(k, j) + end + m:set(i, j, sum) + end + end + + return m end +-- Returns the product of a matrix and a vector +-- Only works if m has a number of rows equal to the size of v +function matrix.mul_vector(m, v) + if (m.cols ~= v.size) then return v end + + local vals = {} + for i = 1, m.rows do + local sum = 0 + for j = 1, m.cols do + sum = sum + m:get(i, j) * v:get(j) + end + vals[i] = sum + end + + return vector.new(m.rows, vals) +end + +-- Returns the product of a scalar and a matrix +function matrix.mul_scalar(m, s) + local r = matrix.new(m.rows, m.cols) + for i = 1, #(r.values) do + r.values[i] = s * m.values[i] + end + return r +end + +-- ********** INSTANCE FUNCTIONS ********** + +-- Returns the value on the (i, j) cell column of this matrix function matrix:get(i, j) if (i <= 0 or i > self.rows) then return 0 end if (j <= 0 or j > self.cols) then return 0 end @@ -40,6 +101,13 @@ function matrix:get(i, j) return n > #self.values and 0 or self.values[n] end +-- Changes the (i, j) cell to be value +function matrix:set(i, j, value) + self.values[(i-1) * self.cols + j] = value + return self +end + +-- Returns a vector corresponding to the ith column of this matrix function matrix:column(i) if (i <= 0 or i > self.cols) then return vector.zero(self.rows) end local vals = {} @@ -47,11 +115,11 @@ function matrix:column(i) return vector.new(self.rows, vals) end -function matrix:set(i, j, value) - self.values[(i-1) * self.cols + j] = value - return self -end +-- ********** OPERATOR OVERLOADS ********** +-- + binary operator overload +-- Only works on matrices with the same dimensions +-- Returns a matrix m' such that for all (i, j), m'(i, j) = m1(i, j) + m2(i, j) function matrix.__add(m1, m2) if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end @@ -63,6 +131,8 @@ function matrix.__add(m1, m2) return m end +-- - unary operator overload +-- Returns a matrix m' such that for all (i, j), m'(i, j) = -m(i, j) function matrix.__unm(m) local n = matrix.new(m.rows, m.cols) for i = 1, #(n.values) do @@ -71,6 +141,8 @@ function matrix.__unm(m) return n end +-- - binary operator overload +-- Returns a matrix m' such that for all (i, j), m'(i, j) = m1(i, j) - m2(i, j) function matrix.__sub(m1, m2) if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end @@ -82,48 +154,9 @@ function matrix.__sub(m1, m2) return m end -function matrix.mul_matrix(m1, m2) - if (type(m1) == "number") then return matrix.mul_scalar(m2, m1) end - if (type(m2) == "number") then return matrix.mul_scalar(m1, m2) end - if (m1.cols ~= m2.rows) then return m1 end - - local m = matrix.new(m1.rows, m2.cols) - for i = 1, m1.rows do - for j = 1, m2.cols do - local sum = 0 - for k = 1, m1.cols do - sum = sum + m1:get(i, k) * m2:get(k, j) - end - m:set(i, j, sum) - end - end - - return m -end - -function matrix.mul_vector(m, v) - if (m.cols ~= v.size) then return v end - - local vals = {} - for i = 1, m.rows do - local sum = 0 - for j = 1, m.cols do - sum = sum + m:get(i, j) * v:get(j) - end - vals[i] = sum - end - - return vector.new(m.rows, vals) -end - -function matrix.mul_scalar(m, s) - local r = matrix.new(m.rows, m.cols) - for i = 1, #(r.values) do - r.values[i] = s * m.values[i] - end - return r -end - +-- * binary operator overload +-- Returns the product of a matrix or scalar (left) +-- and a matrix, vector or scalar (right) function matrix.__mul(m, o) if (getmetatable(o) == matrix) then return matrix.mul_matrix(m, o) elseif (getmetatable(o) == vector) then return matrix.mul_vector(m, o) @@ -131,6 +164,8 @@ function matrix.__mul(m, o) return m end +-- == operator overload +-- Returns "for all (i, j), m1(i, j) == m2(i, j)" function matrix.__eq(m1, m2) if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return false end