From 7ce760c7a246ef47ae2b871b3a06d5bf3acad603 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Sat, 3 Feb 2024 22:08:14 +0100 Subject: [PATCH] simplifying export calls --- model_server/extensions/chaeo/params.py | 2 +- model_server/extensions/chaeo/products.py | 14 ++++++-------- model_server/extensions/chaeo/zmask.py | 5 +---- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/model_server/extensions/chaeo/params.py b/model_server/extensions/chaeo/params.py index 0880ca12..71b9497d 100644 --- a/model_server/extensions/chaeo/params.py +++ b/model_server/extensions/chaeo/params.py @@ -36,7 +36,7 @@ class RoiSetExportParams(BaseModel): patches_3d: Union[PatchParams, None] = None patches_2d_for_annotation: Union[PatchParams, None] = None patches_2d_for_training: Union[PatchParams, None] = None - patch_masks: bool = False + patch_masks: Union[PatchParams, None] = None annotated_zstacks: Union[AnnotatedZStackParams, None] = None object_classes: bool = False dataframe: bool = False diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py index 84535d73..3664ce8e 100644 --- a/model_server/extensions/chaeo/products.py +++ b/model_server/extensions/chaeo/products.py @@ -60,9 +60,9 @@ def write_patch_to_file(where, fname, yxcz): def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack: patches = [] - for ob in roiset.get_df().itertuples('Roi'): - patch = np.zeros((ob.ebb_h, ob.ebb_w, 1, 1), dtype='uint8') - patch[ob.relative_slice][:, :, 0, 0] = ob.mask * 255 + for roi in roiset.get_df().itertuples(): + patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') + patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255 if pad_to: patch = pad(patch, pad_to) @@ -72,15 +72,13 @@ def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack: def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask') -> list: - patches_acc = roiset.get_patch_masks(pad_to=pad_to) + patches_acc = get_patch_masks(roiset, pad_to=pad_to) exported = [] - for i in range(0, roiset.count): - mi = roiset.zmask_meta[i] - obj = mi['info'] + for i, roi in enumerate(roiset.get_df().itertuples()): # assumes index of patches_acc is same as dataframe patch = patches_acc.iat_yxcz(i) ext = 'png' - fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}' + fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' write_patch_to_file(where, fname, patch) exported.append(fname) return exported diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py index b06c276f..781b6cbc 100644 --- a/model_server/extensions/chaeo/zmask.py +++ b/model_server/extensions/chaeo/zmask.py @@ -181,9 +181,6 @@ class RoiSet(object): def get_object_mask_by_class(self, class_id): return self.object_id_labels == class_id - def get_patch_masks(self, **kwargs) -> MonoPatchStack: - return get_patch_masks(self, **kwargs) - def export_patch_masks(self, where, **kwargs) -> list: return export_patch_masks(self, where, **kwargs) @@ -284,7 +281,7 @@ class RoiSet(object): self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1) if k == 'patch_masks': - self.export_patch_masks(subdir, prefix=pr) + self.export_patch_masks(subdir, prefix=pr, **params.patch_masks) if k == 'annotated_zstacks': annotated = InMemoryDataAccessor( draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp) -- GitLab