From f8dd4aa2bd5ca7b62f8aedae242d6102398d7893 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 24 Aug 2023 13:29:31 +0200
Subject: [PATCH] Abstractified Model base class

---
 model_server/api.py   | 27 +--------------------------
 model_server/model.py | 18 ++++++++++++++----
 tests/test_model.py   |  4 ++++
 3 files changed, 19 insertions(+), 30 deletions(-)

diff --git a/model_server/api.py b/model_server/api.py
index 2289fea4..47471e32 100644
--- a/model_server/api.py
+++ b/model_server/api.py
@@ -47,29 +47,4 @@ def infer_img(model_id: str, imgf: str, channel: int = None) -> dict:
         channel=channel
     )
     session.record_workflow_run(record)
-    return record
-
-    # model = session.models[model_id]
-    #
-    # # read image file into memory
-    # # maybe this isn't accurate if e.g. czifile loads lazily
-    # t0 = time()
-    # img = generate_file_accessor(session.inbound / imgf)
-    # dt_fi = time() - t0
-    #
-    # # run model inference
-    # outdata, record = model.infer(img.data, channel=channel)
-    # dt_inf = time() - t0
-    #
-    # # write output to file
-    # outpath = session.outbound / img.fpath.stem / '.tif'
-    # WriteableTiffFileAccessor(outpath).write(outdata)
-    # dt_fo = time() - t0
-    #
-    # record['output_file'] = outpath
-    # record['times'] = {
-    #     'file_input': dt_fi,
-    #     'inference': dt_inf - dt_fi,
-    #     'file_output': dt_fo - dt_fi - dt_inf
-    # }
-    # session.register_inference(record)
\ No newline at end of file
+    return record
\ No newline at end of file
diff --git a/model_server/model.py b/model_server/model.py
index 416c0055..c9c3140b 100644
--- a/model_server/model.py
+++ b/model_server/model.py
@@ -1,11 +1,11 @@
+from abc import ABC, abstractmethod
 from math import floor
 
 import numpy as np
 
 from model_server.image import GenericImageFileAccessor
 
-# TODO: properly abstractify base class
-class Model(object):
+class Model(ABC):
 
     def __init__(self, autoload=True):
         """
@@ -15,10 +15,11 @@ class Model(object):
         """
         self.autoload = autoload
 
-    # abstract
+    @abstractmethod
     def load(self):
         pass
 
+    @abstractmethod
     def infer(self,
               img: GenericImageFileAccessor,
               channel: int = None
@@ -33,7 +34,16 @@ class Model(object):
     def reload(self):
         self.load()
 
-class ImageToImageModel(Model): # receives an image and returns an image of the same size
+class ImageToImageModel(Model):
+
+    def __init__(self, **kwargs):
+        """
+        Abstract class for models that receive an image and return an image of the same size
+        :param kwargs: variable length keyword arguments
+        """
+        return super(**kwargs)
+
+    @abstractmethod
     def infer(self, img, channel=None) -> (np.ndarray, dict):
         super().infer(img, channel)
 
diff --git a/tests/test_model.py b/tests/test_model.py
index 43568997..1920ea3b 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -7,6 +7,10 @@ class TestCziImageFileAccess(unittest.TestCase):
     def setUp(self) -> None:
         self.cf = CziImageFileAccessor(czifile['path'])
 
+    def test_instantiate_model_with_nondefault_kwarg(self):
+        model = DummyImageToImageModel(autoload=False)
+        self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.')
+
     def test_czifile_is_correct_shape(self):
         model = DummyImageToImageModel()
         img, _ = model.infer(self.cf, channel=1)
-- 
GitLab