Skip to content
Extraits de code Groupes Projets
Sélectionner une révision Git
  • 298c33f4d2bd34b675ebe50981229a25e44dba62
  • master par défaut protégée
  • convert-to-dlang
  • clear-warnings
  • update-structs
  • unittests
  • bjarne-stroustrup
  • 5.67.0
8 résultats

mylib.c

Blame
  • matrix.lua 4,63 Kio
    require("kara3d.vector")
    
    -- ********** DEFINITION AND GENERIC FUNCTIONS **********
    
    matrix = {class = "matrix"}
    matrix.__index = matrix
    
    -- Creates a new matrix with r rows and c columns,
    -- initialized with the table vals
    function matrix.new(r, c, vals)
      local m = {rows = r or 1, cols = c or 1, values = {}}
      vals = vals or {}
      for i = 1, r*c do
        m.values[i] = i <= #vals and vals[i] or 0
      end
      setmetatable(m, matrix)
      return m
    end
    
    -- Clones this matrix
    function matrix:clone()
      return matrix.new(self.rows, self.cols, self.values)
    end
    
    -- Returns a string representing the matrix
    function matrix:tostring()
      local str = self.rows .. " x " .. self.cols .. "\n"
      for i = 1, self.rows do
        for j = 1, self.cols do
          str = str .. self:get(i, j) .. (j < self.cols and "\t" or "\n")
        end
      end
      return str
    end
    
    -- ********** STATIC FUNCTIONS **********
    
    -- Returns a matrix with r rows and c columns,
    -- with 1 on the diagonal and 0 in the other cells
    function matrix.identity(r, c)
      local m = matrix.new(r, c)
      for i = 1, math.min(r, c) do m:set(i, i, 1) end
      return m
    end
    
    -- Returns the product of two matrices, or a scalar and a matrix
    -- Only works if m1 or m2 is a scalar,
    -- or if m1 has a number of columns equal to m2's number of rows
    function matrix.mul_matrix(m1, m2)
      if (type(m1) == "number") then return matrix.mul_scalar(m2, m1) end
      if (type(m2) == "number") then return matrix.mul_scalar(m1, m2) end
      if (m1.cols ~= m2.rows) then return m1 end
    
      local m = matrix.new(m1.rows, m2.cols)
      for i = 1, m1.rows do
        for j = 1, m2.cols do
          local sum = 0
          for k = 1, m1.cols do
            sum = sum + m1:get(i, k) * m2:get(k, j)
          end
          m:set(i, j, sum)
        end
      end
    
      return m
    end
    
    -- Returns the product of a matrix and a vector
    -- Only works if m has a number of rows equal to the size of v
    function matrix.mul_vector(m, v)
      if (m.cols ~= v.size) then return v end
    
      local vals = {}
      for i = 1, m.rows do
        local sum = 0
        for j = 1, m.cols do
          sum = sum + m:get(i, j) * v:get(j)
        end
        vals[i] = sum
      end
    
      return vector.new(m.rows, vals)
    end
    
    -- Returns the product of a scalar and a matrix
    function matrix.mul_scalar(m, s)
      local r = matrix.new(m.rows, m.cols)
      for i = 1, #(r.values) do
        r.values[i] = s * m.values[i]
      end
      return r
    end
    
    -- ********** INSTANCE FUNCTIONS **********
    
    -- Returns the value on the (i, j) cell column of this matrix
    function matrix:get(i, j)
      if (i <= 0 or i > self.rows) then return 0 end
      if (j <= 0 or j > self.cols) then return 0 end
      local n = (i-1) * self.cols + j
      return n > #self.values and 0 or self.values[n]
    end
    
    -- Changes the (i, j) cell to be value
    function matrix:set(i, j, value)
      self.values[(i-1) * self.cols + j] = value
      return self
    end
    
    -- Returns a vector corresponding to the ith column of this matrix
    function matrix:column(i)
      if (i <= 0 or i > self.cols) then return vector.zero(self.rows) end
      local vals = {}
      for j = 1, self.rows do vals[j] = self:get(j, i) end
      return vector.new(self.rows, vals)
    end
    
    -- ********** OPERATOR OVERLOADS **********
    
    -- + binary operator overload
    -- Only works on matrices with the same dimensions
    -- Returns a matrix m' such that for all (i, j), m'(i, j) = m1(i, j) + m2(i, j)
    function matrix.__add(m1, m2)
      if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end
    
      local m = m1:clone()
      for i = 1, #(m.values) do
        m.values[i] = m.values[i] + m2.values[i]
      end
    
      return m
    end
    
    -- - unary operator overload
    -- Returns a matrix m' such that for all (i, j), m'(i, j) = -m(i, j)
    function matrix.__unm(m)
      local n = matrix.new(m.rows, m.cols)
      for i = 1, #(n.values) do
        n.values[i] = -m.values[i]
      end
      return n
    end
    
    -- - binary operator overload
    -- Returns a matrix m' such that for all (i, j), m'(i, j) = m1(i, j) - m2(i, j)
    function matrix.__sub(m1, m2)
      if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end
    
      local m = matrix.new(m1.rows, m1.cols)
      for i = 1, #(m.values) do
        m.values[i] = m1.values[i] - m2.values[i]
      end
    
      return m
    end
    
    -- * binary operator overload
    -- Returns the product of a matrix or scalar (left)
    -- and a matrix, vector or scalar (right)
    function matrix.__mul(m, o)
      if (getmetatable(o) == matrix) then return matrix.mul_matrix(m, o)
      elseif (getmetatable(o) == vector) then return matrix.mul_vector(m, o)
      elseif (type(o) == "number") then return matrix.mul_scalar(m, o) end
      return m
    end
    
    -- == operator overload
    -- Returns "for all (i, j), m1(i, j) == m2(i, j)"
    function matrix.__eq(m1, m2)
      if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return false end
    
      for i = 1, #(m1.values) do
        if (m1.values[i] ~= m2.values[i]) then return false end
      end
    
      return true
    end