Skip to content
Snippets Groups Projects
Commit 10700c09 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Adapted folder structure from master

parent 14d76dfa
No related branches found
No related tags found
No related merge requests found
...@@ -209,6 +209,19 @@ class TestZStackDerivedDataProducts(unittest.TestCase): ...@@ -209,6 +209,19 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
DummyInstanceSegmentationModel(), DummyInstanceSegmentationModel(),
] ]
models = {
'pixel_classifier': {
'model': self.pxmodel,
'params': {
'px_class': 0,
'px_prob_threshold': 0.6,
}
},
'object_classifier': {
'model': DummyInstanceSegmentationModel(),
}
}
roi_params = RoiSetMetaParams(**{ roi_params = RoiSetMetaParams(**{
'mask_type': 'boxes', 'mask_type': 'boxes',
'filters': { 'filters': {
...@@ -240,8 +253,6 @@ class TestZStackDerivedDataProducts(unittest.TestCase): ...@@ -240,8 +253,6 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
multichannel_zstack['path'], multichannel_zstack['path'],
output_path / 'roiset' / 'workflow', output_path / 'roiset' / 'workflow',
models, models,
pixel_class=pp['pxmap_channel'],
pixel_probability_threshold=pp['pxmap_threshold'],
segmentation_channel=pp['segmentation_channel'], segmentation_channel=pp['segmentation_channel'],
patches_channel=pp['patches_channel'], patches_channel=pp['patches_channel'],
export_params=export_params, export_params=export_params,
......
...@@ -26,17 +26,12 @@ def infer_object_map_from_zstack( ...@@ -26,17 +26,12 @@ def infer_object_map_from_zstack(
segmentation_channel: int, segmentation_channel: int,
patches_channel: int, patches_channel: int,
zmask_zindex: int = None, # None for MIP, zmask_zindex: int = None, # None for MIP,
zmask_clip: int = None,
roi_params: RoiSetMetaParams = RoiSetMetaParams(), roi_params: RoiSetMetaParams = RoiSetMetaParams(),
export_params: RoiSetExportParams = RoiSetExportParams(), export_params: RoiSetExportParams = RoiSetExportParams(),
pixel_class=0,
pixel_probability_threshold=0.6,
) -> Dict: ) -> Dict:
assert len(models) == 2 assert len(models) == 2
pixel_classifier = models[0] assert isinstance(models['pixel_classifier']['model'], SemanticSegmentationModel)
assert isinstance(pixel_classifier, SemanticSegmentationModel) assert isinstance(models['object_classifier']['model'], InstanceSegmentationModel)
object_classifier = models[1]
assert isinstance(object_classifier, InstanceSegmentationModel)
ti = Timer() ti = Timer()
stack = generate_file_accessor(Path(input_file_path)) stack = generate_file_accessor(Path(input_file_path))
...@@ -49,20 +44,16 @@ def infer_object_map_from_zstack( ...@@ -49,20 +44,16 @@ def infer_object_map_from_zstack(
zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data[:, :, :, zmask_zindex] zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data[:, :, :, zmask_zindex]
else: else:
zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data.max(axis=-1, keepdims=True) zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data.max(axis=-1, keepdims=True)
if zmask_clip: mip = InMemoryDataAccessor(zmask_data)
zmask_data = rescale(zmask_data, zmask_clip)
mip = InMemoryDataAccessor(
zmask_data,
)
mip_mask = pixel_classifier.label_pixel_class(mip, pixel_class, pixel_probability_threshold,) mip_mask = models['pixel_classifier']['model'].label_pixel_class(mip, **models['pixel_classifier']['params'])
ti.click('classify_pixels') ti.click('classify_pixels')
# make zmask # make zmask
rois = RoiSet(mip_mask, stack, params=roi_params) rois = RoiSet(mip_mask, stack, params=roi_params)
ti.click('generate_zmasks') ti.click('generate_zmasks')
rois.classify_by(patches_channel, object_classifier) rois.classify_by(patches_channel, models['object_classifier']['model'])
ti.click('classify_objects') ti.click('classify_objects')
rois.run_exports(Path(output_folder_path), patches_channel, fstem, export_params) rois.run_exports(Path(output_folder_path), patches_channel, fstem, export_params)
......
...@@ -77,9 +77,9 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): ...@@ -77,9 +77,9 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
) )
return InMemoryDataAccessor(data=yxcz), {'success': True} return InMemoryDataAccessor(data=yxcz), {'success': True}
def label_pixel_class(self, img: GenericImageDataAccessor, pixel_class: int = 0, pixel_probability_threshold=0.5): def label_pixel_class(self, img: GenericImageDataAccessor, px_class: int = 0, px_prob_threshold=0.5, **kwargs):
pxmap, _ = self.infer(img) pxmap, _ = self.infer(img)
mask = pxmap.data[:, :, pixel_class, :] > pixel_probability_threshold mask = pxmap.data[:, :, px_class, :] > px_prob_threshold
return InMemoryDataAccessor(mask) return InMemoryDataAccessor(mask)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment