From b9f27095a0ca3494043beed30afb927fe66add30 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 31 Aug 2023 16:33:41 +0200
Subject: [PATCH] Can now load ilastik models by API

---
 api.py                | 36 ++++++++++++++----------------------
 tests/test_api.py     |  1 -
 tests/test_ilastik.py | 39 ++++++++++++++++++++++++---------------
 3 files changed, 38 insertions(+), 38 deletions(-)

diff --git a/api.py b/api.py
index 1c168b44..ea23b59e 100644
--- a/api.py
+++ b/api.py
@@ -27,30 +27,22 @@ def load_dummy_model() -> dict:
     return {'model_id': session.load_model(DummyImageToImageModel)}
 
 @app.put('/models/ilastik/pixel_classification/load/')
-def load_ilastik_pixel_classification_model(params: str) -> dict:
-    return {'model_id': session.load_model(IlastikPixelClassifierModel, params)}
+def load_ilastik_pixel_classification_model(project_file: str) -> dict:
+    return {
+        'model_id': session.load_model(
+            IlastikPixelClassifierModel,
+            {'project_file': project_file}
+        )
+    }
 
 @app.put('/models/ilastik/object_classification/load/')
-def load_ilastik_object_classification_model(params: str) -> dict:
-    return {'model_id': session.load_model(IlastikObjectClassifierModel, params)}
-
-# @app.put('/models/ilastik/pixel_classification/load/')
-# def infer_ilastik_pixel_classification_from_file(input_filename: str, channel: int = None) -> dict:
-#     inpath = session.inbound.path / input_filename
-#     if not inpath.exists():
-#         raise HTTPException(
-#             status_code=404,
-#             detail=f'Could not find file:\n{inpath}'
-#         )
-#
-#     record = infer_image_to_image(
-#         inpath,
-#         session.models[model_id]['object'],
-#         session.outbound.path,
-#         channel=channel,
-#     )
-#     session.record_workflow_run(record)
-#     return record
+def load_ilastik_object_classification_model(project_file: str) -> dict:
+    return {
+        'model_id': session.load_model(
+            IlastikObjectClassifierModel,
+            {'project_file': project_file}
+        )
+    }
 
 @app.put('/infer/from_image_file')
 def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
diff --git a/tests/test_api.py b/tests/test_api.py
index c61e4c2b..e4091cb2 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -47,7 +47,6 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
         self.assertEqual(resp.content, b'{}')
 
     def test_load_dummy_model(self):
-        # model_key = DummyImageToImageModel.__name__ + '_00'
         resp_load = requests.put(
             self.uri + f'models/dummy/load',
         )
diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py
index 643225b2..8b687874 100644
--- a/tests/test_ilastik.py
+++ b/tests/test_ilastik.py
@@ -34,9 +34,6 @@ class TestIlastikPixelClassification(unittest.TestCase):
         with self.assertRaises(AttributeError):
             pxmap , _= model.infer(input_img)
 
-    def test_ilastik_subclasses_are_found(self):
-        self.assertIn(IlastikPixelClassifierModel, Model.get_all_subclasses())
-        self.assertIn(IlastikObjectClassifierModel, Model.get_all_subclasses())
 
     def test_run_pixel_classifier_on_random_data(self):
         model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']})
@@ -48,6 +45,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
         pxmap, _ = model.infer(input_img)
         self.assertEqual(pxmap.shape, (w, h, 2, 1))
 
+
     def test_run_pixel_classifier(self):
         channel = 0
         model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']})
@@ -99,17 +97,28 @@ class TestIlastikPixelClassification(unittest.TestCase):
         self.assertGreater(result.timer_results['inference'], 1.0)
 
 class TestIlastikOverApi(TestServerBaseClass):
-    def test_load_ilastik_model(self):
-        model_id = IlastikPixelClassifierModel.model_id
+    def test_load_ilastik_pixel_model(self):
+        resp_load = requests.put(
+            self.uri + 'models/ilastik/pixel_classification/load/',
+            params={'project_file': str(ilastik['pixel_classifier'])},
+        )
+        model_id = resp_load.json()['model_id']
+
+        self.assertEqual(resp_load.status_code, 200, resp_load.json())
+        resp_list = requests.get(self.uri + 'models')
+        self.assertEqual(resp_list.status_code, 200)
+        rj = resp_list.json()
+        self.assertEqual(rj[model_id]['class'], 'IlastikPixelClassifierModel')
+
+    def test_load_ilastik_object_model(self):
         resp_load = requests.put(
-            self.uri + f'models/load',
-            params={'model_id': model_id},
-            # data={'project_file': str(ilastik['pixel_classifier'])},
-            data={'project_file': 'hii',},
+            self.uri + 'models/ilastik/object_classification/load/',
+            params={'project_file': str(ilastik['object_classifier'])},
         )
-        self.assertEqual(resp_load.status_code, 200, resp_load.content)
-        # resp_list = requests.get(self.uri + 'models')
-        # self.assertEqual(resp_list.status_code, 200)
-        # rj = resp_list.json()
-        # self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel')
-        # return model_id
\ No newline at end of file
+        model_id = resp_load.json()['model_id']
+
+        self.assertEqual(resp_load.status_code, 200, resp_load.json())
+        resp_list = requests.get(self.uri + 'models')
+        self.assertEqual(resp_list.status_code, 200)
+        rj = resp_list.json()
+        self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierModel')
\ No newline at end of file
-- 
GitLab