From c1cd28eb667555ba0b9fa03adc68945fdec5233d Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Tue, 13 Aug 2024 17:35:24 +0200 Subject: [PATCH] Test covers derived inputs pipeline method directly --- model_server/base/pipelines/roiset_obmap.py | 10 ++- tests/test_ilastik/test_roiset_workflow.py | 74 +-------------------- 2 files changed, 11 insertions(+), 73 deletions(-) diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py index 1bd9220a..3efa5189 100644 --- a/model_server/base/pipelines/roiset_obmap.py +++ b/model_server/base/pipelines/roiset_obmap.py @@ -139,7 +139,7 @@ def roiset_object_map_pipeline( rois = RoiSet.from_binary_mask(d['raw'], d.last, RoiSetMetaParams(**k['roi_params'])) # optionally derive additional inputs for object classification - if dpmod := models.get('pixel_classifier_derived'): + if dpmod := models.get('pixel_classifier_derived_model'): ic = k['derived_channels_input_channel'] ocs = k['derived_channels_output_channels'] assert ic < d['raw'].chroma @@ -157,7 +157,7 @@ def roiset_object_map_pipeline( else: derived_channel_handles = None - # optionally classify if an object classifier is passed + # optionally run an object classifier if specified if obmod := models.get('object_classifier_model'): obmod_name = k['object_classifier_model_id'] assert isinstance(obmod, InstanceSegmentationModel) @@ -167,8 +167,14 @@ def roiset_object_map_pipeline( obmod, derived_channel_functions=derived_channel_handles ) + # add derived channels to intermediate data + for di, dacc in enumerate(rois.accs_derived): + d[f'derived_{di:02d}'] = dacc + d[obmod_name] = rois.get_object_class_map(obmod_name) + + # TODO: subclass this boundary pipeline in separate module # if p.label_params and p.export_label_interm: # fp_la_interm = Path(output_folder_path) / 'mip_masks' / f'interm_{fstem}.png' diff --git a/tests/test_ilastik/test_roiset_workflow.py b/tests/test_ilastik/test_roiset_workflow.py index 860d8089..27fd0d7f 100644 --- a/tests/test_ilastik/test_roiset_workflow.py +++ b/tests/test_ilastik/test_roiset_workflow.py @@ -1,15 +1,11 @@ -import shutil from pathlib import Path -from shutil import copyfile import unittest import numpy as np -import model_server.conf.testing as conf from model_server.base.accessors import generate_file_accessor from tests.base.test_model import DummyInstanceSegmentationModel -from model_server.base.roiset import RoiSetMetaParams, RoiSetExportParams import model_server.conf.testing as conf from model_server.base.pipelines.roiset_obmap import RoiSetObjectMapParams, roiset_object_map_pipeline @@ -121,6 +117,7 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertTrue('ob_id' in trace.keys()) self.assertEqual(len(trace['ob_id']._unique()[0]), 2) + def test_object_map_workflow_with_derived_channels(self): acc_in = generate_file_accessor(self.fpi) params = RoiSetObjectMapParams( @@ -140,7 +137,8 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertTrue('ob_id' in trace.keys()) self.assertEqual(len(trace['ob_id']._unique()[0]), 2) - self.assertTrue(all([Path(pa).exists() for pa in record['interm']['derived_channels']])) + self.assertTrue('derived_00' in trace.keys()) + self.assertEqual(trace['derived_00'].chroma, 1) class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts): @@ -255,69 +253,3 @@ class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProd # self.assertTrue(all(np.unique(acc_obmap.data) == [0, 1])) - -class TestRoiSetWithDerivedChannels(conf.TestServerBaseClass): - - app_name = 'api:app' - input_data = data['multichannel_zstack_raw'] - - # def test_object_map_workflow_with_derived_channels(self): - # models = _get_models() - # models['pixel_classifier_derived'] = models['pixel_classifier_segmentation'] # re-use same classifier - # where_out = output_path / 'roiset' / 'workflow' - # - # p = ClassifyZStackApiParams( - # input_filename=str(data['multichannel_zstack_raw']['path']), - # model_ids={ - # 'pixel_classifier_segmentation': 'id_px', - # 'object_classifier': 'id_ob', - # 'pixel_classifier_derived': 'id_px', - # }, - # derived_channels_input_channel=0, - # derived_channels_output_channels=[1], - # segmentation_channel=test_params['segmentation_channel'], - # patches_channel=test_params['patches_channel'], - # roi_params=_get_roi_params(), - # export_params=_get_export_params(), - # ) - # - # record = export_zstack_roiset( - # data['multichannel_zstack_raw']['path'], - # where_out, - # models, - # p - # ) - # self.assertTrue(all([Path(pa).exists() for pa in record['interm']['derived_channels']])) - - def test_derived_channels_api(self): - resp = self._put( - 'ilastik/seg/load/', - body={'project_file': _get_model_params()['pixel_classifier_segmentation']['project_file'].__str__()}, - ) - self.assertEqual(resp.status_code, 200) - mid_px = resp.json()['model_id'] - resp = self._put(f'models/dummy_instance/load') - mid_ob = resp.json()['model_id'] - - - p = ClassifyZStackApiParams( - input_filename=data['multichannel_zstack_raw']['path'].__str__(), - model_ids={ - 'pixel_classifier_segmentation': mid_px, - 'object_classifier': mid_ob, - 'pixel_classifier_derived': mid_px, - }, - derived_channels_input_channel=0, - derived_channels_output_channels=[1], - segmentation_channel=test_params['segmentation_channel'], - patches_channel=test_params['patches_channel'], - roi_params=_get_roi_params(), - export_params=_get_export_params(), - ) - resp = self._put('chaeo/classify_zstack/infer', body=p.dict()) - self.assertEqual(resp.status_code, 200, resp.json()) - omfp = resp.json()['object_map_filepath'] - self.assertTrue(Path(omfp).exists()) - acc_obmap = generate_file_accessor(omfp) - self.assertTrue(all(np.unique(acc_obmap.data) == [0, 1])) - -- GitLab