Skip to content
Snippets Groups Projects
Commit 10700c09 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Adapted folder structure from master

parent 14d76dfa
No related branches found
No related tags found
No related merge requests found
......@@ -209,6 +209,19 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
DummyInstanceSegmentationModel(),
]
models = {
'pixel_classifier': {
'model': self.pxmodel,
'params': {
'px_class': 0,
'px_prob_threshold': 0.6,
}
},
'object_classifier': {
'model': DummyInstanceSegmentationModel(),
}
}
roi_params = RoiSetMetaParams(**{
'mask_type': 'boxes',
'filters': {
......@@ -240,8 +253,6 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
multichannel_zstack['path'],
output_path / 'roiset' / 'workflow',
models,
pixel_class=pp['pxmap_channel'],
pixel_probability_threshold=pp['pxmap_threshold'],
segmentation_channel=pp['segmentation_channel'],
patches_channel=pp['patches_channel'],
export_params=export_params,
......
......@@ -26,17 +26,12 @@ def infer_object_map_from_zstack(
segmentation_channel: int,
patches_channel: int,
zmask_zindex: int = None, # None for MIP,
zmask_clip: int = None,
roi_params: RoiSetMetaParams = RoiSetMetaParams(),
export_params: RoiSetExportParams = RoiSetExportParams(),
pixel_class=0,
pixel_probability_threshold=0.6,
) -> Dict:
assert len(models) == 2
pixel_classifier = models[0]
assert isinstance(pixel_classifier, SemanticSegmentationModel)
object_classifier = models[1]
assert isinstance(object_classifier, InstanceSegmentationModel)
assert isinstance(models['pixel_classifier']['model'], SemanticSegmentationModel)
assert isinstance(models['object_classifier']['model'], InstanceSegmentationModel)
ti = Timer()
stack = generate_file_accessor(Path(input_file_path))
......@@ -49,20 +44,16 @@ def infer_object_map_from_zstack(
zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data[:, :, :, zmask_zindex]
else:
zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data.max(axis=-1, keepdims=True)
if zmask_clip:
zmask_data = rescale(zmask_data, zmask_clip)
mip = InMemoryDataAccessor(
zmask_data,
)
mip = InMemoryDataAccessor(zmask_data)
mip_mask = pixel_classifier.label_pixel_class(mip, pixel_class, pixel_probability_threshold,)
mip_mask = models['pixel_classifier']['model'].label_pixel_class(mip, **models['pixel_classifier']['params'])
ti.click('classify_pixels')
# make zmask
rois = RoiSet(mip_mask, stack, params=roi_params)
ti.click('generate_zmasks')
rois.classify_by(patches_channel, object_classifier)
rois.classify_by(patches_channel, models['object_classifier']['model'])
ti.click('classify_objects')
rois.run_exports(Path(output_folder_path), patches_channel, fstem, export_params)
......
......@@ -77,9 +77,9 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
def label_pixel_class(self, img: GenericImageDataAccessor, pixel_class: int = 0, pixel_probability_threshold=0.5):
def label_pixel_class(self, img: GenericImageDataAccessor, px_class: int = 0, px_prob_threshold=0.5, **kwargs):
pxmap, _ = self.infer(img)
mask = pxmap.data[:, :, pixel_class, :] > pixel_probability_threshold
mask = pxmap.data[:, :, px_class, :] > px_prob_threshold
return InMemoryDataAccessor(mask)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment