Skip to content
Snippets Groups Projects
Commit ce711cba authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

slice and mask array objects now in RoiSet's dataframe

parent e39fcf73
No related branches found
No related tags found
No related merge requests found
......@@ -61,9 +61,8 @@ def write_patch_to_file(where, fname, yxcz):
def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack:
patches = []
for ob in roiset.get_df().itertuples('Roi'):
sp_sl = roiset.get_rel_slice_at(ob.Index)
patch = np.zeros((ob.ebb_h, ob.ebb_w, 1, 1), dtype='uint8')
patch[sp_sl][:, :, 0, 0] = roiset.get_mask_at(ob.Index) * 255
patch[ob.relative_slice][:, :, 0, 0] = ob.mask * 255
if pad_to:
patch = pad(patch, pad_to)
......
......@@ -96,20 +96,18 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
def test_zmask_slices_are_valid(self):
roiset = self.test_zmask_makes_correct_boxes()
slices = [roiset.get_slice_at(i) for i in roiset.get_df().index]
for s in slices:
for s in roiset.get_slices():
ebb = roiset.acc_raw.data[s]
self.assertEqual(len(ebb.shape), 4)
self.assertTrue(np.all([si >= 1 for si in ebb.shape]))
def test_zmask_rel_slices_are_valid(self):
roiset = self.test_zmask_makes_correct_boxes()
slices = [roiset.get_slice_at(i) for i in roiset.get_df().index]
rel_slices = [roiset.get_rel_slice_at(i) for i in roiset.get_df().index]
for i, s in enumerate(slices):
for i, s in enumerate(roiset.get_slices()):
ebb = roiset.acc_raw.data[s]
self.assertEqual(len(ebb.shape), 4)
self.assertTrue(np.all([si >= 1 for si in ebb.shape]))
rel_slices = roiset.get_df()['relative_slice']
rbb = ebb[rel_slices[i]]
self.assertEqual(len(rbb.shape), 4)
self.assertTrue(np.all([si >= 1 for si in rbb.shape]))
......
......@@ -36,7 +36,9 @@ class RoiSet(object):
self.acc_raw = acc_raw
self._df = self.filter_df(
self.make_df(self.acc_raw, self.acc_obj_ids, expand_box_by=params.expand_box_by),
self.make_df(
self.acc_raw, self.acc_obj_ids, expand_box_by=params.expand_box_by
),
params.filters,
)
......@@ -78,7 +80,14 @@ class RoiSet(object):
self.object_class_map = None
@staticmethod
def make_df(acc_raw, acc_obj_ids, expand_box_by):
def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
"""
Build dataframe associate object IDs with summary stats
:param acc_raw: accessor to raw image data
:param acc_obj_ids: accessor to map of object IDs
:param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary)
:return: pd.DataFrame
"""
# build dataframe of objects, assign z index to each object
argmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16')
df = (
......@@ -117,36 +126,22 @@ class RoiSet(object):
assert np.all(df['rel_x1'] <= (df['ebb_x1'] - df['ebb_x0']))
assert np.all(df['rel_y1'] <= (df['ebb_x1'] - df['ebb_x0']))
df['slice'] = df.apply(
lambda r:
np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1],
axis=1
)
df['relative_slice'] = df.apply(
lambda r:
np.s_[int(r.rel_y0): int(r.rel_y1), int(r.rel_x0): int(r.rel_x1), :, :],
axis=1
)
df['mask'] = df.apply(
lambda r: (acc_obj_ids == r.label)[r.y0: r.y1, r.x0: r.x1],
axis=1
)
return df
def get_slice_at(self, idx) -> tuple:
"""
Return slice object in np.s_ format that defines expanded bounding box of object
:param idx: object index (Index in DataFrame, does not necessarily start at zero)
:return: slice object
"""
ob = self.get_df().loc[idx, :].astype('int64')
return np.s_[ob.ebb_y0: ob.ebb_y1, ob.ebb_x0: ob.ebb_x1, :, ob.ebb_z0: ob.ebb_z1 + 1]
def get_rel_slice_at(self, idx) -> tuple:
"""
Return slice object in np.s_ format that defines bounding box of an object within its expanded bounding box
:param idx: object index (Index in DataFrame, does not necessarily start at zero)
:return: slice object
"""
ob = self.get_df().loc[idx, :].astype('int64')
return np.s_[ob.rel_y0: ob.rel_y1, ob.rel_x0: ob.rel_x1, :, :]
def get_mask_at(self, idx) -> np.ndarray:
"""
Return 2D array describing object mask that fills (unexpanded) bounding box at index idx
:param idx: object index (Index in DataFrame, does not necessarily start at zero)
:return: np.ndarray boolean mask
"""
ob = self.get_df().loc[idx, :].astype('int64')
obmask = (self.acc_obj_ids == ob.label)
return obmask[ob.y0: ob.y1, ob.x0: ob.x1]
@staticmethod
def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame:
......@@ -158,12 +153,14 @@ class RoiSet(object):
vmax = val['max']
assert vmin >= 0
query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}'
# df.loc[df.query(query_str).index, 'keeper'] = True
return df.loc[df.query(query_str).index, :]
def get_df(self) -> pd.DataFrame: # TODO: exclude columns that refer to objects
return self._df
def get_slices(self) -> pd.Series:
return self.get_df()['slice']
def add_df_col(self, name, se: pd.Series) -> None:
self._df[name] = se
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment