From 9010a5a196f2fe8a7325ef22c4f2d56677324d4d Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 30 Aug 2023 16:44:59 +0200
Subject: [PATCH] Stuck on type validation in test. def
 test_load_ilastik_model()

---
 api.py                         | 13 ++++++++++---
 model_server/ilastik.py        | 32 +++++++++++++++++++-------------
 model_server/model.py          | 11 -----------
 model_server/model_registry.py | 17 +++++++++++++++++
 model_server/session.py        |  3 ++-
 run_server.py                  |  7 +++++++
 tests/test_api.py              | 12 +++++++++++-
 tests/test_ilastik.py          | 25 ++++++++++++++++++++++++-
 8 files changed, 90 insertions(+), 30 deletions(-)
 create mode 100644 model_server/model_registry.py
 create mode 100644 run_server.py

diff --git a/api.py b/api.py
index 32eea187..0167a17c 100644
--- a/api.py
+++ b/api.py
@@ -2,7 +2,7 @@ from typing import Dict
 
 from fastapi import FastAPI, HTTPException
 
-from model_server.session import Session
+from model_server.session import CouldNotFindModelError, Session
 from model_server.workflow import infer_image_to_image
 
 app = FastAPI(debug=True)
@@ -21,13 +21,20 @@ def list_active_models():
     return session.describe_loaded_models()
 
 @app.put('/models/load/')
-def load_model(model_id: str, params: Dict[str, str] = None) -> dict:
+# def load_model(model_id: str, misc: Dict[str, str]) -> dict:
+def load_model(model_id: str, misc: dict) -> dict:
     if model_id in session.models.keys():
         raise HTTPException(
             status_code=409,
             detail=f'Model with id {model_id} has already been loaded'
         )
-    session.load_model(model_id, params=params)
+    try:
+        session.load_model(model_id, params=misc)
+    except CouldNotFindModelError:
+        raise HTTPException(
+            status_code=404,
+            detail=f'Could not find {model_id} in defined models'
+        )
     return session.describe_loaded_models()
 
 @app.put('/i2i/infer/')
diff --git a/model_server/ilastik.py b/model_server/ilastik.py
index 99a812a7..606c93d4 100644
--- a/model_server/ilastik.py
+++ b/model_server/ilastik.py
@@ -1,11 +1,5 @@
 import os
 
-from ilastik import app
-from ilastik.applets.dataSelection import DatasetInfo, FilesystemDatasetInfo
-from ilastik.applets.dataSelection.opDataSelection import PreloadedArrayDatasetInfo
-from ilastik.workflows.pixelClassification import PixelClassificationWorkflow
-from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflow
-
 import numpy as np
 import vigra
 
@@ -20,11 +14,15 @@ class IlastikImageToImageModel(ImageToImageModel):
             raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file')
         self.project_file = str(params['project_file'])
         self.shell = None
-        # self.operator = None
         super().__init__(autoload, params)
 
 
     def load(self):
+        from ilastik import app
+        from ilastik.applets.dataSelection.opDataSelection import PreloadedArrayDatasetInfo
+
+        self.PreloadedArrayDatasetInfo = PreloadedArrayDatasetInfo
+
         os.environ["LAZYFLOW_THREADS"] = "8"
         os.environ["LAZYFLOW_TOTAL_RAM_MB"] = "24000"
 
@@ -33,7 +31,7 @@ class IlastikImageToImageModel(ImageToImageModel):
         args.project = self.project_file
         shell = app.main(args)
 
-        if not isinstance(shell.workflow, self.workflow):
+        if not isinstance(shell.workflow, self.get_workflow()):
             raise ParameterExpectedError(
                 f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}'
             )
@@ -44,13 +42,17 @@ class IlastikImageToImageModel(ImageToImageModel):
 
 class IlastikPixelClassifierModel(IlastikImageToImageModel):
     model_id = 'ilastik_pixel_classification'
-    workflow = PixelClassificationWorkflow
+
+    @staticmethod
+    def get_workflow():
+        from ilastik.workflows.pixelClassification import PixelClassificationWorkflow
+        return PixelClassificationWorkflow
 
     def infer(self, input_img: GenericImageDataAccessor) -> (np.ndarray, dict):
         tagged_input_data = vigra.taggedView(input_img.data, 'xycz')
         dsi = [
             {
-                'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
+                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
             }
         ]
         pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [1 x w x h x n]
@@ -66,7 +68,11 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel):
 
 class IlastikObjectClassifierModel(IlastikImageToImageModel):
     model_id = 'ilastik_object_classification'
-    workflow = ObjectClassificationWorkflow
+
+    @staticmethod
+    def get_workflow():
+        from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflow
+        return ObjectClassificationWorkflow
 
     def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
         tagged_input_data = vigra.taggedView(input_img.data, 'xycz')
@@ -74,8 +80,8 @@ class IlastikObjectClassifierModel(IlastikImageToImageModel):
 
         dsi = [
             {
-                'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
-                'Prediction Maps': PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data),
+                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
+                'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data),
             }
         ]
 
diff --git a/model_server/model.py b/model_server/model.py
index 7ee63af4..a2d75c7b 100644
--- a/model_server/model.py
+++ b/model_server/model.py
@@ -27,17 +27,6 @@ class Model(ABC):
             raise CouldNotLoadModelError()
         return None
 
-    @classmethod
-    def get_all_subclasses(cls):
-        """
-        Recursively find all subclasses of Model
-        :return: set of all subclasses of Model
-        """
-        def get_all_subclasses_of(cc):
-            return set(cc.__subclasses__()).union(
-                [s for c in cc.__subclasses__() for s in get_all_subclasses_of(c)])
-        return get_all_subclasses_of(cls)
-
     @abstractmethod
     def load(self):
         """
diff --git a/model_server/model_registry.py b/model_server/model_registry.py
new file mode 100644
index 00000000..f6686dc0
--- /dev/null
+++ b/model_server/model_registry.py
@@ -0,0 +1,17 @@
+import model_server.ilastik
+import model_server.model
+
+def get_all_model_subclasses():
+    """
+    Recursively find all subclasses of Model
+    :return: set of all subclasses of Model
+    """
+
+    def get_all_subclasses_of(cc):
+        return set(cc.__subclasses__()).union(
+            [s for c in cc.__subclasses__() for s in get_all_subclasses_of(c)])
+
+    return get_all_subclasses_of(model_server.model.Model)
+
+if __name__ == '__main__':
+    print(get_all_model_subclasses())
\ No newline at end of file
diff --git a/model_server/session.py b/model_server/session.py
index 7508f103..69a94fcb 100644
--- a/model_server/session.py
+++ b/model_server/session.py
@@ -7,6 +7,7 @@ from typing import Dict
 
 from conf.server import paths
 from model_server.model import Model
+from model_server.model_registry import get_all_model_subclasses
 from model_server.share import SharedImageDirectory
 from model_server.workflow import WorkflowRunRecord
 
@@ -70,7 +71,7 @@ class Session(object):
         :param params: optional parameters that are passed upon loading a model
         :return: True if model successfully loaded, False if not
         """
-        models = Model.get_all_subclasses()
+        models = get_all_model_subclasses()
         for mc in models:
             if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
                 try:
diff --git a/run_server.py b/run_server.py
new file mode 100644
index 00000000..1e90251c
--- /dev/null
+++ b/run_server.py
@@ -0,0 +1,7 @@
+import uvicorn
+
+host = '127.0.0.1'
+port = 8001
+
+if __name__ == '__main__':
+    uvicorn.run('api:app', **{'host': host, 'port': port, 'log_level': 'debug'}, reload=True)
\ No newline at end of file
diff --git a/tests/test_api.py b/tests/test_api.py
index 0b2e844b..bcb366ec 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -6,7 +6,7 @@ import unittest
 from conf.testing import czifile
 from model_server.model import DummyImageToImageModel
 
-class TestApiFromAutomatedClient(unittest.TestCase):
+class TestServerBaseClass(unittest.TestCase):
     def setUp(self) -> None:
         import uvicorn
         host = '127.0.0.1'
@@ -37,6 +37,7 @@ class TestApiFromAutomatedClient(unittest.TestCase):
     def tearDown(self) -> None:
         self.server_process.terminate()
 
+class TestApiFromAutomatedClient(TestServerBaseClass):
     def test_trivial_api_response(self):
         resp = requests.get(self.uri, )
         self.assertEqual(resp.status_code, 200)
@@ -59,6 +60,15 @@ class TestApiFromAutomatedClient(unittest.TestCase):
         self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel')
         return model_id
 
+    def test_respond_with_error_when_invalid_model_loaded(self):
+        model_id = 'not_a_real_model'
+        resp = requests.put(
+            self.uri + f'models/load',
+            params={'model_id': model_id}
+        )
+        self.assertEqual(resp.status_code, 404)
+        print(resp.content)
+
     def test_respond_with_error_when_invalid_filepath_requested(self):
         model_id = self.test_load_dummy_model()
         resp = requests.put(
diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py
index a2d204e6..643225b2 100644
--- a/tests/test_ilastik.py
+++ b/tests/test_ilastik.py
@@ -1,3 +1,4 @@
+import requests
 import unittest
 
 import numpy as np
@@ -5,7 +6,9 @@ import numpy as np
 from conf.testing import czifile, ilastik, output_path
 from model_server.image import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file
 from model_server.ilastik import IlastikObjectClassifierModel, IlastikPixelClassifierModel
+from model_server.model import Model
 from model_server.workflow import infer_image_to_image
+from tests.test_api import TestServerBaseClass
 
 class TestIlastikPixelClassification(unittest.TestCase):
     def setUp(self) -> None:
@@ -31,6 +34,10 @@ class TestIlastikPixelClassification(unittest.TestCase):
         with self.assertRaises(AttributeError):
             pxmap , _= model.infer(input_img)
 
+    def test_ilastik_subclasses_are_found(self):
+        self.assertIn(IlastikPixelClassifierModel, Model.get_all_subclasses())
+        self.assertIn(IlastikObjectClassifierModel, Model.get_all_subclasses())
+
     def test_run_pixel_classifier_on_random_data(self):
         model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']})
         w = 512
@@ -89,4 +96,20 @@ class TestIlastikPixelClassification(unittest.TestCase):
             channel=0,
         )
         self.assertTrue(result.success)
-        self.assertGreater(result.timer_results['inference'], 1.0)
\ No newline at end of file
+        self.assertGreater(result.timer_results['inference'], 1.0)
+
+class TestIlastikOverApi(TestServerBaseClass):
+    def test_load_ilastik_model(self):
+        model_id = IlastikPixelClassifierModel.model_id
+        resp_load = requests.put(
+            self.uri + f'models/load',
+            params={'model_id': model_id},
+            # data={'project_file': str(ilastik['pixel_classifier'])},
+            data={'project_file': 'hii',},
+        )
+        self.assertEqual(resp_load.status_code, 200, resp_load.content)
+        # resp_list = requests.get(self.uri + 'models')
+        # self.assertEqual(resp_list.status_code, 200)
+        # rj = resp_list.json()
+        # self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel')
+        # return model_id
\ No newline at end of file
-- 
GitLab