From 6d9f27b0f3129a97a9a4c4bed1401e89aedb47d7 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 21 Dec 2023 14:51:03 +0100 Subject: [PATCH] Dummy object classification exports map --- extensions/chaeo/params.py | 1 + extensions/chaeo/tests/test_zstack.py | 3 ++- extensions/chaeo/workflows.py | 18 +++++------------- extensions/chaeo/zmask.py | 13 ++++++------- model_server/models.py | 6 ++++-- 5 files changed, 18 insertions(+), 23 deletions(-) diff --git a/extensions/chaeo/params.py b/extensions/chaeo/params.py index 5a226925..0ba956fb 100644 --- a/extensions/chaeo/params.py +++ b/extensions/chaeo/params.py @@ -39,5 +39,6 @@ class RoiSetExportParams(BaseModel): patches_2d_for_training: Union[PatchParams, None] = None patch_masks: bool = False annotated_zstacks: Union[AnnotatedZStackParams, None] = None + object_classes: bool = False diff --git a/extensions/chaeo/tests/test_zstack.py b/extensions/chaeo/tests/test_zstack.py index 9e6e3a74..7778370a 100644 --- a/extensions/chaeo/tests/test_zstack.py +++ b/extensions/chaeo/tests/test_zstack.py @@ -229,7 +229,8 @@ class TestZStackDerivedDataProducts(unittest.TestCase): 'draw_mask': False, }, 'patch_masks': True, - 'annotated_zstacks': {} + 'annotated_zstacks': {}, + 'object_classes': True }) infer_object_map_from_zstack( multichannel_zstack['path'], diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py index 4b3f380d..7a0e0941 100644 --- a/extensions/chaeo/workflows.py +++ b/extensions/chaeo/workflows.py @@ -54,33 +54,25 @@ def infer_object_map_from_zstack( mip = InMemoryDataAccessor( zmask_data, ) - # pxmap, _ = pixel_classifier.infer(mip) + mip_mask = pixel_classifier.label_pixel_class(mip, pixel_class, pixel_probability_threshold,) ti.click('classify_pixels') # make zmask - # rois = RoiSet(mip_mask, stack, mask_type=zmask_type, filters=zmask_filters, expand_box_by=meta.expand_box_by) rois = RoiSet(mip_mask, stack, params=roi_params) ti.click('generate_zmasks') - object_class_map = rois.classify_by(patches_channel, object_classifier) - - # TODO: add ZMaskObjectTable method to export object map - output_path = Path(output_folder_path) / ('obj_classes_' + (fstem + '.tif')) - write_accessor_data_to_file( - output_path, - object_class_map - ) - ti.click('export_object_classes') + rois.classify_by(patches_channel, object_classifier) + ti.click('classify_objects') rois.run_exports(Path(output_folder_path), patches_channel, fstem, export_params) ti.click('export_roi_products') return { 'timer_results': ti.events, - 'dataframe': rois.df, + 'dataframe': rois.df, 'interm': {}, - 'output_path': output_path.__str__(), + 'output_path': output_folder_path, } diff --git a/extensions/chaeo/zmask.py b/extensions/chaeo/zmask.py index 13c54b7d..1441e9ae 100644 --- a/extensions/chaeo/zmask.py +++ b/extensions/chaeo/zmask.py @@ -32,6 +32,7 @@ class RoiSet(object): self.acc_raw = acc_raw self.count = len(self.zmask_meta) self.object_id_labels = self.interm['label_map'] + self.object_class_map = None def get_argmax(self): return self.interm.argmax @@ -67,7 +68,7 @@ class RoiSet(object): ) lamap = self.object_id_labels - output_map = np.zeros(lamap.shape, dtype=lamap.dtype) + om = np.zeros(lamap.shape, dtype=lamap.dtype) self.df['instance_class'] = np.nan # assign labels to object map: @@ -75,10 +76,10 @@ class RoiSet(object): object_id = self.zmask_meta[ii]['info'].label result_patch = mask_largest_object(obmap_patches.iat(ii)) object_class = np.unique(result_patch)[1] - output_map[self.object_id_labels == object_id] = object_class + om[self.object_id_labels == object_id] = object_class self.df[object_id, 'instance_class'] = object_class - return InMemoryDataAccessor(output_map) + self.object_class_map = InMemoryDataAccessor(om) # TODO: test def get_object_mask_by_id(self, obj_id): @@ -91,9 +92,6 @@ class RoiSet(object): def get_object_patch_by_id(self, obj_id): pass - def get_object_map(self, filters: RoiFilter): - pass - def run_exports(self, where, channel, prefix, params: RoiSetExportParams): if not self.count: return @@ -129,7 +127,8 @@ class RoiSet(object): draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp) ) write_accessor_data_to_file(subdir / (pr + '.tif'), annotated) - + if k == 'object_classes': + write_accessor_data_to_file(subdir / (pr + '.tif'), self.object_class_map) diff --git a/model_server/models.py b/model_server/models.py index 06bef01c..8239d1ca 100644 --- a/model_server/models.py +++ b/model_server/models.py @@ -120,13 +120,15 @@ class DummyInstanceSegmentationModel(InstanceSegmentationModel): def infer( self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor ) -> (GenericImageDataAccessor, dict): - return mask + return img.__class__( + (mask.data / mask.data.max()).astype('uint16') + ) def label_instance_class( self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs ) -> GenericImageDataAccessor: """ - Returns a trivial segmentation, i.e. the input mask + Returns a trivial segmentation, i.e. the input mask with value 1 """ super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs) return self.infer(img, mask) -- GitLab