From a0b7809f5d63247ad61047eb2ec2cdfa64ea7229 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Tue, 31 Oct 2023 15:09:07 +0100
Subject: [PATCH] Merge in ilastik API support for segmentation-to-object
 workflow

---
 .gitignore                               |  4 +-
 conf/testing.py                          |  7 +-
 extensions/ilastik/router.py             | 20 +++---
 extensions/ilastik/tests/test_ilastik.py | 82 ++++++++++++++----------
 4 files changed, 67 insertions(+), 46 deletions(-)

diff --git a/.gitignore b/.gitignore
index f5570b3f..7efb7f0b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,4 @@
 */.idea/*
-*__pycache__*
\ No newline at end of file
+*__pycache__*
+/clients/imagej/.idea/workspace.xml
+/clients/imagej/.idea/
diff --git a/conf/testing.py b/conf/testing.py
index 3c52b95c..1c2b0cf3 100644
--- a/conf/testing.py
+++ b/conf/testing.py
@@ -53,9 +53,10 @@ monozstackmask = {
     'z': 85
 }
 
-ilastik = {
-    'pixel_classifier': 'demo_px.ilp',
-    'object_classifier': 'demo_obj.ilp',
+ilastik_classifiers = {
+    'px': 'demo_px.ilp',
+    'pxmap_to_obj': 'demo_obj.ilp',
+    'seg_to_obj': 'new_auto_obj.ilp',
 }
 
 output_path = root / 'testing_output'
diff --git a/extensions/ilastik/router.py b/extensions/ilastik/router.py
index 3287dc47..ba01d1b9 100644
--- a/extensions/ilastik/router.py
+++ b/extensions/ilastik/router.py
@@ -3,7 +3,7 @@ from fastapi import APIRouter, HTTPException
 from model_server.session import Session
 from model_server.validators import  validate_workflow_inputs
 
-from extensions.ilastik.models import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierFromPixelPredictionsModel
+from extensions.ilastik import models as ilm
 from model_server.models import ParameterExpectedError
 from extensions.ilastik.workflows import infer_px_then_ob_model
 
@@ -14,7 +14,7 @@ router = APIRouter(
 
 session = Session()
 
-def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str, duplicate=True) -> dict:
+def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: str, duplicate=True) -> dict:
     """
     Load an ilastik model of a given class and project filename.
     :param model_class:
@@ -40,13 +40,17 @@ def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str,
         )
     return result
 
-@router.put('/pixel_classification/load/')
-def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool = True) -> dict:
-    return load_ilastik_model(IlastikPixelClassifierModel, project_file, duplicate=duplicate)
+@router.put('/px/load/')
+def load_px_model(project_file: str, duplicate: bool = True) -> dict:
+    return load_ilastik_model(ilm.IlastikPixelClassifierModel, project_file, duplicate=duplicate)
 
-@router.put('/object_classification/load/')
-def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict:
-    return load_ilastik_model(IlastikObjectClassifierFromPixelPredictionsModel, project_file, duplicate=duplicate)
+@router.put('/pxmap_to_obj/load/')
+def load_pxmap_to_obj_model(project_file: str, duplicate: bool = True) -> dict:
+    return load_ilastik_model(ilm.IlastikObjectClassifierFromPixelPredictionsModel, project_file, duplicate=duplicate)
+
+@router.put('/seg_to_obj/load/')
+def load_seg_to_obj_model(project_file: str, duplicate: bool = True) -> dict:
+    return load_ilastik_model(ilm.IlastikObjectClassifierFromSegmentationModel, project_file, duplicate=duplicate)
 
 @router.put('/pixel_then_object_classification/infer')
 def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict:
diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py
index 217b00c9..d2498af2 100644
--- a/extensions/ilastik/tests/test_ilastik.py
+++ b/extensions/ilastik/tests/test_ilastik.py
@@ -3,15 +3,15 @@ import unittest
 
 import numpy as np
 
-import conf.testing
+from conf.testing import czifile, ilastik_classifiers, output_path
 from model_server.accessors import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file
-from extensions.ilastik.models import IlastikObjectClassifierFromPixelPredictionsModel, IlastikPixelClassifierModel
+from extensions.ilastik import models as ilm
 from model_server.workflows import infer_image_to_image
 from tests.test_api import TestServerBaseClass
 
 class TestIlastikPixelClassification(unittest.TestCase):
     def setUp(self) -> None:
-        self.cf = CziImageFileAccessor(conf.testing.czifile['path'])
+        self.cf = CziImageFileAccessor(czifile['path'])
 
 
     def test_faulthandler(self): # recreate error that is messing up ilastik
@@ -24,8 +24,8 @@ class TestIlastikPixelClassification(unittest.TestCase):
 
 
     def test_raise_error_if_autoload_disabled(self):
-        model = IlastikPixelClassifierModel(
-            {'project_file': conf.testing.ilastik['pixel_classifier']},
+        model = ilm.IlastikPixelClassifierModel(
+            {'project_file': ilastik_classifiers['px']},
             autoload=False
         )
         w = 512
@@ -38,8 +38,8 @@ class TestIlastikPixelClassification(unittest.TestCase):
 
 
     def test_run_pixel_classifier_on_random_data(self):
-        model = IlastikPixelClassifierModel(
-            {'project_file': conf.testing.ilastik['pixel_classifier']},
+        model = ilm.IlastikPixelClassifierModel(
+            {'project_file': ilastik_classifiers['px']},
         )
         w = 512
         h = 256
@@ -52,16 +52,16 @@ class TestIlastikPixelClassification(unittest.TestCase):
 
     def test_run_pixel_classifier(self):
         channel = 0
-        model = IlastikPixelClassifierModel(
-            {'project_file': conf.testing.ilastik['pixel_classifier']}
+        model = ilm.IlastikPixelClassifierModel(
+            {'project_file': ilastik_classifiers['px']}
         )
         cf = CziImageFileAccessor(
-            conf.testing.czifile['path']
+            czifile['path']
         )
         mono_image = cf.get_one_channel_data(channel)
 
-        self.assertEqual(mono_image.shape_dict['X'], conf.testing.czifile['w'])
-        self.assertEqual(mono_image.shape_dict['Y'], conf.testing.czifile['h'])
+        self.assertEqual(mono_image.shape_dict['X'], czifile['w'])
+        self.assertEqual(mono_image.shape_dict['Y'], czifile['h'])
         self.assertEqual(mono_image.shape_dict['C'], 1)
         self.assertEqual(mono_image.shape_dict['Z'], 1)
 
@@ -72,7 +72,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
         self.assertEqual(pxmap.shape_dict['Z'], 1)
         self.assertTrue(
             write_accessor_data_to_file(
-                conf.testing.output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
+                output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
                 pxmap
             )
         )
@@ -82,15 +82,15 @@ class TestIlastikPixelClassification(unittest.TestCase):
 
     def test_run_object_classifier(self):
         self.test_run_pixel_classifier()
-        fp = conf.testing.czifile['path']
-        model = IlastikObjectClassifierFromPixelPredictionsModel(
-            {'project_file': conf.testing.ilastik['object_classifier']}
+        fp = czifile['path']
+        model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
+            {'project_file': ilastik_classifiers['pxmap_to_obj']}
         )
         objmap, _ = model.infer(self.mono_image, self.pxmap)
 
         self.assertTrue(
             write_accessor_data_to_file(
-                conf.testing.output_path / f'obmap_{fp.stem}.tif',
+                output_path / f'obmap_{fp.stem}.tif',
                 objmap,
             )
         )
@@ -98,11 +98,11 @@ class TestIlastikPixelClassification(unittest.TestCase):
 
     def test_ilastik_pixel_classification_as_workflow(self):
         result = infer_image_to_image(
-            conf.testing.czifile['path'],
-            IlastikPixelClassifierModel(
-                {'project_file': conf.testing.ilastik['pixel_classifier']}
+            czifile['path'],
+            ilm.IlastikPixelClassifierModel(
+                {'project_file': ilastik_classifiers['px']}
             ),
-            conf.testing.output_path,
+            output_path,
             channel=0,
         )
         self.assertTrue(result.success)
@@ -112,7 +112,7 @@ class TestIlastikOverApi(TestServerBaseClass):
 
     def test_httpexception_if_incorrect_project_file_loaded(self):
         resp_load = requests.put(
-            self.uri + 'ilastik/pixel_classification/load/',
+            self.uri + 'ilastik/px/load/',
             params={'project_file': 'improper.ilp'},
         )
         self.assertEqual(resp_load.status_code, 404)
@@ -120,8 +120,8 @@ class TestIlastikOverApi(TestServerBaseClass):
 
     def test_load_ilastik_pixel_model(self):
         resp_load = requests.put(
-            self.uri + 'ilastik/pixel_classification/load/',
-            params={'project_file': str(conf.testing.ilastik['pixel_classifier'])},
+            self.uri + 'ilastik/px/load/',
+            params={'project_file': str(ilastik_classifiers['px'])},
         )
         model_id = resp_load.json()['model_id']
 
@@ -137,18 +137,18 @@ class TestIlastikOverApi(TestServerBaseClass):
         resp_list_1st = requests.get(self.uri + 'models').json()
         self.assertEqual(len(resp_list_1st), 1, resp_list_1st)
         resp_load_2nd = requests.put(
-            self.uri + 'ilastik/pixel_classification/load/',
+            self.uri + 'ilastik/px/load/',
             params={
-                'project_file': str(conf.testing.ilastik['pixel_classifier']),
+                'project_file': str(ilastik_classifiers['px']),
                 'duplicate': True,
             },
         )
         resp_list_2nd = requests.get(self.uri + 'models').json()
         self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd)
         resp_load_3rd = requests.put(
-            self.uri + 'ilastik/pixel_classification/load/',
+            self.uri + 'ilastik/px/load/',
             params={
-                'project_file': str(conf.testing.ilastik['pixel_classifier']),
+                'project_file': str(ilastik_classifiers['px']),
                 'duplicate': False,
             },
         )
@@ -156,10 +156,10 @@ class TestIlastikOverApi(TestServerBaseClass):
         self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd)
 
 
-    def test_load_ilastik_object_model(self):
+    def test_load_ilastik_pxmap_to_obj_model(self):
         resp_load = requests.put(
-            self.uri + 'ilastik/object_classification/load/',
-            params={'project_file': str(conf.testing.ilastik['object_classifier'])},
+            self.uri + 'ilastik/pxmap_to_obj/load/',
+            params={'project_file': str(ilastik_classifiers['pxmap_to_obj'])},
         )
         model_id = resp_load.json()['model_id']
 
@@ -170,6 +170,20 @@ class TestIlastikOverApi(TestServerBaseClass):
         self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromPixelPredictionsModel')
         return model_id
 
+    def test_load_ilastik_seg_to_obj_model(self):
+        resp_load = requests.put(
+            self.uri + 'ilastik/seg_to_obj/load/',
+            params={'project_file': str(ilastik_classifiers['seg_to_obj'])},
+        )
+        model_id = resp_load.json()['model_id']
+
+        self.assertEqual(resp_load.status_code, 200, resp_load.json())
+        resp_list = requests.get(self.uri + 'models')
+        self.assertEqual(resp_list.status_code, 200)
+        rj = resp_list.json()
+        self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromSegmentationModel')
+        return model_id
+
     def test_ilastik_infer_pixel_probability(self):
         self.copy_input_file_to_server()
         model_id = self.test_load_ilastik_pixel_model()
@@ -178,7 +192,7 @@ class TestIlastikOverApi(TestServerBaseClass):
             self.uri + f'infer/from_image_file',
             params={
                 'model_id': model_id,
-                'input_filename': conf.testing.czifile['filename'],
+                'input_filename': czifile['filename'],
                 'channel': 0,
             },
         )
@@ -187,14 +201,14 @@ class TestIlastikOverApi(TestServerBaseClass):
     def test_ilastik_infer_px_then_ob(self):
         self.copy_input_file_to_server()
         px_model_id = self.test_load_ilastik_pixel_model()
-        ob_model_id = self.test_load_ilastik_object_model()
+        ob_model_id = self.test_load_ilastik_pxmap_to_obj_model()
 
         resp_infer = requests.put(
             self.uri + f'ilastik/pixel_then_object_classification/infer/',
             params={
                 'px_model_id': px_model_id,
                 'ob_model_id': ob_model_id,
-                'input_filename': conf.testing.czifile['filename'],
+                'input_filename': czifile['filename'],
                 'channel': 0,
             }
         )
-- 
GitLab