require("kara3d.vector") -- ********** DEFINITION AND GENERIC FUNCTIONS ********** matrix = {class = "matrix"} matrix.__index = matrix -- Creates a new matrix with r rows and c columns, -- initialized with the table vals function matrix.new(r, c, vals) local m = {rows = r or 1, cols = c or 1, values = {}} vals = vals or {} for i = 1, r*c do m.values[i] = i <= #vals and vals[i] or 0 end setmetatable(m, matrix) return m end -- Clones this matrix function matrix:clone() return matrix.new(self.rows, self.cols, self.values) end -- Returns a string representing the matrix function matrix:tostring() local str = self.rows .. " x " .. self.cols .. "\n" for i = 1, self.rows do for j = 1, self.cols do str = str .. self:get(i, j) .. (j < self.cols and "\t" or "\n") end end return str end -- ********** STATIC FUNCTIONS ********** -- Returns a matrix with r rows and c columns, -- with 1 on the diagonal and 0 in the other cells function matrix.identity(r, c) local m = matrix.new(r, c) for i = 1, math.min(r, c) do m:set(i, i, 1) end return m end -- Returns the product of two matrices, or a scalar and a matrix -- Only works if m1 or m2 is a scalar, -- or if m1 has a number of columns equal to m2's number of rows function matrix.mul_matrix(m1, m2) if (type(m1) == "number") then return matrix.mul_scalar(m2, m1) end if (type(m2) == "number") then return matrix.mul_scalar(m1, m2) end if (m1.cols ~= m2.rows) then return m1 end local m = matrix.new(m1.rows, m2.cols) for i = 1, m1.rows do for j = 1, m2.cols do local sum = 0 for k = 1, m1.cols do sum = sum + m1:get(i, k) * m2:get(k, j) end m:set(i, j, sum) end end return m end -- Returns the product of a matrix and a vector -- Only works if m has a number of rows equal to the size of v function matrix.mul_vector(m, v) if (m.cols ~= v.size) then return v end local vals = {} for i = 1, m.rows do local sum = 0 for j = 1, m.cols do sum = sum + m:get(i, j) * v:get(j) end vals[i] = sum end return vector.new(m.rows, vals) end -- Returns the product of a scalar and a matrix function matrix.mul_scalar(m, s) local r = matrix.new(m.rows, m.cols) for i = 1, #(r.values) do r.values[i] = s * m.values[i] end return r end -- ********** INSTANCE FUNCTIONS ********** -- Returns the value on the (i, j) cell column of this matrix function matrix:get(i, j) if (i <= 0 or i > self.rows) then return 0 end if (j <= 0 or j > self.cols) then return 0 end local n = (i-1) * self.cols + j return n > #self.values and 0 or self.values[n] end -- Changes the (i, j) cell to be value function matrix:set(i, j, value) self.values[(i-1) * self.cols + j] = value return self end -- Returns a vector corresponding to the ith column of this matrix function matrix:column(i) if (i <= 0 or i > self.cols) then return vector.zero(self.rows) end local vals = {} for j = 1, self.rows do vals[j] = self:get(j, i) end return vector.new(self.rows, vals) end -- ********** OPERATOR OVERLOADS ********** -- + binary operator overload -- Only works on matrices with the same dimensions -- Returns a matrix m' such that for all (i, j), m'(i, j) = m1(i, j) + m2(i, j) function matrix.__add(m1, m2) if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end local m = m1:clone() for i = 1, #(m.values) do m.values[i] = m.values[i] + m2.values[i] end return m end -- - unary operator overload -- Returns a matrix m' such that for all (i, j), m'(i, j) = -m(i, j) function matrix.__unm(m) local n = matrix.new(m.rows, m.cols) for i = 1, #(n.values) do n.values[i] = -m.values[i] end return n end -- - binary operator overload -- Returns a matrix m' such that for all (i, j), m'(i, j) = m1(i, j) - m2(i, j) function matrix.__sub(m1, m2) if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end local m = matrix.new(m1.rows, m1.cols) for i = 1, #(m.values) do m.values[i] = m1.values[i] - m2.values[i] end return m end -- * binary operator overload -- Returns the product of a matrix or scalar (left) -- and a matrix, vector or scalar (right) function matrix.__mul(m, o) if (getmetatable(o) == matrix) then return matrix.mul_matrix(m, o) elseif (getmetatable(o) == vector) then return matrix.mul_vector(m, o) elseif (type(o) == "number") then return matrix.mul_scalar(m, o) end return m end -- == operator overload -- Returns "for all (i, j), m1(i, j) == m2(i, j)" function matrix.__eq(m1, m2) if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return false end for i = 1, #(m1.values) do if (m1.values[i] ~= m2.values[i]) then return false end end return true end