From b34891526a695d43c6f2df6690dd1a5426d16665 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9line?= <liuceline2001@yahoo.fr> Date: Thu, 2 Feb 2023 12:15:56 +0100 Subject: [PATCH] machine learning / yml --- .gitlab-ci.yml | 24 ++++++++++++++++++++++++ machine_learning.py | 21 +++++++++++++++++++++ test_prediction.py | 17 +++++++++++++++++ 3 files changed, 62 insertions(+) create mode 100644 .gitlab-ci.yml create mode 100644 machine_learning.py create mode 100644 test_prediction.py diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100644 index 0000000..4386ca5 --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,24 @@ +image: python:3.9 + +stages: + - build + - test + - export + +# Build stage +build: + stage: build + script: + - python machine_learning.py + +# Test stage +test: + stage: test + script: + - pytest test_prediction.py + +# Export stage +export: + stage: export + script: + - mv digits_model.joblib artifacts/ \ No newline at end of file diff --git a/machine_learning.py b/machine_learning.py new file mode 100644 index 0000000..571411c --- /dev/null +++ b/machine_learning.py @@ -0,0 +1,21 @@ +from joblib import dump +from sklearn import svm +from sklearn import datasets + + +def train_model(): + model = svm.SVC() + X, y = datasets.load_digits(return_X_y=True) + model.fit(X, y) + return model + +def export_model(model): + dump(model, './digits_model.joblib') + +def start(): + model = train_model() + export_model(model) + print('Model successfully exported.') + +if __name__ == '__main__': + start() diff --git a/test_prediction.py b/test_prediction.py new file mode 100644 index 0000000..5721ce8 --- /dev/null +++ b/test_prediction.py @@ -0,0 +1,17 @@ +import numpy as np +import matplotlib.pyplot as plt +from sklearn import datasets +from machine_learning import train_model + + +def test_inference_sample(): + model = train_model() + X, _ = datasets.load_digits(return_X_y=True) + prediction = model.predict(X[0:1])[0] + assert prediction == 0 + +def test_inference_batch(): + model = train_model() + X, _ = datasets.load_digits(return_X_y=True) + predictions = model.predict(X[0:100]) + assert np.all(predictions < 10) \ No newline at end of file -- GitLab