Skip to content
Snippets Groups Projects
Commit 6d9f27b0 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Dummy object classification exports map

parent 2491ddf1
No related branches found
No related tags found
No related merge requests found
...@@ -39,5 +39,6 @@ class RoiSetExportParams(BaseModel): ...@@ -39,5 +39,6 @@ class RoiSetExportParams(BaseModel):
patches_2d_for_training: Union[PatchParams, None] = None patches_2d_for_training: Union[PatchParams, None] = None
patch_masks: bool = False patch_masks: bool = False
annotated_zstacks: Union[AnnotatedZStackParams, None] = None annotated_zstacks: Union[AnnotatedZStackParams, None] = None
object_classes: bool = False
...@@ -229,7 +229,8 @@ class TestZStackDerivedDataProducts(unittest.TestCase): ...@@ -229,7 +229,8 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
'draw_mask': False, 'draw_mask': False,
}, },
'patch_masks': True, 'patch_masks': True,
'annotated_zstacks': {} 'annotated_zstacks': {},
'object_classes': True
}) })
infer_object_map_from_zstack( infer_object_map_from_zstack(
multichannel_zstack['path'], multichannel_zstack['path'],
......
...@@ -54,33 +54,25 @@ def infer_object_map_from_zstack( ...@@ -54,33 +54,25 @@ def infer_object_map_from_zstack(
mip = InMemoryDataAccessor( mip = InMemoryDataAccessor(
zmask_data, zmask_data,
) )
# pxmap, _ = pixel_classifier.infer(mip)
mip_mask = pixel_classifier.label_pixel_class(mip, pixel_class, pixel_probability_threshold,) mip_mask = pixel_classifier.label_pixel_class(mip, pixel_class, pixel_probability_threshold,)
ti.click('classify_pixels') ti.click('classify_pixels')
# make zmask # make zmask
# rois = RoiSet(mip_mask, stack, mask_type=zmask_type, filters=zmask_filters, expand_box_by=meta.expand_box_by)
rois = RoiSet(mip_mask, stack, params=roi_params) rois = RoiSet(mip_mask, stack, params=roi_params)
ti.click('generate_zmasks') ti.click('generate_zmasks')
object_class_map = rois.classify_by(patches_channel, object_classifier) rois.classify_by(patches_channel, object_classifier)
ti.click('classify_objects')
# TODO: add ZMaskObjectTable method to export object map
output_path = Path(output_folder_path) / ('obj_classes_' + (fstem + '.tif'))
write_accessor_data_to_file(
output_path,
object_class_map
)
ti.click('export_object_classes')
rois.run_exports(Path(output_folder_path), patches_channel, fstem, export_params) rois.run_exports(Path(output_folder_path), patches_channel, fstem, export_params)
ti.click('export_roi_products') ti.click('export_roi_products')
return { return {
'timer_results': ti.events, 'timer_results': ti.events,
'dataframe': rois.df, 'dataframe': rois.df,
'interm': {}, 'interm': {},
'output_path': output_path.__str__(), 'output_path': output_folder_path,
} }
......
...@@ -32,6 +32,7 @@ class RoiSet(object): ...@@ -32,6 +32,7 @@ class RoiSet(object):
self.acc_raw = acc_raw self.acc_raw = acc_raw
self.count = len(self.zmask_meta) self.count = len(self.zmask_meta)
self.object_id_labels = self.interm['label_map'] self.object_id_labels = self.interm['label_map']
self.object_class_map = None
def get_argmax(self): def get_argmax(self):
return self.interm.argmax return self.interm.argmax
...@@ -67,7 +68,7 @@ class RoiSet(object): ...@@ -67,7 +68,7 @@ class RoiSet(object):
) )
lamap = self.object_id_labels lamap = self.object_id_labels
output_map = np.zeros(lamap.shape, dtype=lamap.dtype) om = np.zeros(lamap.shape, dtype=lamap.dtype)
self.df['instance_class'] = np.nan self.df['instance_class'] = np.nan
# assign labels to object map: # assign labels to object map:
...@@ -75,10 +76,10 @@ class RoiSet(object): ...@@ -75,10 +76,10 @@ class RoiSet(object):
object_id = self.zmask_meta[ii]['info'].label object_id = self.zmask_meta[ii]['info'].label
result_patch = mask_largest_object(obmap_patches.iat(ii)) result_patch = mask_largest_object(obmap_patches.iat(ii))
object_class = np.unique(result_patch)[1] object_class = np.unique(result_patch)[1]
output_map[self.object_id_labels == object_id] = object_class om[self.object_id_labels == object_id] = object_class
self.df[object_id, 'instance_class'] = object_class self.df[object_id, 'instance_class'] = object_class
return InMemoryDataAccessor(output_map) self.object_class_map = InMemoryDataAccessor(om)
# TODO: test # TODO: test
def get_object_mask_by_id(self, obj_id): def get_object_mask_by_id(self, obj_id):
...@@ -91,9 +92,6 @@ class RoiSet(object): ...@@ -91,9 +92,6 @@ class RoiSet(object):
def get_object_patch_by_id(self, obj_id): def get_object_patch_by_id(self, obj_id):
pass pass
def get_object_map(self, filters: RoiFilter):
pass
def run_exports(self, where, channel, prefix, params: RoiSetExportParams): def run_exports(self, where, channel, prefix, params: RoiSetExportParams):
if not self.count: if not self.count:
return return
...@@ -129,7 +127,8 @@ class RoiSet(object): ...@@ -129,7 +127,8 @@ class RoiSet(object):
draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp) draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp)
) )
write_accessor_data_to_file(subdir / (pr + '.tif'), annotated) write_accessor_data_to_file(subdir / (pr + '.tif'), annotated)
if k == 'object_classes':
write_accessor_data_to_file(subdir / (pr + '.tif'), self.object_class_map)
......
...@@ -120,13 +120,15 @@ class DummyInstanceSegmentationModel(InstanceSegmentationModel): ...@@ -120,13 +120,15 @@ class DummyInstanceSegmentationModel(InstanceSegmentationModel):
def infer( def infer(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor
) -> (GenericImageDataAccessor, dict): ) -> (GenericImageDataAccessor, dict):
return mask return img.__class__(
(mask.data / mask.data.max()).astype('uint16')
)
def label_instance_class( def label_instance_class(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
) -> GenericImageDataAccessor: ) -> GenericImageDataAccessor:
""" """
Returns a trivial segmentation, i.e. the input mask Returns a trivial segmentation, i.e. the input mask with value 1
""" """
super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs) super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs)
return self.infer(img, mask) return self.infer(img, mask)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment