From 9965d5ea2b9486ae705ea8da60ebfd9b272af3be Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Sat, 3 Feb 2024 10:22:21 +0100 Subject: [PATCH] Pass RoiSet object to products, much easier --- model_server/extensions/chaeo/products.py | 40 ++++++++++++++++ .../extensions/chaeo/tests/test_zstack.py | 2 +- model_server/extensions/chaeo/zmask.py | 46 +++---------------- 3 files changed, 47 insertions(+), 41 deletions(-) diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py index ee773a20..d68b4c36 100644 --- a/model_server/extensions/chaeo/products.py +++ b/model_server/extensions/chaeo/products.py @@ -57,6 +57,46 @@ def write_patch_to_file(where, fname, yxcz): raise Exception(f'Unsupported file extension: {ext}') +def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack: + + patches = [] + for mi in roiset.zmask_meta: + sl = mi['slice'] + + rbb = mi['relative_bounding_box'] + x0 = rbb['x0'] + y0 = rbb['y0'] + x1 = rbb['x1'] + y1 = rbb['y1'] + + sp_sl = np.s_[y0: y1, x0: x1, :, :] + + h, w = roiset.acc_raw.data[sl].shape[0:2] + patch = np.zeros((h, w, 1, 1), dtype='uint8') + patch[sp_sl][:, :, 0, 0] = mi['mask'] * 255 + + if pad_to: + patch = pad(patch, pad_to) + + patches.append(patch) + return MonoPatchStack(patches) + + +def export_patch_masks_from_zstack(roiset, where: Path, pad_to: int = 256, prefix='mask') -> list: + patches_acc = roiset.get_patch_masks(pad_to=pad_to) + + exported = [] + for i in range(0, roiset.count): + mi = roiset.zmask_meta[i] + obj = mi['info'] + patch = patches_acc.iat_yxcz(i) + ext = 'png' + fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}' + write_patch_to_file(where, fname, patch) + exported.append(fname) + return exported + + def get_patches_from_zmask_meta( stack: GenericImageDataAccessor, zmask_meta: list, diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py index 66d25594..a925a503 100644 --- a/model_server/extensions/chaeo/tests/test_zstack.py +++ b/model_server/extensions/chaeo/tests/test_zstack.py @@ -199,7 +199,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase): filters={'area': {'min': 1e3, 'max': 1e4}}, expand_box_by=(128, 2) ) - files = roiset.export_patch_masks_from_zstack(output_path / '2d_mask_patches', ) + files = roiset.export_patch_masks(output_path / '2d_mask_patches', ) self.assertGreaterEqual(len(files), 1) def test_object_map_workflow(self): diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py index e30b2c18..1111510c 100644 --- a/model_server/extensions/chaeo/zmask.py +++ b/model_server/extensions/chaeo/zmask.py @@ -4,7 +4,6 @@ from math import floor import numpy as np import pandas as pd from pathlib import Path -from typing import List from skimage.measure import find_contours, label, regionprops_table from sklearn.preprocessing import PolynomialFeatures @@ -19,7 +18,7 @@ from model_server.extensions.chaeo.products import export_patches_from_zstack, e from extensions.chaeo.params import RoiSetMetaParams, RoiSetExportParams from model_server.extensions.chaeo.accessors import MonoPatchStack from model_server.extensions.chaeo.process import mask_largest_object -from model_server.extensions.chaeo.products import get_patches_from_zmask_meta, write_patch_to_file +from model_server.extensions.chaeo.products import get_patches_from_zmask_meta, get_patch_masks, export_patch_masks_from_zstack def get_label_ids(acc_seg_mask): @@ -62,30 +61,11 @@ class RoiSet(object): def get_object_mask_by_class(self, class_id): return self.object_id_labels == class_id - def get_patch_masks(self, pad_to: int = 256) -> MonoPatchStack: - - patches = [] - for mi in self.zmask_meta: - sl = mi['slice'] - - rbb = mi['relative_bounding_box'] - x0 = rbb['x0'] - y0 = rbb['y0'] - x1 = rbb['x1'] - y1 = rbb['y1'] - - sp_sl = np.s_[y0: y1, x0: x1, :, :] - - h, w = self.acc_raw.data[sl].shape[0:2] - patch = np.zeros((h, w, 1, 1), dtype='uint8') - patch[sp_sl][:, :, 0, 0] = mi['mask'] * 255 - - if pad_to: - patch = pad(patch, pad_to) - - patches.append(patch) - return MonoPatchStack(patches) + def get_patch_masks(self, **kwargs) -> MonoPatchStack: + return get_patch_masks(self, **kwargs) + def export_patch_masks(self, where, **kwargs) -> list: + return export_patch_masks_from_zstack(self, where, **kwargs) def get_raw_patches(self, channel): return get_patches_from_zmask_meta( @@ -99,20 +79,6 @@ class RoiSet(object): def get_zmask(self): # TODO: on-the-fly generation of zmask array return self.zmask - def export_patch_masks_from_zstack(self, where: Path, pad_to: int = 256, prefix='mask') -> List: - patches_acc = self.get_patch_masks(pad_to=pad_to) - - exported = [] - for i in range(0, self.count): - mi = self.zmask_meta[i] - obj = mi['info'] - patch = patches_acc.iat_yxcz(i) - ext = 'png' - fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}' - write_patch_to_file(where, fname, patch) - exported.append(fname) - return exported - def classify_by(self, channel, object_classification_model: InstanceSegmentationModel): # do this on a patch basis, i.e. only one object per frame obmap_patches = object_classification_model.label_instance_class( @@ -161,7 +127,7 @@ class RoiSet(object): self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1) if k == 'patch_masks': - self.export_patch_masks_from_zstack(subdir, prefix=pr) + self.export_patch_masks(subdir, prefix=pr) if k == 'annotated_zstacks': annotated = InMemoryDataAccessor( draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp) -- GitLab