Skip to content
Snippets Groups Projects
Commit c744a6d7 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Add option to get MIP of patch mask

parent f0547637
No related branches found
No related tags found
No related merge requests found
......@@ -44,6 +44,10 @@ class PatchParams(BaseModel):
False,
description='If generating 2d patches with a 3d focus metric, update the RoiSet patch focus zi accordingly'
)
mask_mip: bool = Field(
False,
description='If generating 2d patch masks, use the MIP of the 3d patch mask instead of existing zi focus'
)
class AnnotatedPatchParams(PatchParams):
bounding_box_channel: int = 1
......@@ -814,8 +818,8 @@ class RoiSet(object):
self.get_serializable_dataframe().to_csv(csv_path, index=False)
return csv_path.name
def export_patch_masks(self, where: Path, prefix='mask', expanded=False, make_3d=True, **kwargs) -> pd.DataFrame:
patches_df = self.get_patch_masks(pad_to=None, expanded=expanded, make_3d=make_3d).copy()
def export_patch_masks(self, where: Path, prefix='mask', expanded=False, make_3d=True, mask_mip=False, **kwargs) -> pd.DataFrame:
patches_df = self.get_patch_masks(pad_to=None, expanded=expanded, make_3d=make_3d, mask_mip=mask_mip).copy()
if 'nz' in patches_df.columns and any(patches_df['nz'] > 1):
ext = 'tif'
else:
......@@ -856,7 +860,7 @@ class RoiSet(object):
return patches_df.apply(_export_patch, axis=1)
def get_patch_masks(self, pad_to: int = None, expanded: bool = False, make_3d=True) -> pd.DataFrame:
def get_patch_masks(self, pad_to: int = None, expanded: bool = False, make_3d=True, mask_mip=False) -> pd.DataFrame:
def _make_patch_mask(roi):
if expanded:
patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8')
......@@ -867,6 +871,8 @@ class RoiSet(object):
patch = pad(patch, pad_to)
if self.is_3d and make_3d:
return patch
elif self.is_3d and mask_mip:
return np.max(patch, axis=-1)
elif self.is_3d:
rzi = roi.zi - roi.z0
return patch[:, :, rzi: rzi + 1]
......
......@@ -351,6 +351,45 @@ class TestRoiSet3dProducts(unittest.TestCase):
updated_zi = roiset.get_df()['zi']
self.assertTrue((starting_zi != updated_zi).any())
def test_mip_patch_masks(self):
p = RoiSetExportParams(**{
'patches': {
'3d_masks': {
'make_3d': True,
'is_patch_mask': True,
},
'2d_masks': {
'make_3d': False,
'is_patch_mask': True,
'mask_mip': True,
},
},
})
res = self.test_create_roiset_from_3d_obj_ids().run_exports(
self.where,
prefix='test',
params=p
)
# test that exported patches are 3d
def _get_nz_from_file(fname):
pa = self.where / fname
pacc = generate_file_accessor(pa)
return pacc.nz
n_3d = len(res['patches_3d_masks'])
n_2d = len(res['patches_2d_masks'])
self.assertEqual(n_3d, n_2d)
for pi in range(0, n_3d):
acc_3d = generate_file_accessor(self.where / res['patches_3d_masks'][pi])
acc_2d = generate_file_accessor(self.where / res['patches_2d_masks'][pi])
self.assertTrue(
np.all(
acc_3d.get_mip().data == acc_2d.data
)
)
def test_run_export_mono_3d_labels_overlay(self):
# export via .run_exports() method
res = self.test_create_roiset_from_3d_obj_ids().run_exports(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment