From 0910c21fd7a298f62077d7abc6b4846dab8f0637 Mon Sep 17 00:00:00 2001
From: Kubat <maelle.martin@proton.me>
Date: Thu, 1 May 2025 20:11:18 +0200
Subject: [PATCH] [VN] Add a way to use vn types in scripts and spells

---
 grimoire/src/error.rs           |   8 +-
 grimoire/src/lib.rs             |  17 ++-
 grimoire/src/spell/arguments.rs |   8 +-
 grimoire/src/spell/factory.rs   |   2 +-
 grimoire/src/spell/std/log.rs   |  20 ++--
 grimoire/src/state.rs           | 178 +++++++++++++++++++-------------
 grimoire/src/vn/actor.rs        |   3 +
 grimoire/src/vn/mod.rs          |  67 ++++++++++++
 grimoire/src/vn/printer.rs      |   3 +
 9 files changed, 218 insertions(+), 88 deletions(-)
 create mode 100644 grimoire/src/vn/actor.rs
 create mode 100644 grimoire/src/vn/mod.rs
 create mode 100644 grimoire/src/vn/printer.rs

diff --git a/grimoire/src/error.rs b/grimoire/src/error.rs
index e61a1a0..7750588 100644
--- a/grimoire/src/error.rs
+++ b/grimoire/src/error.rs
@@ -9,13 +9,13 @@ pub enum Error {
     RedefinedSpell(&'static str, &'static str),
 
     #[error("already defined spell argument '{0}'")]
-    RedefinedSpellArgument(String),
+    RedefinedSpellArgument(std::borrow::Cow<'static, str>),
 
     #[error("unexpect: {0}")]
-    Unexpected(&'static str),
+    Unexpected(std::borrow::Cow<'static, str>),
 
     #[error("undefined spell '{0}'")]
-    Undefined(String),
+    Undefined(std::borrow::Cow<'static, str>),
 
     #[error("can't cast '{0}' into {1:?}")]
     CantCast(crate::ast::Const, crate::types::Type),
@@ -30,7 +30,7 @@ pub enum Error {
     ParseBool(#[from] ParseBoolError),
 
     #[error("empty {0}")]
-    Empty(&'static str),
+    Empty(std::borrow::Cow<'static, str>),
 
     #[error("error in spell '{0}' instantiation: {1}")]
     BuildSpell(&'static str, Box<Error>),
diff --git a/grimoire/src/lib.rs b/grimoire/src/lib.rs
index 4a31919..e41df7b 100644
--- a/grimoire/src/lib.rs
+++ b/grimoire/src/lib.rs
@@ -1,6 +1,7 @@
 mod error;
 mod state;
 mod types;
+mod vn;
 
 pub mod ast;
 pub mod parser;
@@ -18,11 +19,19 @@ pub mod prelude {
         types::Type as GrimoireType,
     };
 
-    // Spell structs/traits re-exports.
-    pub use crate::spell::{BuildableSpell, Spell, SpellArguments, SpellFactory};
+    // The VN structs.
+    pub mod vn {
+        pub use crate::vn::ConstOrVnType;
+    }
 
-    // Standard spells re-rexports.
+    // Spell structs + standard spells re-rexports.
     pub mod spell {
-        pub use crate::spell::std::*;
+        pub use crate::spell::{SpellArguments as Arguments, SpellFactory as Factory, std::*};
     }
+
+    // Traits.
+    pub use crate::{
+        spell::{BuildableSpell as GrimoireBuildableSpell, Spell as GrimoireSpell},
+        vn::{Actor as GrimoireActor, Printer as GrimoirePrinter},
+    };
 }
diff --git a/grimoire/src/spell/arguments.rs b/grimoire/src/spell/arguments.rs
index 89cb131..f0ace1a 100644
--- a/grimoire/src/spell/arguments.rs
+++ b/grimoire/src/spell/arguments.rs
@@ -43,13 +43,13 @@ impl SpellArguments {
         names: [&str; N],
     ) -> Result<[Expression; N], Error> {
         if let Some(missing) = names.iter().find(|&&name| !self.named.contains_key(name)) {
-            Err(Error::Undefined(format!(
+            Err(Error::Undefined(Into::into(format!(
                 "mendatory spell argument '{missing}'"
-            )))
+            ))))
         } else if self.named.len() > N {
-            Err(Error::Unexpected(
+            Err(Error::Unexpected(Into::into(
                 "too many named arguments in in spell arguments",
-            ))
+            )))
         } else {
             Ok(names.map(|name| self.remove(name).unwrap()))
         }
diff --git a/grimoire/src/spell/factory.rs b/grimoire/src/spell/factory.rs
index e9502db..a444218 100644
--- a/grimoire/src/spell/factory.rs
+++ b/grimoire/src/spell/factory.rs
@@ -45,6 +45,6 @@ impl SpellFactory {
         self.dispatch
             .get(name.as_ref())
             .copied()
-            .ok_or(Error::Undefined(name.as_ref().to_string()))
+            .ok_or(Error::Undefined(name.as_ref().to_string().into()))
     }
 }
diff --git a/grimoire/src/spell/std/log.rs b/grimoire/src/spell/std/log.rs
index 6caf55e..0d96b9e 100644
--- a/grimoire/src/spell/std/log.rs
+++ b/grimoire/src/spell/std/log.rs
@@ -7,10 +7,14 @@ pub struct Log {
     messages: Vec<GrimoireExpression>,
 }
 
-impl Spell for Log {
+impl GrimoireSpell for Log {
     fn cast(&self, state: GrimoireState) -> Result<GrimoireState, GrimoireError> {
         if let Some(pred) = &self.if_predicate {
-            if !(state.evaluate_as(pred, GrimoireType::Boolean)?).unwrap_boolean() {
+            if !state
+                .evaluate_as(pred, GrimoireType::Boolean)?
+                .try_into_const()?
+                .unwrap_boolean()
+            {
                 return Ok(state);
             }
         }
@@ -22,11 +26,15 @@ impl Spell for Log {
 
         let level = state
             .evaluate_as(&self.level, GrimoireType::String)?
+            .try_into_const()?
             .to_string();
 
         let logger = |ident: bool, line: &GrimoireExpression| -> Result<(), GrimoireError> {
             let ident = ident.then_some("\t").unwrap_or_default();
-            let line = state.evaluate_as(line, GrimoireType::String)?.to_string();
+            let line = state
+                .evaluate_as(line, GrimoireType::String)?
+                .try_into_const()?
+                .to_string();
 
             match level.as_str() {
                 "error" => log::error!(target: "grimoire", "{ident}{line}"),
@@ -45,10 +53,10 @@ impl Spell for Log {
     }
 }
 
-impl TryFrom<SpellArguments> for Log {
+impl TryFrom<spell::Arguments> for Log {
     type Error = GrimoireError;
 
-    fn try_from(mut args: SpellArguments) -> Result<Self, Self::Error> {
+    fn try_from(mut args: spell::Arguments) -> Result<Self, Self::Error> {
         let if_predicate = args.remove("if");
         let [level] = args.remove_exact(["level"])?;
         let messages = args
@@ -64,7 +72,7 @@ impl TryFrom<SpellArguments> for Log {
     }
 }
 
-impl BuildableSpell for Log {
+impl GrimoireBuildableSpell for Log {
     fn name() -> &'static str {
         "log"
     }
diff --git a/grimoire/src/state.rs b/grimoire/src/state.rs
index 6b390fb..2e41616 100644
--- a/grimoire/src/state.rs
+++ b/grimoire/src/state.rs
@@ -2,24 +2,26 @@ use crate::{
     ast::{Const, Expression, VarOrConst, operator},
     error::Error,
     types::Type,
+    vn::{Actor, ConstOrVnType, Printer},
+};
+use std::{
+    collections::{HashMap, HashSet},
+    rc::Rc,
 };
-use std::collections::{HashMap, HashSet};
 
-#[derive(Debug, Default)]
+#[derive(Default)]
 pub struct State {
     variables: HashMap<String, Const>,
     identifiers: HashSet<String>,
-
-    // To transform into HashMap...
-    actors: HashSet<String>,
-    printers: HashSet<String>,
+    actors: HashMap<String, Rc<dyn Actor>>,
+    printers: HashMap<String, Rc<dyn Printer>>,
 }
 
 impl State {
-    pub fn resolve(&self, var: impl AsRef<str>) -> Result<&Const, Error> {
+    pub fn resolve_variable(&self, var: impl AsRef<str>) -> Result<&Const, Error> {
         self.variables
             .get(var.as_ref())
-            .ok_or_else(|| Error::Undefined(var.as_ref().to_string()))
+            .ok_or_else(|| Error::Undefined(var.as_ref().to_string().into()))
     }
 
     fn evaluate_numeric_binop(
@@ -28,8 +30,8 @@ impl State {
         op: operator::Numeric,
         right: impl AsRef<Expression>,
     ) -> Result<Const, Error> {
-        let left = self.evaluate_as(left, Type::Number)?;
-        let right = self.evaluate_as(right, Type::Number)?;
+        let left = self.evaluate_as(left, Type::Number)?.try_into_const()?;
+        let right = self.evaluate_as(right, Type::Number)?.try_into_const()?;
 
         match matches!(left, Const::Flt(_)) || matches!(right, Const::Flt(_)) {
             // Use float operations
@@ -87,8 +89,14 @@ impl State {
         right: impl AsRef<Expression>,
     ) -> Result<Const, Error> {
         // Here we evaluate eagerly to catch errors!
-        let left = self.evaluate_as(left, Type::Boolean)?.unwrap_boolean();
-        let right = self.evaluate_as(right, Type::Boolean)?.unwrap_boolean();
+        let left = self
+            .evaluate_as(left, Type::Boolean)?
+            .try_into_const()?
+            .unwrap_boolean();
+        let right = self
+            .evaluate_as(right, Type::Boolean)?
+            .try_into_const()?
+            .unwrap_boolean();
         Ok(Const::Bool(match op {
             operator::Logic::Or => left || right,
             operator::Logic::And => left && right,
@@ -102,61 +110,70 @@ impl State {
         op: operator::Equal,
         right: impl AsRef<Expression>,
     ) -> Result<Const, Error> {
-        Ok(Const::Bool(
-            match (self.evaluate(left)?, self.evaluate(right)?) {
-                // Easy!
-                (Const::Int(x), Const::Int(y)) => op.apply(x == y),
-                (Const::Flt(x), Const::Flt(y)) => op.apply(x == y),
-                (Const::Bool(x), Const::Bool(y)) => op.apply(x == y),
-
-                (Const::Str(x), Const::Ident(y))
-                | (Const::Ident(x), Const::Str(y))
-                | (Const::Ident(x), Const::Ident(y))
-                | (Const::Str(x), Const::Str(y)) => op.apply(x == y),
-
-                // Cast, handle permutations and cast priority.
-                (Const::Int(x), Const::Flt(y)) => op.apply(x == y as i32),
-                (Const::Flt(x), Const::Int(y)) => op.apply(x == y as f32),
-
-                // Cast with permutation. Here we cast all into a boolean first!
-                (x @ Const::Int(_), Const::Bool(y)) | (Const::Bool(y), x @ Const::Int(_)) => {
-                    op.apply(x.unwrap_boolean() == y)
-                }
+        let left = self.evaluate(left)?.try_into_const()?;
+        let right = self.evaluate(right)?.try_into_const()?;
+        Ok(Const::Bool(match (left, right) {
+            // Easy!
+            (Const::Int(x), Const::Int(y)) => op.apply(x == y),
+            (Const::Flt(x), Const::Flt(y)) => op.apply(x == y),
+            (Const::Bool(x), Const::Bool(y)) => op.apply(x == y),
+
+            (Const::Str(x), Const::Ident(y))
+            | (Const::Ident(x), Const::Str(y))
+            | (Const::Ident(x), Const::Ident(y))
+            | (Const::Str(x), Const::Str(y)) => op.apply(x == y),
+
+            // Cast, handle permutations and cast priority.
+            (Const::Int(x), Const::Flt(y)) => op.apply(x == y as i32),
+            (Const::Flt(x), Const::Int(y)) => op.apply(x == y as f32),
+
+            // Cast with permutation. Here we cast all into a boolean first!
+            (x @ Const::Int(_), Const::Bool(y)) | (Const::Bool(y), x @ Const::Int(_)) => {
+                op.apply(x.unwrap_boolean() == y)
+            }
 
-                (x @ Const::Flt(_), Const::Bool(y)) | (Const::Bool(y), x @ Const::Flt(_)) => {
-                    op.apply(x.unwrap_boolean() == y)
-                }
+            (x @ Const::Flt(_), Const::Bool(y)) | (Const::Bool(y), x @ Const::Flt(_)) => {
+                op.apply(x.unwrap_boolean() == y)
+            }
 
-                // Can be inter-casted, they are not equal!
-                _ => op.check_not_equals(),
-            },
-        ))
+            // Can be inter-casted, they are not equal!
+            _ => op.check_not_equals(),
+        }))
     }
 
-    pub fn evaluate(&self, expr: impl AsRef<Expression>) -> Result<Const, Error> {
+    pub fn evaluate(&self, expr: impl AsRef<Expression>) -> Result<ConstOrVnType, Error> {
         match expr.as_ref() {
             // Leaf, simple.
-            Expression::Leaf(_, VarOrConst::Const(constant)) => Ok(constant.clone()),
-            Expression::Leaf(_, VarOrConst::Var(var)) => self.resolve(var).cloned(),
+            Expression::Leaf(_, VarOrConst::Const(constant)) => Ok(constant.clone().into()),
+            Expression::Leaf(_, VarOrConst::Var(var)) => {
+                self.resolve_variable(var).cloned().map(Into::into)
+            }
 
             // Unary, not complicated.
             Expression::Unary(_, op, inner) => match op {
                 operator::Unary::Not => Ok(Const::Bool(
-                    !self.evaluate_as(inner, Type::Boolean)?.unwrap_boolean(),
+                    !self
+                        .evaluate_as(inner, Type::Boolean)?
+                        .try_into_const()?
+                        .unwrap_boolean(),
                 )),
-                operator::Unary::Neg => match self.evaluate_as(inner, Type::Number)? {
-                    Const::Int(x) => Ok(Const::Int(-x)),
-                    Const::Flt(x) => Ok(Const::Flt(-x)),
-                    _ => unreachable!(),
-                },
-            },
+                operator::Unary::Neg => {
+                    match self.evaluate_as(inner, Type::Number)?.try_into_const()? {
+                        Const::Int(x) => Ok(Const::Int(-x)),
+                        Const::Flt(x) => Ok(Const::Flt(-x)),
+                        _ => unreachable!(),
+                    }
+                }
+            }
+            .map(Into::into),
 
             // Binary, just long to write...
             Expression::Binary(_, left, op, right) => match op {
                 operator::Binary::Numeric(op) => self.evaluate_numeric_binop(left, *op, right),
                 operator::Binary::Logic(op) => self.evaluate_logic_binop(left, *op, right),
                 operator::Binary::Equal(op) => self.evaluate_equal_binop(left, *op, right),
-            },
+            }
+            .map(Into::into),
         }
     }
 
@@ -164,33 +181,56 @@ impl State {
         let x = ident.into();
         match self.identifiers.contains(&x) {
             true => Ok(Const::Ident(x)),
-            false => Err(Error::Undefined(x)),
+            false => Err(Error::Undefined(x.into())),
         }
     }
 
-    fn find_actor(&self, actor: impl Into<String>) -> Result<Const, Error> {
-        let x = actor.into();
-        match self.actors.contains(&x) {
-            true => Ok(Const::Ident(x)),
-            false => Err(Error::Undefined(x)),
-        }
+    fn find_actor(&self, actor: impl Into<String>) -> Result<Rc<dyn Actor>, Error> {
+        let actor = actor.into();
+        self.actors
+            .get(&actor)
+            .cloned()
+            .ok_or(Error::Undefined(actor.into()))
     }
 
-    fn find_printer(&self, printer: impl Into<String>) -> Result<Const, Error> {
-        let x = printer.into();
-        match self.printers.contains(&x) {
-            true => Ok(Const::Ident(x)),
-            false => Err(Error::Undefined(x)),
-        }
+    fn find_printer(&self, printer: impl Into<String>) -> Result<Rc<dyn Printer>, Error> {
+        let printer = printer.into();
+        self.printers
+            .get(&printer)
+            .cloned()
+            .ok_or(Error::Undefined(printer.into()))
     }
 
-    pub fn evaluate_as(&self, expr: impl AsRef<Expression>, ty: Type) -> Result<Const, Error> {
-        let constant = self.evaluate(expr)?;
+    pub fn evaluate_as(
+        &self,
+        expr: impl AsRef<Expression>,
+        ty: Type,
+    ) -> Result<ConstOrVnType, Error> {
+        let res = self.evaluate(expr)?;
+
         match ty {
-            Type::Identifier => self.check_identifier(constant.unwrap_identifier()),
-            Type::Actor => self.find_actor(constant.unwrap_identifier()),
-            Type::Printer => self.find_printer(constant.unwrap_identifier()),
-            ty => constant.cast_into(ty),
+            Type::Identifier => self
+                .check_identifier(res.try_into_const()?.unwrap_identifier())
+                .map(ConstOrVnType::Const),
+
+            Type::Actor => match res.is_actor() {
+                true => Ok(res),
+                false => self
+                    .find_actor(res.try_into_const()?.unwrap_identifier())
+                    .map(ConstOrVnType::Actor),
+            },
+
+            Type::Printer => match res.is_printer() {
+                true => Ok(res),
+                false => self
+                    .find_printer(res.try_into_const()?.unwrap_identifier())
+                    .map(ConstOrVnType::Printer),
+            },
+
+            ty => res
+                .try_into_const()?
+                .cast_into(ty)
+                .map(ConstOrVnType::Const),
         }
     }
 }
diff --git a/grimoire/src/vn/actor.rs b/grimoire/src/vn/actor.rs
new file mode 100644
index 0000000..dd864ea
--- /dev/null
+++ b/grimoire/src/vn/actor.rs
@@ -0,0 +1,3 @@
+pub trait Actor {
+    fn name(&self) -> &'static str;
+}
diff --git a/grimoire/src/vn/mod.rs b/grimoire/src/vn/mod.rs
new file mode 100644
index 0000000..73b5390
--- /dev/null
+++ b/grimoire/src/vn/mod.rs
@@ -0,0 +1,67 @@
+mod actor;
+mod printer;
+
+use crate::{ast::Const, error::Error};
+use std::{fmt, rc::Rc};
+
+pub use self::{actor::Actor, printer::Printer};
+
+#[derive(Clone)]
+pub enum ConstOrVnType {
+    Const(Const),
+    Actor(Rc<dyn Actor>),
+    Printer(Rc<dyn Printer>),
+}
+
+impl fmt::Debug for ConstOrVnType {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            Self::Const(arg0) => f.debug_tuple("Const").field(arg0).finish(),
+            Self::Actor(arg0) => f.debug_tuple("Actor").field(&arg0.name()).finish(),
+            Self::Printer(arg0) => f.debug_tuple("Printer").field(&arg0.name()).finish(),
+        }
+    }
+}
+
+impl ConstOrVnType {
+    pub fn is_actor(&self) -> bool {
+        matches!(self, ConstOrVnType::Actor(_))
+    }
+
+    pub fn is_printer(&self) -> bool {
+        matches!(self, ConstOrVnType::Printer(_))
+    }
+
+    pub fn try_into_const(self) -> Result<Const, Error> {
+        use ConstOrVnType::*;
+        match self {
+            Const(constant) => Ok(constant),
+            Actor(vn) => Err(Error::Unexpected(Into::into(format!(
+                "can't cast actor '{name}' into constant",
+                name = vn.name()
+            )))),
+            Printer(vn) => Err(Error::Unexpected(Into::into(format!(
+                "can't cast printer '{name}' into constant",
+                name = vn.name()
+            )))),
+        }
+    }
+}
+
+impl From<Const> for ConstOrVnType {
+    fn from(value: Const) -> Self {
+        Self::Const(value)
+    }
+}
+
+impl From<Rc<dyn Actor>> for ConstOrVnType {
+    fn from(value: Rc<dyn Actor>) -> Self {
+        Self::Actor(value)
+    }
+}
+
+impl From<Rc<dyn Printer>> for ConstOrVnType {
+    fn from(value: Rc<dyn Printer>) -> Self {
+        Self::Printer(value)
+    }
+}
diff --git a/grimoire/src/vn/printer.rs b/grimoire/src/vn/printer.rs
new file mode 100644
index 0000000..99963ed
--- /dev/null
+++ b/grimoire/src/vn/printer.rs
@@ -0,0 +1,3 @@
+pub trait Printer {
+    fn name(&self) -> &'static str;
+}
-- 
GitLab