From 976206d74445dfc43edce46cc1ffae749170bf8a Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 14 Aug 2024 16:49:18 +0200 Subject: [PATCH] Put simplest functional RoiSet pipeline model_server for now; derived channel and boundary seg alternatives in trec-adaptive-feedback --- model_server/base/pipelines/roiset_obmap.py | 56 ++------------------- tests/test_ilastik/test_roiset_workflow.py | 42 ++++++++-------- 2 files changed, 26 insertions(+), 72 deletions(-) diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py index 8a891b4d..38313675 100644 --- a/model_server/base/pipelines/roiset_obmap.py +++ b/model_server/base/pipelines/roiset_obmap.py @@ -7,8 +7,9 @@ from pydantic import Field from model_server.base.accessors import GenericImageDataAccessor, PatchStack from model_server.base.pipelines.params import PipelineParams, PipelineRecord +from model_server.base.pipelines.segment import segment_pipeline from model_server.base.pipelines.util import call_pipeline -from model_server.base.roiset import RoiSet, RoiSetMetaParams, RoiSetExportParams +from model_server.base.roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams from model_server.base.session import session from model_server.base.util import PipelineTrace @@ -73,8 +74,6 @@ class RoiSetObjectMapParams(PipelineParams): None, description='Derived channels to send to object classifier; use all if empty' ) - # TODO: move to subclassed pipeline - # label_params: Union[LabelFromBoundarySegParams, None] = None export_label_interm: bool = False @@ -92,7 +91,6 @@ def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord: assert isinstance(rois, RoiSet) table = rois.get_serializable_dataframe() - # TODO: instead, explicitly let trace include RoiSets session.write_to_table('RoiSet', {'input_filename': p.accessor_id}, table) ret = RoiSetToObjectMapRecord( @@ -101,20 +99,6 @@ def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord: ) return ret -# def _segment(d: PipelineTrace, models: Dict[str, Model], **k) -> PipelineTrace: -# # MIP if no zmask z-index is given, then classify pixels -# zmi = k['zmask_zindex'] -# sch = k['segmentation_channel'] -# -# if isinstance(zmi, int): -# assert 0 < zmi < d.last.nz -# d['mip'] = d.last.get_mono(channel=sch).get_zi(zmi) -# else: -# d['mip'] = d.last.get_mono(channel=sch).apply(lambda x: x.max(axis=-1, keepdims=True)) -# -# d['mip_mask'] = models['pixel_classifier_segmentation_model'].label_pixel_class(d.last) -# return d - def roiset_object_map_pipeline( accessors: Dict[str, GenericImageDataAccessor], @@ -124,10 +108,7 @@ def roiset_object_map_pipeline( if not isinstance(models['pixel_classifier_segmentation_model'], SemanticSegmentationModel): raise IncompatibleModelsError('Expecting a pixel classification model') - d = PipelineTrace() - d['raw'] = accessors['accessor'] - # d = _segment(d, models, **k) - + d = PipelineTrace(accessors['accessor']) zmi = k['zmask_zindex'] sch = k['segmentation_channel'] @@ -138,28 +119,8 @@ def roiset_object_map_pipeline( d['mip'] = d.last.get_mono(channel=sch).apply(lambda x: x.max(axis=-1, keepdims=True)) d['mip_mask'] = models['pixel_classifier_segmentation_model'].label_pixel_class(d.last) - - rois = RoiSet.from_binary_mask(d['raw'], d.last, RoiSetMetaParams(**k['roi_params'])) - - # TODO: split out derived channel logic - # # optionally derive additional inputs for object classification - # if dpmod := models.get('pixel_classifier_derived_model'): - # ic = k['derived_channels_input_channel'] - # ocs = k['derived_channels_output_channels'] - # assert ic < d['raw'].chroma - # if not isinstance(dpmod, SemanticSegmentationModel): - # raise IncompatibleModelsError('Expecting pixel_classifier_derived to be a pixel classification model') - # - # def _derive(acc: GenericImageDataAccessor, oc: int): - # acc_mono = acc.get_channels([ic]) - # pxmap = dpmod.infer_patch_stack(acc_mono)[0] - # assert oc < pxmap.chroma - # return PatchStack((pxmap.get_channels([oc]).data * 255).astype('uint8')) - # derived_channel_handles = [ - # lambda a: _derive(a, oc) for oc in ocs - # ] - # else: - # derived_channel_handles = None + d['segmentation'] = get_label_ids(d.last) + rois = RoiSet.from_object_ids(d['input'], d['segmentation'], RoiSetMetaParams(**k['roi_params'])) # optionally run an object classifier if specified if obmod := models.get('object_classifier_model'): @@ -169,16 +130,9 @@ def roiset_object_map_pipeline( obmod_name, [k['patches_channel']], obmod, - # derived_channel_functions=derived_channel_handles ) - # TODO: split out derived channel logic - # # add derived channels to intermediate data - # for di, dacc in enumerate(rois.accs_derived): - # d[f'derived_{di:02d}'] = dacc - d[obmod_name] = rois.get_object_class_map(obmod_name) - return d, rois class Error(Exception): diff --git a/tests/test_ilastik/test_roiset_workflow.py b/tests/test_ilastik/test_roiset_workflow.py index 27fd0d7f..bfe51ed0 100644 --- a/tests/test_ilastik/test_roiset_workflow.py +++ b/tests/test_ilastik/test_roiset_workflow.py @@ -118,27 +118,27 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertEqual(len(trace['ob_id']._unique()[0]), 2) - def test_object_map_workflow_with_derived_channels(self): - acc_in = generate_file_accessor(self.fpi) - params = RoiSetObjectMapParams( - **self._pipeline_params(), - pixel_classifier_derived_model_id='id_px', - derived_channels_input_channel=0, - derived_channels_output_channels=[1], - ) - models = self._get_models() - models['pixel_classifier_derived'] = models['pixel_classifier_segmentation'] # re-use same classifier - - trace, _ = roiset_object_map_pipeline( - {'accessor': acc_in}, - {f'{k}_model': v['model'] for k, v in models.items()}, - **params.dict() - ) - - self.assertTrue('ob_id' in trace.keys()) - self.assertEqual(len(trace['ob_id']._unique()[0]), 2) - self.assertTrue('derived_00' in trace.keys()) - self.assertEqual(trace['derived_00'].chroma, 1) + # def test_object_map_workflow_with_derived_channels(self): + # acc_in = generate_file_accessor(self.fpi) + # params = RoiSetObjectMapParams( + # **self._pipeline_params(), + # pixel_classifier_derived_model_id='id_px', + # derived_channels_input_channel=0, + # derived_channels_output_channels=[1], + # ) + # models = self._get_models() + # models['pixel_classifier_derived'] = models['pixel_classifier_segmentation'] # re-use same classifier + # + # trace, _ = roiset_object_map_pipeline( + # {'accessor': acc_in}, + # {f'{k}_model': v['model'] for k, v in models.items()}, + # **params.dict() + # ) + # + # self.assertTrue('ob_id' in trace.keys()) + # self.assertEqual(len(trace['ob_id']._unique()[0]), 2) + # self.assertTrue('derived_00' in trace.keys()) + # self.assertEqual(trace['derived_00'].chroma, 1) class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts): -- GitLab