Sélectionner une révision Git
matrix.lua 4,63 Kio
require("kara3d.vector")
-- ********** DEFINITION AND GENERIC FUNCTIONS **********
matrix = {class = "matrix"}
matrix.__index = matrix
-- Creates a new matrix with r rows and c columns,
-- initialized with the table vals
function matrix.new(r, c, vals)
local m = {rows = r or 1, cols = c or 1, values = {}}
vals = vals or {}
for i = 1, r*c do
m.values[i] = i <= #vals and vals[i] or 0
end
setmetatable(m, matrix)
return m
end
-- Clones this matrix
function matrix:clone()
return matrix.new(self.rows, self.cols, self.values)
end
-- Returns a string representing the matrix
function matrix:tostring()
local str = self.rows .. " x " .. self.cols .. "\n"
for i = 1, self.rows do
for j = 1, self.cols do
str = str .. self:get(i, j) .. (j < self.cols and "\t" or "\n")
end
end
return str
end
-- ********** STATIC FUNCTIONS **********
-- Returns a matrix with r rows and c columns,
-- with 1 on the diagonal and 0 in the other cells
function matrix.identity(r, c)
local m = matrix.new(r, c)
for i = 1, math.min(r, c) do m:set(i, i, 1) end
return m
end
-- Returns the product of two matrices, or a scalar and a matrix
-- Only works if m1 or m2 is a scalar,
-- or if m1 has a number of columns equal to m2's number of rows
function matrix.mul_matrix(m1, m2)
if (type(m1) == "number") then return matrix.mul_scalar(m2, m1) end
if (type(m2) == "number") then return matrix.mul_scalar(m1, m2) end
if (m1.cols ~= m2.rows) then return m1 end
local m = matrix.new(m1.rows, m2.cols)
for i = 1, m1.rows do
for j = 1, m2.cols do
local sum = 0
for k = 1, m1.cols do
sum = sum + m1:get(i, k) * m2:get(k, j)
end
m:set(i, j, sum)
end
end
return m
end
-- Returns the product of a matrix and a vector
-- Only works if m has a number of rows equal to the size of v
function matrix.mul_vector(m, v)
if (m.cols ~= v.size) then return v end
local vals = {}
for i = 1, m.rows do
local sum = 0
for j = 1, m.cols do
sum = sum + m:get(i, j) * v:get(j)
end
vals[i] = sum
end
return vector.new(m.rows, vals)
end
-- Returns the product of a scalar and a matrix
function matrix.mul_scalar(m, s)
local r = matrix.new(m.rows, m.cols)
for i = 1, #(r.values) do
r.values[i] = s * m.values[i]
end
return r
end
-- ********** INSTANCE FUNCTIONS **********
-- Returns the value on the (i, j) cell column of this matrix
function matrix:get(i, j)
if (i <= 0 or i > self.rows) then return 0 end
if (j <= 0 or j > self.cols) then return 0 end
local n = (i-1) * self.cols + j
return n > #self.values and 0 or self.values[n]
end
-- Changes the (i, j) cell to be value
function matrix:set(i, j, value)
self.values[(i-1) * self.cols + j] = value
return self
end
-- Returns a vector corresponding to the ith column of this matrix
function matrix:column(i)
if (i <= 0 or i > self.cols) then return vector.zero(self.rows) end
local vals = {}
for j = 1, self.rows do vals[j] = self:get(j, i) end
return vector.new(self.rows, vals)
end
-- ********** OPERATOR OVERLOADS **********
-- + binary operator overload
-- Only works on matrices with the same dimensions
-- Returns a matrix m' such that for all (i, j), m'(i, j) = m1(i, j) + m2(i, j)
function matrix.__add(m1, m2)
if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end
local m = m1:clone()
for i = 1, #(m.values) do
m.values[i] = m.values[i] + m2.values[i]
end
return m
end
-- - unary operator overload
-- Returns a matrix m' such that for all (i, j), m'(i, j) = -m(i, j)
function matrix.__unm(m)
local n = matrix.new(m.rows, m.cols)
for i = 1, #(n.values) do
n.values[i] = -m.values[i]
end
return n
end
-- - binary operator overload
-- Returns a matrix m' such that for all (i, j), m'(i, j) = m1(i, j) - m2(i, j)
function matrix.__sub(m1, m2)
if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end
local m = matrix.new(m1.rows, m1.cols)
for i = 1, #(m.values) do
m.values[i] = m1.values[i] - m2.values[i]
end
return m
end
-- * binary operator overload
-- Returns the product of a matrix or scalar (left)
-- and a matrix, vector or scalar (right)
function matrix.__mul(m, o)
if (getmetatable(o) == matrix) then return matrix.mul_matrix(m, o)
elseif (getmetatable(o) == vector) then return matrix.mul_vector(m, o)
elseif (type(o) == "number") then return matrix.mul_scalar(m, o) end
return m
end
-- == operator overload
-- Returns "for all (i, j), m1(i, j) == m2(i, j)"
function matrix.__eq(m1, m2)
if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return false end
for i = 1, #(m1.values) do
if (m1.values[i] ~= m2.values[i]) then return false end
end
return true
end