Skip to content
Snippets Groups Projects
Commit 994f3e3e authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Tests pass, caught bug where a deserialized RoiSet would become 3d

parent 96a892c8
No related branches found
No related tags found
No related merge requests found
......@@ -234,7 +234,7 @@ def filter_df_overlap_seg(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.Dat
def is_df_3d(df: pd.DataFrame) -> bool:
return 'z0' in df.columns and 'zi' in df.columns
return 'z0' in df.columns and 'z1' in df.columns
def make_df_from_object_ids(
......@@ -671,6 +671,10 @@ class RoiSet(object):
return zi_st
@property
def is_3d(self) -> bool:
return is_df_3d(self.get_df())
def classify_by(
self, name: str, channels: list[int],
object_classification_model: InstanceMaskSegmentationModel,
......@@ -844,7 +848,6 @@ class RoiSet(object):
return patches_df.apply(_export_patch, axis=1)
def get_patch_masks(self, pad_to: int = None, expanded: bool = False) -> pd.DataFrame:
is_3d = is_df_3d(self._df)
def _make_patch_mask(roi):
if expanded:
patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8')
......@@ -853,7 +856,7 @@ class RoiSet(object):
patch = (roi.binary_mask * 255).astype('uint8')
if pad_to:
patch = pad(patch, pad_to)
if is_3d:
if self.is_3d:
return patch
else:
return np.expand_dims(patch, 2)
......@@ -1087,14 +1090,9 @@ class RoiSet(object):
for k, kp in params.dict().items():
if kp is None:
continue
if k == 'patches_3d':
interm[k] = self.get_patches_acc(make_3d=True, **kp)
if k == 'annotated_patches_2d':
interm[k] = self.get_patches_acc(
make_3d=False, **kp
)
if k == 'patches_2d':
interm[k] = self.get_patches_acc(make_3d=False, **kp)
if k == 'patches':
for pk, pp in kp.items():
interm[f'patches_{pk}'] = self.get_patches_acc(**pp)
if k == 'annotated_zstacks':
interm[k] = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kp))
if k == 'object_classes':
......@@ -1216,7 +1214,10 @@ class RoiSet(object):
df['binary_mask'] = df.apply(_read_binary_mask, axis=1)
id_mask = make_object_ids_from_df(df, acc_raw.shape_dict)
return cls.from_object_ids(acc_raw, id_mask)
if not is_3d and id_mask.nz > 1:
return cls.from_object_ids(acc_raw, id_mask.get_mip())
else:
return cls.from_object_ids(acc_raw, id_mask)
else: # assume bounding boxes, exclusively 2d objects
df['y'] = df['y0']
......
......@@ -516,14 +516,14 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
def test_run_exports(self):
p = RoiSetExportParams(**{
'patches': {
'annotated_patches_2d': {
'2d_annotated': {
'white_channel': 3,
'draw_bounding_box': True,
'rgb_overlay_channels': [3, None, None],
'rgb_overlay_weights': [0.2, 1.0, 1.0],
'pad_to': 512,
},
'patches_2d': {
'2d': {
'white_channel': 3,
'draw_bounding_box': False,
'draw_mask': False,
......@@ -556,7 +556,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
# test on paths in CSV
test_df = pd.read_csv(where / res['dataframe'])
for c in ['tight_patch_masks_path', 'patches_2d_path', 'annotated_patches_2d']:
for c in ['tight_patch_masks_path', 'patches_2d_path', 'patches_2d_annotated_path']:
self.assertTrue(c in test_df.columns)
for f in test_df[c]:
self.assertTrue((where / f).exists(), where / f)
......@@ -757,7 +757,7 @@ class TestRoiSetObjectDetection(unittest.TestCase):
self.seg_mask_3d = generate_file_accessor(data['multichannel_zstack_mask3d']['path'])
def test_create_roiset_from_bounding_boxes(self):
from skimage.measure import label, regionprops, regionprops_table
from skimage.measure import label, regionprops_table
mask = self.seg_mask_3d
labels = label(mask.data_yxz, connectivity=3)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment