From 5855ccb8287dc3cf52c1216e826aef3db5d30cb4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ma=C3=ABlle=20MARTIN?= <maelle.martin@proton.me>
Date: Fri, 30 Aug 2024 12:29:52 +0200
Subject: [PATCH] SCRIPT: Continue to implement LLVM bindings

---
 src/Rust/vvs_llvm/src/bindings/basic_block.rs | 125 +++++++++++++++---
 src/Rust/vvs_llvm/src/bindings/builder.rs     |   4 +-
 src/Rust/vvs_llvm/src/bindings/function.rs    |  98 ++++++++++----
 src/Rust/vvs_llvm/src/bindings/module.rs      |   4 +-
 src/Rust/vvs_llvm/src/bindings/value.rs       |  23 +++-
 5 files changed, 205 insertions(+), 49 deletions(-)

diff --git a/src/Rust/vvs_llvm/src/bindings/basic_block.rs b/src/Rust/vvs_llvm/src/bindings/basic_block.rs
index 16883051..ed65e7b1 100644
--- a/src/Rust/vvs_llvm/src/bindings/basic_block.rs
+++ b/src/Rust/vvs_llvm/src/bindings/basic_block.rs
@@ -1,14 +1,53 @@
 use crate::prelude::*;
 use llvm_sys::{core::*, prelude::*};
 
