From 976206d74445dfc43edce46cc1ffae749170bf8a Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 14 Aug 2024 16:49:18 +0200
Subject: [PATCH] Put simplest functional RoiSet pipeline model_server for now;
 derived channel and boundary seg alternatives in trec-adaptive-feedback

---
 model_server/base/pipelines/roiset_obmap.py | 56 ++-------------------
 tests/test_ilastik/test_roiset_workflow.py  | 42 ++++++++--------
 2 files changed, 26 insertions(+), 72 deletions(-)

diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py
index 8a891b4d..38313675 100644
--- a/model_server/base/pipelines/roiset_obmap.py
+++ b/model_server/base/pipelines/roiset_obmap.py
@@ -7,8 +7,9 @@ from pydantic import Field
 
 from model_server.base.accessors import GenericImageDataAccessor, PatchStack
 from model_server.base.pipelines.params import PipelineParams, PipelineRecord
+from model_server.base.pipelines.segment import segment_pipeline
 from model_server.base.pipelines.util import call_pipeline
-from model_server.base.roiset import RoiSet, RoiSetMetaParams, RoiSetExportParams
+from model_server.base.roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams
 from model_server.base.session import session
 from model_server.base.util import PipelineTrace
 
@@ -73,8 +74,6 @@ class RoiSetObjectMapParams(PipelineParams):
         None,
         description='Derived channels to send to object classifier; use all if empty'
     )
-    # TODO: move to subclassed pipeline
-    # label_params: Union[LabelFromBoundarySegParams, None] = None
     export_label_interm: bool = False
 
 
@@ -92,7 +91,6 @@ def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord:
     assert isinstance(rois, RoiSet)
     table = rois.get_serializable_dataframe()
 
-
     # TODO: instead, explicitly let trace include RoiSets
     session.write_to_table('RoiSet', {'input_filename': p.accessor_id}, table)
     ret = RoiSetToObjectMapRecord(
@@ -101,20 +99,6 @@ def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord:
     )
     return ret
 
