Skip to content
Extraits de code Groupes Projets
Valider c8723ac1 rédigé par ultrakatiz's avatar ultrakatiz
Parcourir les fichiers

cleaned up matrix.lua

parent 6ee4b70e
Aucune branche associée trouvée
Aucune étiquette associée trouvée
Aucune requête de fusion associée trouvée
require("kara3d.vector") require("kara3d.vector")
-- ********** DEFINITION AND GENERIC FUNCTIONS **********
matrix = {class = "matrix"} matrix = {class = "matrix"}
matrix.__index = matrix matrix.__index = matrix
-- Creates a new matrix with r rows and c columns,
-- initialized with the table vals
function matrix.new(r, c, vals) function matrix.new(r, c, vals)
local m = {rows = r or 1, cols = c or 1, values = {}} local m = {rows = r or 1, cols = c or 1, values = {}}
vals = vals or {} vals = vals or {}
...@@ -13,6 +17,12 @@ function matrix.new(r, c, vals) ...@@ -13,6 +17,12 @@ function matrix.new(r, c, vals)
return m return m
end end
-- Clones this matrix
function matrix:clone()
return matrix.new(self.rows, self.cols, self.values)
end
-- Returns a string representing the matrix
function matrix:tostring() function matrix:tostring()
local str = self.rows .. " x " .. self.cols .. "\n" local str = self.rows .. " x " .. self.cols .. "\n"
for i = 1, self.rows do for i = 1, self.rows do
...@@ -23,16 +33,67 @@ function matrix:tostring() ...@@ -23,16 +33,67 @@ function matrix:tostring()
return str return str
end end
-- ********** STATIC FUNCTIONS **********
-- Returns a matrix with r rows and c columns,
-- with 1 on the diagonal and 0 in the other cells
function matrix.identity(r, c) function matrix.identity(r, c)
local m = matrix.new(r, c) local m = matrix.new(r, c)
for i = 1, math.min(r, c) do m:set(i, i, 1) end for i = 1, math.min(r, c) do m:set(i, i, 1) end
return m return m
end end
function matrix:clone() -- Returns the product of two matrices, or a scalar and a matrix
return matrix.new(self.rows, self.cols, self.values) -- Only works if m1 or m2 is a scalar,
-- or if m1 has a number of columns equal to m2's number of rows
function matrix.mul_matrix(m1, m2)
if (type(m1) == "number") then return matrix.mul_scalar(m2, m1) end
if (type(m2) == "number") then return matrix.mul_scalar(m1, m2) end
if (m1.cols ~= m2.rows) then return m1 end
local m = matrix.new(m1.rows, m2.cols)
for i = 1, m1.rows do
for j = 1, m2.cols do
local sum = 0
for k = 1, m1.cols do
sum = sum + m1:get(i, k) * m2:get(k, j)
end
m:set(i, j, sum)
end
end
return m
end end
-- Returns the product of a matrix and a vector
-- Only works if m has a number of rows equal to the size of v
function matrix.mul_vector(m, v)
if (m.cols ~= v.size) then return v end
local vals = {}
for i = 1, m.rows do
local sum = 0
for j = 1, m.cols do
sum = sum + m:get(i, j) * v:get(j)
end
vals[i] = sum
end
return vector.new(m.rows, vals)
end
-- Returns the product of a scalar and a matrix
function matrix.mul_scalar(m, s)
local r = matrix.new(m.rows, m.cols)
for i = 1, #(r.values) do
r.values[i] = s * m.values[i]
end
return r
end
-- ********** INSTANCE FUNCTIONS **********
-- Returns the value on the (i, j) cell column of this matrix
function matrix:get(i, j) function matrix:get(i, j)
if (i <= 0 or i > self.rows) then return 0 end if (i <= 0 or i > self.rows) then return 0 end
if (j <= 0 or j > self.cols) then return 0 end if (j <= 0 or j > self.cols) then return 0 end
...@@ -40,6 +101,13 @@ function matrix:get(i, j) ...@@ -40,6 +101,13 @@ function matrix:get(i, j)
return n > #self.values and 0 or self.values[n] return n > #self.values and 0 or self.values[n]
end end
-- Changes the (i, j) cell to be value
function matrix:set(i, j, value)
self.values[(i-1) * self.cols + j] = value
return self
end
-- Returns a vector corresponding to the ith column of this matrix
function matrix:column(i) function matrix:column(i)
if (i <= 0 or i > self.cols) then return vector.zero(self.rows) end if (i <= 0 or i > self.cols) then return vector.zero(self.rows) end
local vals = {} local vals = {}
...@@ -47,11 +115,11 @@ function matrix:column(i) ...@@ -47,11 +115,11 @@ function matrix:column(i)
return vector.new(self.rows, vals) return vector.new(self.rows, vals)
end end
function matrix:set(i, j, value) -- ********** OPERATOR OVERLOADS **********
self.values[(i-1) * self.cols + j] = value
return self
end
-- + binary operator overload
-- Only works on matrices with the same dimensions
-- Returns a matrix m' such that for all (i, j), m'(i, j) = m1(i, j) + m2(i, j)
function matrix.__add(m1, m2) function matrix.__add(m1, m2)
if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end
...@@ -63,6 +131,8 @@ function matrix.__add(m1, m2) ...@@ -63,6 +131,8 @@ function matrix.__add(m1, m2)
return m return m
end end
-- - unary operator overload
-- Returns a matrix m' such that for all (i, j), m'(i, j) = -m(i, j)
function matrix.__unm(m) function matrix.__unm(m)
local n = matrix.new(m.rows, m.cols) local n = matrix.new(m.rows, m.cols)
for i = 1, #(n.values) do for i = 1, #(n.values) do
...@@ -71,6 +141,8 @@ function matrix.__unm(m) ...@@ -71,6 +141,8 @@ function matrix.__unm(m)
return n return n
end end
-- - binary operator overload
-- Returns a matrix m' such that for all (i, j), m'(i, j) = m1(i, j) - m2(i, j)
function matrix.__sub(m1, m2) function matrix.__sub(m1, m2)
if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end
...@@ -82,48 +154,9 @@ function matrix.__sub(m1, m2) ...@@ -82,48 +154,9 @@ function matrix.__sub(m1, m2)
return m return m
end end
function matrix.mul_matrix(m1, m2) -- * binary operator overload
if (type(m1) == "number") then return matrix.mul_scalar(m2, m1) end -- Returns the product of a matrix or scalar (left)
if (type(m2) == "number") then return matrix.mul_scalar(m1, m2) end -- and a matrix, vector or scalar (right)
if (m1.cols ~= m2.rows) then return m1 end
local m = matrix.new(m1.rows, m2.cols)
for i = 1, m1.rows do
for j = 1, m2.cols do
local sum = 0
for k = 1, m1.cols do
sum = sum + m1:get(i, k) * m2:get(k, j)
end
m:set(i, j, sum)
end
end
return m
end
function matrix.mul_vector(m, v)
if (m.cols ~= v.size) then return v end
local vals = {}
for i = 1, m.rows do
local sum = 0
for j = 1, m.cols do
sum = sum + m:get(i, j) * v:get(j)
end
vals[i] = sum
end
return vector.new(m.rows, vals)
end
function matrix.mul_scalar(m, s)
local r = matrix.new(m.rows, m.cols)
for i = 1, #(r.values) do
r.values[i] = s * m.values[i]
end
return r
end
function matrix.__mul(m, o) function matrix.__mul(m, o)
if (getmetatable(o) == matrix) then return matrix.mul_matrix(m, o) if (getmetatable(o) == matrix) then return matrix.mul_matrix(m, o)
elseif (getmetatable(o) == vector) then return matrix.mul_vector(m, o) elseif (getmetatable(o) == vector) then return matrix.mul_vector(m, o)
...@@ -131,6 +164,8 @@ function matrix.__mul(m, o) ...@@ -131,6 +164,8 @@ function matrix.__mul(m, o)
return m return m
end end
-- == operator overload
-- Returns "for all (i, j), m1(i, j) == m2(i, j)"
function matrix.__eq(m1, m2) function matrix.__eq(m1, m2)
if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return false end if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return false end
......
0% Chargement en cours ou .
You are about to add 0 people to the discussion. Proceed with caution.
Terminez d'abord l'édition de ce message.
Veuillez vous inscrire ou vous pour commenter