diff --git a/extensions/chaeo/params.py b/extensions/chaeo/params.py index 5a226925f9ff1bb0d9798ee4fa5dd1885e60091c..0ba956fba3446c84955c1ac829c2dc3561a9ff29 100644 --- a/extensions/chaeo/params.py +++ b/extensions/chaeo/params.py @@ -39,5 +39,6 @@ class RoiSetExportParams(BaseModel): patches_2d_for_training: Union[PatchParams, None] = None patch_masks: bool = False annotated_zstacks: Union[AnnotatedZStackParams, None] = None + object_classes: bool = False diff --git a/extensions/chaeo/tests/test_zstack.py b/extensions/chaeo/tests/test_zstack.py index 9e6e3a74db10f7d763505ecc6e1621a4231f2b54..7778370add1a2421480cfa9f5c460d6e1f63a5d7 100644 --- a/extensions/chaeo/tests/test_zstack.py +++ b/extensions/chaeo/tests/test_zstack.py @@ -229,7 +229,8 @@ class TestZStackDerivedDataProducts(unittest.TestCase): 'draw_mask': False, }, 'patch_masks': True, - 'annotated_zstacks': {} + 'annotated_zstacks': {}, + 'object_classes': True }) infer_object_map_from_zstack( multichannel_zstack['path'], diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py index 4b3f380d7fe9d80e28e5b2a76cdd6c7251c2764e..7a0e09413fa30a2648ff8941589a9e5f1126880a 100644 --- a/extensions/chaeo/workflows.py +++ b/extensions/chaeo/workflows.py @@ -54,33 +54,25 @@ def infer_object_map_from_zstack( mip = InMemoryDataAccessor( zmask_data, ) - # pxmap, _ = pixel_classifier.infer(mip) + mip_mask = pixel_classifier.label_pixel_class(mip, pixel_class, pixel_probability_threshold,) ti.click('classify_pixels') # make zmask - # rois = RoiSet(mip_mask, stack, mask_type=zmask_type, filters=zmask_filters, expand_box_by=meta.expand_box_by) rois = RoiSet(mip_mask, stack, params=roi_params) ti.click('generate_zmasks') - object_class_map = rois.classify_by(patches_channel, object_classifier) - - # TODO: add ZMaskObjectTable method to export object map - output_path = Path(output_folder_path) / ('obj_classes_' + (fstem + '.tif')) - write_accessor_data_to_file( - output_path, - object_class_map - ) - ti.click('export_object_classes') + rois.classify_by(patches_channel, object_classifier) + ti.click('classify_objects') rois.run_exports(Path(output_folder_path), patches_channel, fstem, export_params) ti.click('export_roi_products') return { 'timer_results': ti.events, - 'dataframe': rois.df, + 'dataframe': rois.df, 'interm': {}, - 'output_path': output_path.__str__(), + 'output_path': output_folder_path, } diff --git a/extensions/chaeo/zmask.py b/extensions/chaeo/zmask.py index 13c54b7d62db0d12d190ed96cc7dbdc40789b96a..1441e9ae278a845616e3d68b34677a8f83a71e72 100644 --- a/extensions/chaeo/zmask.py +++ b/extensions/chaeo/zmask.py @@ -32,6 +32,7 @@ class RoiSet(object): self.acc_raw = acc_raw self.count = len(self.zmask_meta) self.object_id_labels = self.interm['label_map'] + self.object_class_map = None def get_argmax(self): return self.interm.argmax @@ -67,7 +68,7 @@ class RoiSet(object): ) lamap = self.object_id_labels - output_map = np.zeros(lamap.shape, dtype=lamap.dtype) + om = np.zeros(lamap.shape, dtype=lamap.dtype) self.df['instance_class'] = np.nan # assign labels to object map: @@ -75,10 +76,10 @@ class RoiSet(object): object_id = self.zmask_meta[ii]['info'].label result_patch = mask_largest_object(obmap_patches.iat(ii)) object_class = np.unique(result_patch)[1] - output_map[self.object_id_labels == object_id] = object_class + om[self.object_id_labels == object_id] = object_class self.df[object_id, 'instance_class'] = object_class - return InMemoryDataAccessor(output_map) + self.object_class_map = InMemoryDataAccessor(om) # TODO: test def get_object_mask_by_id(self, obj_id): @@ -91,9 +92,6 @@ class RoiSet(object): def get_object_patch_by_id(self, obj_id): pass - def get_object_map(self, filters: RoiFilter): - pass - def run_exports(self, where, channel, prefix, params: RoiSetExportParams): if not self.count: return @@ -129,7 +127,8 @@ class RoiSet(object): draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp) ) write_accessor_data_to_file(subdir / (pr + '.tif'), annotated) - + if k == 'object_classes': + write_accessor_data_to_file(subdir / (pr + '.tif'), self.object_class_map) diff --git a/model_server/models.py b/model_server/models.py index 06bef01c9831cb495b32662c4bbdb8970589ada7..8239d1cae3f4595b7ba09693f111c0d4dda1fca4 100644 --- a/model_server/models.py +++ b/model_server/models.py @@ -120,13 +120,15 @@ class DummyInstanceSegmentationModel(InstanceSegmentationModel): def infer( self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor ) -> (GenericImageDataAccessor, dict): - return mask + return img.__class__( + (mask.data / mask.data.max()).astype('uint16') + ) def label_instance_class( self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs ) -> GenericImageDataAccessor: """ - Returns a trivial segmentation, i.e. the input mask + Returns a trivial segmentation, i.e. the input mask with value 1 """ super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs) return self.infer(img, mask)