-# def _segment(d: PipelineTrace, models: Dict[str, Model], **k) -> PipelineTrace:
-#     # MIP if no zmask z-index is given, then classify pixels
-#     zmi = k['zmask_zindex']
-#     sch = k['segmentation_channel']
-#
-#     if isinstance(zmi, int):
-#         assert 0 < zmi < d.last.nz
-#         d['mip'] = d.last.get_mono(channel=sch).get_zi(zmi)
-#     else:
-#         d['mip'] = d.last.get_mono(channel=sch).apply(lambda x: x.max(axis=-1, keepdims=True))
-#
-#     d['mip_mask'] = models['pixel_classifier_segmentation_model'].label_pixel_class(d.last)
-#     return d
-
 
 def roiset_object_map_pipeline(
         accessors: Dict[str, GenericImageDataAccessor],
@@ -124,10 +108,7 @@ def roiset_object_map_pipeline(
     if not isinstance(models['pixel_classifier_segmentation_model'], SemanticSegmentationModel):
         raise IncompatibleModelsError('Expecting a pixel classification model')
 
-    d = PipelineTrace()
-    d['raw'] = accessors['accessor']
-    # d = _segment(d, models, **k)
-
+    d = PipelineTrace(accessors['accessor'])
     zmi = k['zmask_zindex']
     sch = k['segmentation_channel']
 
@@ -138,28 +119,8 @@ def roiset_object_map_pipeline(
         d['mip'] = d.last.get_mono(channel=sch).apply(lambda x: x.max(axis=-1, keepdims=True))
 
     d['mip_mask'] = models['pixel_classifier_segmentation_model'].label_pixel_class(d.last)
-
-    rois = RoiSet.from_binary_mask(d['raw'], d.last, RoiSetMetaParams(**k['roi_params']))
-
-    # TODO: split out derived channel logic
-    # # optionally derive additional inputs for object classification
-    # if dpmod := models.get('pixel_classifier_derived_model'):
-    #     ic = k['derived_channels_input_channel']
-    #     ocs = k['derived_channels_output_channels']
-    #     assert ic < d['raw'].chroma
-    #     if not isinstance(dpmod, SemanticSegmentationModel):
-    #         raise IncompatibleModelsError('Expecting pixel_classifier_derived to be a pixel classification model')
-    #
-    #     def _derive(acc: GenericImageDataAccessor, oc: int):
-    #         acc_mono = acc.get_channels([ic])
-    #         pxmap = dpmod.infer_patch_stack(acc_mono)[0]
-    #         assert oc < pxmap.chroma
-    #         return PatchStack((pxmap.get_channels([oc]).data * 255).astype('uint8'))
-    #     derived_channel_handles = [
-    #         lambda a: _derive(a, oc) for oc in ocs
-    #     ]
-    # else:
-    #     derived_channel_handles = None
+    d['segmentation'] = get_label_ids(d.last)
+    rois = RoiSet.from_object_ids(d['input'], d['segmentation'], RoiSetMetaParams(**k['roi_params']))
 
     # optionally run an object classifier if specified
     if obmod := models.get('object_classifier_model'):
@@ -169,16 +130,9 @@ def roiset_object_map_pipeline(
             obmod_name,
             [k['patches_channel']],
             obmod,
-            # derived_channel_functions=derived_channel_handles
         )
-        # TODO: split out derived channel logic
-        # # add derived channels to intermediate data
-        # for di, dacc in enumerate(rois.accs_derived):
-        #     d[f'derived_{di:02d}'] = dacc
-
         d[obmod_name] = rois.get_object_class_map(obmod_name)
 
-
     return d, rois
 
 class Error(Exception):
diff --git a/tests/test_ilastik/test_roiset_workflow.py b/tests/test_ilastik/test_roiset_workflow.py
index 27fd0d7f..bfe51ed0 100644
--- a/tests/test_ilastik/test_roiset_workflow.py
+++ b/tests/test_ilastik/test_roiset_workflow.py
@@ -118,27 +118,27 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):
         self.assertEqual(len(trace['ob_id']._unique()[0]), 2)
 
 
-    def test_object_map_workflow_with_derived_channels(self):
-        acc_in = generate_file_accessor(self.fpi)
-        params = RoiSetObjectMapParams(
-            **self._pipeline_params(),
-            pixel_classifier_derived_model_id='id_px',
-            derived_channels_input_channel=0,
-            derived_channels_output_channels=[1],
-        )
-        models = self._get_models()
-        models['pixel_classifier_derived'] = models['pixel_classifier_segmentation']  # re-use same classifier
-
-        trace, _ = roiset_object_map_pipeline(
-            {'accessor': acc_in},
-            {f'{k}_model': v['model'] for k, v in models.items()},
-            **params.dict()
-        )
-
-        self.assertTrue('ob_id' in trace.keys())
-        self.assertEqual(len(trace['ob_id']._unique()[0]), 2)
-        self.assertTrue('derived_00' in trace.keys())
-        self.assertEqual(trace['derived_00'].chroma, 1)
+    # def test_object_map_workflow_with_derived_channels(self):
+    #     acc_in = generate_file_accessor(self.fpi)
+    #     params = RoiSetObjectMapParams(
+    #         **self._pipeline_params(),
+    #         pixel_classifier_derived_model_id='id_px',
+    #         derived_channels_input_channel=0,
+    #         derived_channels_output_channels=[1],
+    #     )
+    #     models = self._get_models()
+    #     models['pixel_classifier_derived'] = models['pixel_classifier_segmentation']  # re-use same classifier
+    #
+    #     trace, _ = roiset_object_map_pipeline(
+    #         {'accessor': acc_in},
+    #         {f'{k}_model': v['model'] for k, v in models.items()},
+    #         **params.dict()
+    #     )
+    #
+    #     self.assertTrue('ob_id' in trace.keys())
+    #     self.assertEqual(len(trace['ob_id']._unique()[0]), 2)
+    #     self.assertTrue('derived_00' in trace.keys())
+    #     self.assertEqual(trace['derived_00'].chroma, 1)
 
 
 class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts):
-- 
GitLab