Skip to content
Extraits de code Groupes Projets
Valider c8723ac1 rédigé par ultrakatiz's avatar ultrakatiz
Parcourir les fichiers

cleaned up matrix.lua

parent 6ee4b70e
Aucune branche associée trouvée
Aucune étiquette associée trouvée
Aucune requête de fusion associée trouvée
require("kara3d.vector")
-- ********** DEFINITION AND GENERIC FUNCTIONS **********
matrix = {class = "matrix"}
matrix.__index = matrix
-- Creates a new matrix with r rows and c columns,
-- initialized with the table vals
function matrix.new(r, c, vals)
local m = {rows = r or 1, cols = c or 1, values = {}}
vals = vals or {}
......@@ -13,6 +17,12 @@ function matrix.new(r, c, vals)
return m
end
-- Clones this matrix
function matrix:clone()
return matrix.new(self.rows, self.cols, self.values)
end
-- Returns a string representing the matrix
function matrix:tostring()
local str = self.rows .. " x " .. self.cols .. "\n"
for i = 1, self.rows do
......@@ -23,16 +33,67 @@ function matrix:tostring()
return str
end
-- ********** STATIC FUNCTIONS **********
-- Returns a matrix with r rows and c columns,
-- with 1 on the diagonal and 0 in the other cells
function matrix.identity(r, c)
local m = matrix.new(r, c)
for i = 1, math.min(r, c) do m:set(i, i, 1) end
return m
end
function matrix:clone()
return matrix.new(self.rows, self.cols, self.values)
-- Returns the product of two matrices, or a scalar and a matrix
-- Only works if m1 or m2 is a scalar,
-- or if m1 has a number of columns equal to m2's number of rows
function matrix.mul_matrix(m1, m2)
if (type(m1) == "number") then return matrix.mul_scalar(m2, m1) end
if (type(m2) == "number") then return matrix.mul_scalar(m1, m2) end
if (m1.cols ~= m2.rows) then return m1 end
local m = matrix.new(m1.rows, m2.cols)
for i = 1, m1.rows do
for j = 1, m2.cols do
local sum = 0
for k = 1, m1.cols do
sum = sum + m1:get(i, k) * m2:get(k, j)
end
m:set(i, j, sum)
end
end
return m
end
-- Returns the product of a matrix and a vector
-- Only works if m has a number of rows equal to the size of v
function matrix.mul_vector(m, v)
if (m.cols ~= v.size) then return v end
local vals = {}
for i = 1, m.rows do
local sum = 0
for j = 1, m.cols do
sum = sum + m:get(i, j) * v:get(j)
end
vals[i] = sum
end
return vector.new(m.rows, vals)
end
-- Returns the product of a scalar and a matrix
function matrix.mul_scalar(m, s)
local r = matrix.new(m.rows, m.cols)
for i = 1, #(r.values) do
r.values[i] = s * m.values[i]
end
return r
end
-- ********** INSTANCE FUNCTIONS **********
-- Returns the value on the (i, j) cell column of this matrix
function matrix:get(i, j)
if (i <= 0 or i > self.rows) then return 0 end
if (j <= 0 or j > self.cols) then return 0 end
......@@ -40,6 +101,13 @@ function matrix:get(i, j)
return n > #self.values and 0 or self.values[n]
end
-- Changes the (i, j) cell to be value
function matrix:set(i, j, value)
self.values[(i-1) * self.cols + j] = value
return self
end
-- Returns a vector corresponding to the ith column of this matrix
function matrix:column(i)
if (i <= 0 or i > self.cols) then return vector.zero(self.rows) end
local vals = {}
......@@ -47,11 +115,11 @@ function matrix:column(i)
return vector.new(self.rows, vals)
end
function matrix:set(i, j, value)
self.values[(i-1) * self.cols + j] = value
return self
end
-- ********** OPERATOR OVERLOADS **********
-- + binary operator overload
-- Only works on matrices with the same dimensions
-- Returns a matrix m' such that for all (i, j), m'(i, j) = m1(i, j) + m2(i, j)
function matrix.__add(m1, m2)
if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end
......@@ -63,6 +131,8 @@ function matrix.__add(m1, m2)
return m
end
-- - unary operator overload
-- Returns a matrix m' such that for all (i, j), m'(i, j) = -m(i, j)
function matrix.__unm(m)
local n = matrix.new(m.rows, m.cols)
for i = 1, #(n.values) do
......@@ -71,6 +141,8 @@ function matrix.__unm(m)
return n
end
-- - binary operator overload
-- Returns a matrix m' such that for all (i, j), m'(i, j) = m1(i, j) - m2(i, j)
function matrix.__sub(m1, m2)
if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end
......@@ -82,48 +154,9 @@ function matrix.__sub(m1, m2)
return m
end
function matrix.mul_matrix(m1, m2)
if (type(m1) == "number") then return matrix.mul_scalar(m2, m1) end
if (type(m2) == "number") then return matrix.mul_scalar(m1, m2) end
if (m1.cols ~= m2.rows) then return m1 end
local m = matrix.new(m1.rows, m2.cols)
for i = 1, m1.rows do
for j = 1, m2.cols do
local sum = 0
for k = 1, m1.cols do
sum = sum + m1:get(i, k) * m2:get(k, j)
end
m:set(i, j, sum)
end
end
return m
end
function matrix.mul_vector(m, v)
if (m.cols ~= v.size) then return v end
local vals = {}
for i = 1, m.rows do
local sum = 0
for j = 1, m.cols do
sum = sum + m:get(i, j) * v:get(j)
end
vals[i] = sum
end
return vector.new(m.rows, vals)
end
function matrix.mul_scalar(m, s)
local r = matrix.new(m.rows, m.cols)
for i = 1, #(r.values) do
r.values[i] = s * m.values[i]
end
return r
end
-- * binary operator overload
-- Returns the product of a matrix or scalar (left)
-- and a matrix, vector or scalar (right)
function matrix.__mul(m, o)
if (getmetatable(o) == matrix) then return matrix.mul_matrix(m, o)
elseif (getmetatable(o) == vector) then return matrix.mul_vector(m, o)
......@@ -131,6 +164,8 @@ function matrix.__mul(m, o)
return m
end
-- == operator overload
-- Returns "for all (i, j), m1(i, j) == m2(i, j)"
function matrix.__eq(m1, m2)
if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return false end
......
0% Chargement en cours ou .
You are about to add 0 people to the discussion. Proceed with caution.
Veuillez vous inscrire ou vous pour commenter