From b34891526a695d43c6f2df6690dd1a5426d16665 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?C=C3=A9line?= <liuceline2001@yahoo.fr>
Date: Thu, 2 Feb 2023 12:15:56 +0100
Subject: [PATCH] machine learning / yml

---
 .gitlab-ci.yml      | 24 ++++++++++++++++++++++++
 machine_learning.py | 21 +++++++++++++++++++++
 test_prediction.py  | 17 +++++++++++++++++
 3 files changed, 62 insertions(+)
 create mode 100644 .gitlab-ci.yml
 create mode 100644 machine_learning.py
 create mode 100644 test_prediction.py

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
new file mode 100644
index 0000000..4386ca5
--- /dev/null
+++ b/.gitlab-ci.yml
@@ -0,0 +1,24 @@
+image: python:3.9
+
+stages:
+  - build
+  - test
+  - export
+
+# Build stage
+build:
+  stage: build
+  script:
+    - python machine_learning.py
+
+# Test stage
+test:
+  stage: test
+  script:
+    - pytest test_prediction.py
+
+# Export stage
+export:
+  stage: export
+  script:
+    - mv digits_model.joblib artifacts/
\ No newline at end of file
diff --git a/machine_learning.py b/machine_learning.py
new file mode 100644
index 0000000..571411c
--- /dev/null
+++ b/machine_learning.py
@@ -0,0 +1,21 @@
+from joblib import dump
+from sklearn import svm
+from sklearn import datasets
+
+
+def train_model():
+    model = svm.SVC()
+    X, y = datasets.load_digits(return_X_y=True)
+    model.fit(X, y)
+    return model
+
+def export_model(model):
+    dump(model, './digits_model.joblib')
+
+def start():
+    model = train_model()
+    export_model(model)
+    print('Model successfully exported.')
+
+if __name__ == '__main__':
+    start()
diff --git a/test_prediction.py b/test_prediction.py
new file mode 100644
index 0000000..5721ce8
--- /dev/null
+++ b/test_prediction.py
@@ -0,0 +1,17 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn import datasets
+from machine_learning import train_model
+
+
+def test_inference_sample():
+    model = train_model()
+    X, _ = datasets.load_digits(return_X_y=True)
+    prediction = model.predict(X[0:1])[0]
+    assert prediction == 0
+
+def test_inference_batch():
+    model = train_model()
+    X, _ = datasets.load_digits(return_X_y=True)
+    predictions = model.predict(X[0:100])
+    assert np.all(predictions < 10)
\ No newline at end of file
-- 
GitLab