-crate::bindings::declare! { mut LLVMBasicBlockRef as BB<'a> }
+crate::bindings::declare! { const LLVMBasicBlockRef as BB<'a>    }
+crate::bindings::declare! { mut   LLVMBasicBlockRef as BBMut<'a> }
+
+/// Share code between the [BB::terminator] and [BB::terminator_mut]
+macro_rules! get_terminator {
+    ($self:ident as $value:ident) => {{
+        use llvm_sys::LLVMOpcode::*;
+        match unsafe { LLVMGetInstructionOpcode(LLVMGetLastInstruction($self.as_ptr())) } {
+            LLVMRet | LLVMBr | LLVMSwitch | LLVMIndirectBr | LLVMCallBr => unsafe {
+                Some($value::from_ptr(LLVMGetBasicBlockTerminator($self.as_ptr())))
+            },
+
+            LLVMCleanupRet | LLVMCatchRet | LLVMCatchPad | LLVMCleanupPad | LLVMCatchSwitch => unsafe {
+                Some($value::from_ptr(LLVMGetBasicBlockTerminator($self.as_ptr())))
+            },
+
+            LLVMUnreachable => None,
+
+            _ => None,
+        }
+    }};
+}
+
+impl<'a> BBMut<'a> {
+    /// Get a const reference to the basic block.
+    pub fn into_ref(self) -> BB<'a> {
+        todo!()
+    }
+
+    /// Get the terminator in a mutable way of the basic block if it exists.
+    pub fn terminator_mut(&mut self) -> Option<ValueMut> {
+        get_terminator!(self as ValueMut)
+    }
+
+    /// Iterate over the instructions of a basic block in a mutable way.
+    pub fn iter_mut(&'a mut self) -> BBIter<'a, &'a mut BBMut<'a>> {
+        match self.as_ref().is_empty() {
+            true => BBIter { curr: None, last: None, bb: self },
+            false => todo!(),
+        }
+    }
+}
 
 impl<'a> BB<'a> {
     /// Get the insturction count of this basic block.
     pub fn len(&self) -> usize {
-        let len = unsafe { LLVMGetLastInstruction(self.as_ptr()).offset_from(LLVMGetFirstInstruction(self.as_ptr())) };
-        assert!(len >= 0, "something went really wrong");
-        len as usize
+        (unsafe { LLVMGetLastInstruction(self.as_ptr()).offset_from(LLVMGetFirstInstruction(self.as_ptr())) }) as usize
     }
 
     /// Tells wether there are any instructions in this basic block.
@@ -16,6 +55,11 @@ impl<'a> BB<'a> {
         self.len() == 0
     }
 
+    /// Get the terminator of the basic block if it exists.
+    pub fn terminator(&self) -> Option<Value> {
+        get_terminator!(self as Value)
+    }
+
     /// Iterate over the instructions of a basic block.
     pub fn iter(&'a self) -> BBIter<'a, &'a BB<'a>> {
         match self.is_empty() {
@@ -23,13 +67,11 @@ impl<'a> BB<'a> {
             false => todo!(),
         }
     }
+}
 
-    /// Iterate over the instructions of a basic block in a mutable way.
-    pub fn iter_mut(&'a mut self) -> BBIter<'a, &'a mut BB<'a>> {
-        match self.is_empty() {
-            true => BBIter { curr: None, last: None, bb: self },
-            false => todo!(),
-        }
+impl<'a> AsRef<BB<'a>> for BBMut<'a> {
+    fn as_ref(&self) -> &BB<'a> {
+        todo!()
     }
 }
 
@@ -40,12 +82,23 @@ impl PartialEq for BB<'_> {
     }
 }
 
+impl Eq for BBMut<'_> {}
+impl PartialEq for BBMut<'_> {
+    fn eq(&self, other: &Self) -> bool {
+        self.inner == other.inner
+    }
+}
+
 mod sealed {
     pub trait BBRef<'a>: AsRef<super::BB<'a>> {}
+    pub trait BBMutRef<'a>: AsMut<super::BBMut<'a>> + BBRef<'a> {}
 }
 
 impl<'a> sealed::BBRef<'a> for &'a BB<'a> {}
 impl<'a> sealed::BBRef<'a> for &'a mut BB<'a> {}
+impl<'a> sealed::BBRef<'a> for &'a BBMut<'a> {}
+impl<'a> sealed::BBRef<'a> for &'a mut BBMut<'a> {}
+impl<'a> sealed::BBMutRef<'a> for &'a mut BBMut<'a> {}
 
 impl<'a> AsRef<BB<'a>> for &'a BB<'a> {
     fn as_ref(&self) -> &BB<'a> {
@@ -59,6 +112,12 @@ impl<'a> AsRef<BB<'a>> for &'a mut BB<'a> {
     }
 }
 
+impl<'a> AsMut<BBMut<'a>> for &'a mut BBMut<'a> {
+    fn as_mut(&mut self) -> &mut BBMut<'a> {
+        todo!()
+    }
+}
+
 /// Iterate over the instructions in a basic block.
 pub struct BBIter<'a, B: sealed::BBRef<'a>> {
     curr: Option<Value<'a>>,
@@ -66,22 +125,56 @@ pub struct BBIter<'a, B: sealed::BBRef<'a>> {
     bb: B,
 }
 
-impl<'a, B: sealed::BBRef<'a>> Iterator for BBIter<'a, B> {
-    type Item = Value<'a>;
+/// Iterate over the instructions in a basic block.
+pub struct BBIterMut<'a, B: sealed::BBMutRef<'a>> {
+    curr: Option<ValueMut<'a>>,
+    last: Option<ValueMut<'a>>,
+    bb: B,
+}
 
-    fn next(&mut self) -> Option<Self::Item> {
-        match (self.curr.take(), self.last.as_ref()) {
+macro_rules! next {
+    ($self:ident as $value:ident) => {
+        match ($self.curr.take(), $self.last.as_ref()) {
             (Some(_), None) | (None, Some(_)) => unreachable!(),
             (None, None) => None,
             (Some(curr), Some(last)) if curr == *last => {
-                self.last = None;
+                $self.last = None;
                 Some(curr)
             }
             (Some(curr), Some(_)) => unsafe {
-                self.curr = Some(Value::from_ptr(LLVMGetNextInstruction(curr.as_ptr())));
+                $self.curr = Some($value::from_ptr(LLVMGetNextInstruction(curr.as_ptr())));
                 Some(curr)
             },
         }
+    };
+}
+
+impl<'a, B: sealed::BBRef<'a>> Iterator for BBIter<'a, B> {
+    type Item = Value<'a>;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        next!(self as Value)
+    }
+
+    fn last(self) -> Option<Self::Item> {
+        self.last
+    }
+
+    fn count(self) -> usize {
+        self.bb.as_ref().len()
+    }
+
+    fn size_hint(&self) -> (usize, Option<usize>) {
+        let size = self.bb.as_ref().len();
+        (size, Some(size))
+    }
+}
+
+impl<'a, B: sealed::BBMutRef<'a>> Iterator for BBIterMut<'a, B> {
+    type Item = ValueMut<'a>;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        next!(self as ValueMut)
     }
 
     fn last(self) -> Option<Self::Item> {
diff --git a/src/Rust/vvs_llvm/src/bindings/builder.rs b/src/Rust/vvs_llvm/src/bindings/builder.rs
index 2b9b23c8..74ee630f 100644
--- a/src/Rust/vvs_llvm/src/bindings/builder.rs
+++ b/src/Rust/vvs_llvm/src/bindings/builder.rs
@@ -19,13 +19,13 @@ macro_rules! build {
 
 impl<'a> Builder<'a> {
     /// Position the builder at the end of a basic block.
-    pub fn position_at_end(&mut self, bb: &BB) -> &mut Self {
+    pub fn position_at_end(&mut self, bb: &mut BBMut) -> &mut Self {
         unsafe { LLVMPositionBuilderAtEnd(self.as_ptr(), bb.as_ptr()) };
         self
     }
 
     /// Position the builder before a value
-    pub fn position_before(&mut self, value: &Value) -> &mut Self {
+    pub fn position_before(&mut self, value: &mut ValueMut) -> &mut Self {
         unsafe { LLVMPositionBuilderBefore(self.as_ptr(), value.as_ptr()) };
         self
     }
diff --git a/src/Rust/vvs_llvm/src/bindings/function.rs b/src/Rust/vvs_llvm/src/bindings/function.rs
index b12f0cca..fd79b262 100644
--- a/src/Rust/vvs_llvm/src/bindings/function.rs
+++ b/src/Rust/vvs_llvm/src/bindings/function.rs
@@ -1,13 +1,31 @@
 use crate::{bindings::cstr, prelude::*};
 use llvm_sys::{core::*, prelude::*, LLVMLinkage, LLVMVisibility};
 
-crate::bindings::declare! { mut LLVMValueRef as FunctionDeclaration<'a> }
-crate::bindings::declare! { mut LLVMValueRef as Function<'a> {
+crate::bindings::declare! { const LLVMValueRef as FunctionDeclaration<'a> }
+crate::bindings::declare! { mut   LLVMValueRef as FunctionBuilder<'a> }
+crate::bindings::declare! { mut   LLVMValueRef as Function<'a> {
     entry_point: BB<'a>,
     basic_block_count: u64,
 } }
 
 impl<'a> FunctionDeclaration<'a> {
+    /// Get the type of this function declaration.
+    pub fn ty(&self) -> FunctionType {
+        unsafe { FunctionType::from_ptr(LLVMTypeOf(self.as_ptr())) }
+    }
+
+    /// Get the LLVM context of the function.
+    pub fn context(&'a self) -> Context<'a> {
+        unsafe { Context::from_ptr(LLVMGetTypeContext(LLVMTypeOf(self.inner)), true) }
+    }
+}
+
+impl<'a> FunctionBuilder<'a> {
+    /// Get the associated declaration.
+    pub fn declaration(&self) -> FunctionDeclaration {
+        todo!()
+    }
+
     /// Set the [LLVMLinkage] and [LLVMVisibility] accordingly.
     pub fn into_public(self, public: bool) -> Self {
         match public {
@@ -23,11 +41,6 @@ impl<'a> FunctionDeclaration<'a> {
         self
     }
 
-    /// Get the type of this function declaration.
-    pub fn ty(&self) -> FunctionType {
-        unsafe { FunctionType::from_ptr(LLVMTypeOf(self.as_ptr())) }
-    }
-
     /// Transform the function declaration into a proper function, with a function body, or at
     /// lease an entry point.
     pub fn into_function(self) -> Function<'a> {
@@ -35,7 +48,7 @@ impl<'a> FunctionDeclaration<'a> {
             Function::from_ptr(
                 self.as_ptr(),
                 BB::from_ptr(LLVMAppendBasicBlockInContext(
-                    LLVMGetTypeContext(self.ty().as_ptr()),
+                    self.declaration().context().as_ptr(),
                     self.as_ptr(),
                     crate::bindings::cstr(b"entrypoint\0").as_ptr(),
                 )),
@@ -46,9 +59,9 @@ impl<'a> FunctionDeclaration<'a> {
 }
 
 impl<'a> Function<'a> {
-    /// Get the type of this function.
-    pub fn ty(&self) -> FunctionType {
-        unsafe { FunctionType::from_ptr(LLVMTypeOf(self.as_ptr())) }
+    /// Get the associated declaration.
+    pub fn declaration(&self) -> FunctionDeclaration {
+        todo!()
     }
 
     /// Get the entry point of the function.
@@ -61,17 +74,12 @@ impl<'a> Function<'a> {
         &mut self.entry_point
     }
 
-    /// Get the LLVM context of the function.
-    pub fn context(&'a self) -> Context<'a> {
-        unsafe { Context::from_ptr(LLVMGetTypeContext(LLVMTypeOf(self.inner)), true) }
-    }
-
     /// Add a new basic block to the function and returns it.
     pub fn new_basic_block(&'a mut self) -> BB {
         self.basic_block_count += 1;
         unsafe {
             BB::from_ptr(LLVMAppendBasicBlockInContext(
-                self.context().as_ptr(),
+                self.declaration().context().as_ptr(),
                 self.inner,
                 cstr(format!("{}", self.basic_block_count).as_bytes()).as_ptr(),
             ))
@@ -93,13 +101,13 @@ impl<'a> Function<'a> {
     }
 
     /// Iterate over all the basic blocks of the function in a mutable way.
-    pub fn iter_mut(&'a mut self) -> FunctionIter<'a, &'a mut Function> {
+    pub fn iter_mut(&'a mut self) -> FunctionIterMut<'a, &'a mut Function> {
         unsafe {
             match LLVMCountBasicBlocks(self.as_ptr()) {
-                0 => FunctionIter { curr: None, last: None, func: self },
-                _ => FunctionIter {
-                    curr: Some(BB::from_ptr(LLVMGetFirstBasicBlock(self.inner))),
-                    last: Some(BB::from_ptr(LLVMGetLastBasicBlock(self.inner))),
+                0 => FunctionIterMut { curr: None, last: None, func: self },
+                _ => FunctionIterMut {
+                    curr: Some(BBMut::from_ptr(LLVMGetFirstBasicBlock(self.inner))),
+                    last: Some(BBMut::from_ptr(LLVMGetLastBasicBlock(self.inner))),
                     func: self,
                 },
             }
@@ -133,22 +141,56 @@ pub struct FunctionIter<'a, F: sealed::FuncRef<'a>> {
     func: F,
 }
 
-impl<'a, F: sealed::FuncRef<'a>> Iterator for FunctionIter<'a, F> {
-    type Item = BB<'a>;
+/// Iterate over the basic blocks in a function.
+pub struct FunctionIterMut<'a, F: sealed::FuncRef<'a>> {
+    curr: Option<BBMut<'a>>,
+    last: Option<BBMut<'a>>,
+    func: F,
+}
 
-    fn next(&mut self) -> Option<Self::Item> {
-        match (self.curr.take(), self.last.as_ref()) {
+macro_rules! next {
+    ($self:ident as $bb:ident) => {
+        match ($self.curr.take(), $self.last.as_ref()) {
             (Some(_), None) | (None, Some(_)) => unreachable!(),
             (None, None) => None,
             (Some(curr), Some(last)) if curr == *last => {
-                self.last = None;
+                $self.last = None;
                 Some(curr)
             }
             (Some(curr), Some(_)) => unsafe {
-                self.curr = Some(BB::from_ptr(LLVMGetNextBasicBlock(curr.as_ptr())));
+                $self.curr = Some($bb::from_ptr(LLVMGetNextBasicBlock(curr.as_ptr())));
                 Some(curr)
             },
         }
+    };
+}
+
+impl<'a, F: sealed::FuncRef<'a>> Iterator for FunctionIter<'a, F> {
+    type Item = BB<'a>;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        next!(self as BB)
+    }
+
+    fn last(self) -> Option<Self::Item> {
+        self.last
+    }
+
+    fn count(self) -> usize {
+        (unsafe { LLVMCountBasicBlocks(self.func.as_ref().as_ptr()) }) as usize
+    }
+
+    fn size_hint(&self) -> (usize, Option<usize>) {
+        let count = (unsafe { LLVMCountBasicBlocks(self.func.as_ref().as_ptr()) }) as usize;
+        (count, Some(count))
+    }
+}
+
+impl<'a, F: sealed::FuncRef<'a>> Iterator for FunctionIterMut<'a, F> {
+    type Item = BBMut<'a>;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        next!(self as BBMut)
     }
 
     fn last(self) -> Option<Self::Item> {
diff --git a/src/Rust/vvs_llvm/src/bindings/module.rs b/src/Rust/vvs_llvm/src/bindings/module.rs
index 4b0f4c32..2e5b9828 100644
--- a/src/Rust/vvs_llvm/src/bindings/module.rs
+++ b/src/Rust/vvs_llvm/src/bindings/module.rs
@@ -8,7 +8,7 @@ impl Module<'_> {
     /// Add a new function in this module.
     ///
     /// Note that by default all functions are private.
-    pub fn add_function(&mut self, name: impl AsRef<str>, ty: FunctionType) -> Result<FunctionDeclaration, NulError> {
+    pub fn add_function(&mut self, name: impl AsRef<str>, ty: FunctionType) -> Result<FunctionBuilder, NulError> {
         let function = unsafe { LLVMAddFunction(self.inner, CString::new(name.as_ref())?.as_ptr(), ty.as_ptr()) };
 
         // All the functions are private to the module!
@@ -17,7 +17,7 @@ impl Module<'_> {
             LLVMSetVisibility(function, LLVMVisibility::LLVMHiddenVisibility);
         }
 
-        Ok(unsafe { FunctionDeclaration::from_ptr(function) })
+        Ok(unsafe { FunctionBuilder::from_ptr(function) })
     }
 
     /// Get a function out of the a module as a declaration.
diff --git a/src/Rust/vvs_llvm/src/bindings/value.rs b/src/Rust/vvs_llvm/src/bindings/value.rs
index 98d8496d..08832609 100644
--- a/src/Rust/vvs_llvm/src/bindings/value.rs
+++ b/src/Rust/vvs_llvm/src/bindings/value.rs
@@ -1,7 +1,8 @@
 use crate::prelude::*;
 use llvm_sys::{core::*, prelude::*};
 
-crate::bindings::declare! { mut LLVMValueRef as Value<'a> }
+crate::bindings::declare! { const LLVMValueRef as Value<'a> }
+crate::bindings::declare! { mut   LLVMValueRef as ValueMut<'a> }
 
 impl Value<'_> {
     /// Get the type of the value.
@@ -26,9 +27,29 @@ impl Value<'_> {
     }
 }
 
+impl<'a> ValueMut<'a> {
+    /// Get a const reference to the internal value out of the mutable value.
+    pub fn into_ref(self) -> Value<'a> {
+        unsafe { Value::from_ptr(self.as_ptr()) }
+    }
+}
+
 impl Eq for Value<'_> {}
 impl PartialEq for Value<'_> {
     fn eq(&self, other: &Self) -> bool {
         self.inner == other.inner
     }
 }
+
+impl Eq for ValueMut<'_> {}
+impl PartialEq for ValueMut<'_> {
+    fn eq(&self, other: &Self) -> bool {
+        self.inner == other.inner
+    }
+}
+
+impl<'a> AsRef<Value<'a>> for ValueMut<'a> {
+    fn as_ref(&self) -> &Value<'a> {
+        todo!()
+    }
+}
-- 
GitLab