Skip to content
Snippets Groups Projects
Commit 3e7e34c7 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

get_patches now returns an extended DataFrame with a patch column, so that...

get_patches now returns an extended DataFrame with a patch column, so that exporters have access to patch metadata with no risk of mixing up indexing
parent f32b87c8
No related branches found
No related tags found
No related merge requests found
...@@ -85,31 +85,20 @@ def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask') -> ...@@ -85,31 +85,20 @@ def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask') ->
def get_patches_from_zmask_meta( def get_patches_from_zmask_meta(
# stack: GenericImageDataAccessor,
# zmask_meta: list,
roiset, roiset,
rescale_clip: float = 0.0, rescale_clip: float = 0.0,
pad_to: int = 256, pad_to: int = 256,
make_3d: bool = False, make_3d: bool = False,
focus_metric: str = None, focus_metric: str = None,
**kwargs **kwargs
) -> MonoPatchStack: ) -> pd.DataFrame:
patches = [] # patches = []
# for mi in zmask_meta: # for mi in zmask_meta:
for i, roi in enumerate(roiset.get_df().itertuples()): # TODO: call RoiSet.iter() when implemented # dfe = roiset.get_df().assign(patch=object())
# for i, roi in enumerate(dfe.itertuples()):
# sl = roi['slice']
# rbb = mi['relative_bounding_box'] # TODO: call rel_ fields in DF
# idx = mi['df_index']
#
# x0 = rbb['x0']
# y0 = rbb['y0']
# x1 = rbb['x1']
# y1 = rbb['y1']
#
# sp_sl = np.s_[y0: y1, x0: x1, :, :]
def _make_patch(roi):
patch3d = roiset.acc_raw.data[roi.slice] patch3d = roiset.acc_raw.data[roi.slice]
ph, pw, pc, pz = patch3d.shape ph, pw, pc, pz = patch3d.shape
subpatch = patch3d[roi.relative_slice] subpatch = patch3d[roi.relative_slice]
...@@ -173,12 +162,17 @@ def get_patches_from_zmask_meta( ...@@ -173,12 +162,17 @@ def get_patches_from_zmask_meta(
if pad_to: if pad_to:
patch = pad(patch, pad_to) patch = pad(patch, pad_to)
patches.append(patch) # patches.append(patch)
return patch
# dfe.loc[i, 'patch'] = patch
dfe = roiset.get_df()
dfe['patch'] = roiset.get_df().apply(lambda r: _make_patch(r), axis=1)
return dfe
if not make_3d and pc == 1: # if not make_3d and pc == 1:
return MonoPatchStack(patches) # return MonoPatchStack(patches)
else: # else:
return Multichannel3dPatchStack(patches) # return Multichannel3dPatchStack(patches)
def export_patches_from_zstack( def export_patches_from_zstack(
where: Path, where: Path,
...@@ -192,7 +186,7 @@ def export_patches_from_zstack( ...@@ -192,7 +186,7 @@ def export_patches_from_zstack(
focus_metric: str = None, focus_metric: str = None,
**kwargs **kwargs
): ):
patches_acc = get_patches_from_zmask_meta( patches_df = get_patches_from_zmask_meta(
roiset, roiset,
# stack, # stack,
# zmask_meta, # zmask_meta,
...@@ -203,6 +197,13 @@ def export_patches_from_zstack( ...@@ -203,6 +197,13 @@ def export_patches_from_zstack(
**kwargs **kwargs
) )
pc = roiset.acc_raw.chroma
patches = list(patches_df['patch'])
if not make_3d and pc == 1:
patches_acc = MonoPatchStack(patches)
else:
patches_acc = Multichannel3dPatchStack(patches)
exported = [] exported = []
for i, roi in enumerate(roiset.get_df().itertuples()): # just used for label info for i, roi in enumerate(roiset.get_df().itertuples()): # just used for label info
# for i in range(0, len(zmask_meta)): # for i in range(0, len(zmask_meta)):
......
...@@ -80,6 +80,7 @@ class RoiSet(object): ...@@ -80,6 +80,7 @@ class RoiSet(object):
self.object_id_labels = self.interm['label_map'] self.object_id_labels = self.interm['label_map']
self.object_class_map = None self.object_class_map = None
@staticmethod @staticmethod
def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame: def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
""" """
...@@ -185,11 +186,8 @@ class RoiSet(object): ...@@ -185,11 +186,8 @@ class RoiSet(object):
def export_patch_masks(self, where, **kwargs) -> list: def export_patch_masks(self, where, **kwargs) -> list:
return export_patch_masks(self, where, **kwargs) return export_patch_masks(self, where, **kwargs)
def get_raw_patches(self, channel): def get_raw_patches(self, channel): # tight, un-annotated 2d patches
return get_patches_from_zmask_meta( return get_patches_from_zmask_meta(self, pad_to=None)
self.acc_raw.get_one_channel_data(channel),
self.zmask_meta
)
def get_zmask(self, mask_type='boxes'): def get_zmask(self, mask_type='boxes'):
""" """
...@@ -285,7 +283,7 @@ class RoiSet(object): ...@@ -285,7 +283,7 @@ class RoiSet(object):
self.export_patch_masks(subdir, prefix=pr, **params.patch_masks) self.export_patch_masks(subdir, prefix=pr, **params.patch_masks)
if k == 'annotated_zstacks': if k == 'annotated_zstacks':
annotated = InMemoryDataAccessor( annotated = InMemoryDataAccessor(
draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp) draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp) # TODO remove zmask_meta ref
) )
write_accessor_data_to_file(subdir / (pr + '.tif'), annotated) write_accessor_data_to_file(subdir / (pr + '.tif'), annotated)
if k == 'object_classes': if k == 'object_classes':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment