require("kara3d.vector")

-- ********** DEFINITION AND GENERIC FUNCTIONS **********

matrix = {class = "matrix"}
matrix.__index = matrix

-- Creates a new matrix with r rows and c columns,
-- initialized with the table vals
function matrix.new(r, c, vals)
  local m = {rows = r or 1, cols = c or 1, values = {}}
  vals = vals or {}
  for i = 1, r*c do
    m.values[i] = i <= #vals and vals[i] or 0
  end
  setmetatable(m, matrix)
  return m
end

-- Clones this matrix
function matrix:clone()
  return matrix.new(self.rows, self.cols, self.values)
end

-- Returns a string representing the matrix
function matrix:tostring()
  local str = self.rows .. " x " .. self.cols .. "\n"
  for i = 1, self.rows do
    for j = 1, self.cols do
      str = str .. self:get(i, j) .. (j < self.cols and "\t" or "\n")
    end
  end
  return str
end

-- ********** STATIC FUNCTIONS **********

-- Returns a matrix with r rows and c columns,
-- with 1 on the diagonal and 0 in the other cells
function matrix.identity(r, c)
  local m = matrix.new(r, c)
  for i = 1, math.min(r, c) do m:set(i, i, 1) end
  return m
end

-- Returns the product of two matrices, or a scalar and a matrix
-- Only works if m1 or m2 is a scalar,
-- or if m1 has a number of columns equal to m2's number of rows
function matrix.mul_matrix(m1, m2)
  if (type(m1) == "number") then return matrix.mul_scalar(m2, m1) end
  if (type(m2) == "number") then return matrix.mul_scalar(m1, m2) end
  if (m1.cols ~= m2.rows) then return m1 end

  local m = matrix.new(m1.rows, m2.cols)
  for i = 1, m1.rows do
    for j = 1, m2.cols do
      local sum = 0
      for k = 1, m1.cols do
        sum = sum + m1:get(i, k) * m2:get(k, j)
      end
      m:set(i, j, sum)
    end
  end

  return m
end

-- Returns the product of a matrix and a vector
-- Only works if m has a number of rows equal to the size of v
function matrix.mul_vector(m, v)
  if (m.cols ~= v.size) then return v end

  local vals = {}
  for i = 1, m.rows do
    local sum = 0
    for j = 1, m.cols do
      sum = sum + m:get(i, j) * v:get(j)
    end
    vals[i] = sum
  end

  return vector.new(m.rows, vals)
end

-- Returns the product of a scalar and a matrix
function matrix.mul_scalar(m, s)
  local r = matrix.new(m.rows, m.cols)
  for i = 1, #(r.values) do
    r.values[i] = s * m.values[i]
  end
  return r
end

-- ********** INSTANCE FUNCTIONS **********

-- Returns the value on the (i, j) cell column of this matrix
function matrix:get(i, j)
  if (i <= 0 or i > self.rows) then return 0 end
  if (j <= 0 or j > self.cols) then return 0 end
  local n = (i-1) * self.cols + j
  return n > #self.values and 0 or self.values[n]
end

-- Changes the (i, j) cell to be value
function matrix:set(i, j, value)
  self.values[(i-1) * self.cols + j] = value
  return self
end

-- Returns a vector corresponding to the ith column of this matrix
function matrix:column(i)
  if (i <= 0 or i > self.cols) then return vector.zero(self.rows) end
  local vals = {}
  for j = 1, self.rows do vals[j] = self:get(j, i) end
  return vector.new(self.rows, vals)
end

-- ********** OPERATOR OVERLOADS **********

-- + binary operator overload
-- Only works on matrices with the same dimensions
-- Returns a matrix m' such that for all (i, j), m'(i, j) = m1(i, j) + m2(i, j)
function matrix.__add(m1, m2)
  if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end

  local m = m1:clone()
  for i = 1, #(m.values) do
    m.values[i] = m.values[i] + m2.values[i]
  end

  return m
end

-- - unary operator overload
-- Returns a matrix m' such that for all (i, j), m'(i, j) = -m(i, j)
function matrix.__unm(m)
  local n = matrix.new(m.rows, m.cols)
  for i = 1, #(n.values) do
    n.values[i] = -m.values[i]
  end
  return n
end

-- - binary operator overload
-- Returns a matrix m' such that for all (i, j), m'(i, j) = m1(i, j) - m2(i, j)
function matrix.__sub(m1, m2)
  if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return m1 end

  local m = matrix.new(m1.rows, m1.cols)
  for i = 1, #(m.values) do
    m.values[i] = m1.values[i] - m2.values[i]
  end

  return m
end

-- * binary operator overload
-- Returns the product of a matrix or scalar (left)
-- and a matrix, vector or scalar (right)
function matrix.__mul(m, o)
  if (getmetatable(o) == matrix) then return matrix.mul_matrix(m, o)
  elseif (getmetatable(o) == vector) then return matrix.mul_vector(m, o)
  elseif (type(o) == "number") then return matrix.mul_scalar(m, o) end
  return m
end

-- == operator overload
-- Returns "for all (i, j), m1(i, j) == m2(i, j)"
function matrix.__eq(m1, m2)
  if (m1.rows ~= m2.rows or m1.cols ~= m2.cols) then return false end

  for i = 1, #(m1.values) do
    if (m1.values[i] ~= m2.values[i]) then return false end
  end

  return true
end