diff --git a/extensions/chaeo/tests/test_zstack.py b/extensions/chaeo/tests/test_zstack.py index 45256258d66cd628590615fd804c5b3397c99272..541d0404d241e61f2be82d5c523bfb58c8847b63 100644 --- a/extensions/chaeo/tests/test_zstack.py +++ b/extensions/chaeo/tests/test_zstack.py @@ -209,6 +209,19 @@ class TestZStackDerivedDataProducts(unittest.TestCase): DummyInstanceSegmentationModel(), ] + models = { + 'pixel_classifier': { + 'model': self.pxmodel, + 'params': { + 'px_class': 0, + 'px_prob_threshold': 0.6, + } + }, + 'object_classifier': { + 'model': DummyInstanceSegmentationModel(), + } + } + roi_params = RoiSetMetaParams(**{ 'mask_type': 'boxes', 'filters': { @@ -240,8 +253,6 @@ class TestZStackDerivedDataProducts(unittest.TestCase): multichannel_zstack['path'], output_path / 'roiset' / 'workflow', models, - pixel_class=pp['pxmap_channel'], - pixel_probability_threshold=pp['pxmap_threshold'], segmentation_channel=pp['segmentation_channel'], patches_channel=pp['patches_channel'], export_params=export_params, diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py index 7a0e09413fa30a2648ff8941589a9e5f1126880a..619f7c95be2868327d2654a327a913cef7dba81f 100644 --- a/extensions/chaeo/workflows.py +++ b/extensions/chaeo/workflows.py @@ -26,17 +26,12 @@ def infer_object_map_from_zstack( segmentation_channel: int, patches_channel: int, zmask_zindex: int = None, # None for MIP, - zmask_clip: int = None, roi_params: RoiSetMetaParams = RoiSetMetaParams(), export_params: RoiSetExportParams = RoiSetExportParams(), - pixel_class=0, - pixel_probability_threshold=0.6, ) -> Dict: assert len(models) == 2 - pixel_classifier = models[0] - assert isinstance(pixel_classifier, SemanticSegmentationModel) - object_classifier = models[1] - assert isinstance(object_classifier, InstanceSegmentationModel) + assert isinstance(models['pixel_classifier']['model'], SemanticSegmentationModel) + assert isinstance(models['object_classifier']['model'], InstanceSegmentationModel) ti = Timer() stack = generate_file_accessor(Path(input_file_path)) @@ -49,20 +44,16 @@ def infer_object_map_from_zstack( zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data[:, :, :, zmask_zindex] else: zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data.max(axis=-1, keepdims=True) - if zmask_clip: - zmask_data = rescale(zmask_data, zmask_clip) - mip = InMemoryDataAccessor( - zmask_data, - ) + mip = InMemoryDataAccessor(zmask_data) - mip_mask = pixel_classifier.label_pixel_class(mip, pixel_class, pixel_probability_threshold,) + mip_mask = models['pixel_classifier']['model'].label_pixel_class(mip, **models['pixel_classifier']['params']) ti.click('classify_pixels') # make zmask rois = RoiSet(mip_mask, stack, params=roi_params) ti.click('generate_zmasks') - rois.classify_by(patches_channel, object_classifier) + rois.classify_by(patches_channel, models['object_classifier']['model']) ti.click('classify_objects') rois.run_exports(Path(output_folder_path), patches_channel, fstem, export_params) diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py index f39a78f87363d290db226cfe9ee811e13af778c8..ac8d253b6ac44c1aaddb98b8350b180eef49b437 100644 --- a/extensions/ilastik/models.py +++ b/extensions/ilastik/models.py @@ -77,9 +77,9 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): ) return InMemoryDataAccessor(data=yxcz), {'success': True} - def label_pixel_class(self, img: GenericImageDataAccessor, pixel_class: int = 0, pixel_probability_threshold=0.5): + def label_pixel_class(self, img: GenericImageDataAccessor, px_class: int = 0, px_prob_threshold=0.5, **kwargs): pxmap, _ = self.infer(img) - mask = pxmap.data[:, :, pixel_class, :] > pixel_probability_threshold + mask = pxmap.data[:, :, px_class, :] > px_prob_threshold return InMemoryDataAccessor(mask)