diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py index a8f5799fb464ac3d350dc0316c3354f39f0e58c0..4cf9585f0bd382242b32af8c9b7d202ce0a1d798 100644 --- a/model_server/extensions/chaeo/products.py +++ b/model_server/extensions/chaeo/products.py @@ -31,7 +31,7 @@ def _focus_metrics(): 'moment': lambda x: moment(x.flatten(), moment=2), } -def _write_patch_to_file(where, fname, yxcz): +def write_patch_to_file(where, fname, yxcz): ext = fname.split('.')[-1].upper() where.mkdir(parents=True, exist_ok=True) @@ -55,60 +55,7 @@ def _write_patch_to_file(where, fname, yxcz): else: raise Exception(f'Unsupported file extension: {ext}') - -def get_patch_masks_from_zmask_meta( - stack: GenericImageDataAccessor, - zmask_meta: list, - pad_to: int = 256, -) -> MonoPatchStack: - patches = [] - for mi in zmask_meta: - sl = mi['slice'] - - rbb = mi['relative_bounding_box'] - x0 = rbb['x0'] - y0 = rbb['y0'] - x1 = rbb['x1'] - y1 = rbb['y1'] - - sp_sl = np.s_[y0: y1, x0: x1, :, :] - - h, w = stack.data[sl].shape[0:2] - patch = np.zeros((h, w, 1, 1), dtype='uint8') - patch[sp_sl][:, :, 0, 0] = mi['mask'] * 255 - - if pad_to: - patch = pad(patch, pad_to) - - patches.append(patch) - return MonoPatchStack(patches) - -def export_patch_masks_from_zstack( - where: Path, - stack: GenericImageDataAccessor, - zmask_meta: list, - pad_to: int = 256, - prefix='mask', - **kwargs -): - patches_acc = get_patch_masks_from_zmask_meta( - stack, - zmask_meta, - pad_to=pad_to, - **kwargs - ) - assert len(zmask_meta) == patches_acc.count - - exported = [] - for i in range(0, len(zmask_meta)): - mi = zmask_meta[i] - obj = mi['info'] - patch = patches_acc.iat_yxcz(i) - ext = 'png' - fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}' - _write_patch_to_file(where, fname, patch) - exported.append(fname) - return exported + def get_patches_from_zmask_meta( stack: GenericImageDataAccessor, @@ -236,9 +183,9 @@ def export_patches_from_zstack( fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}' if patch.dtype is np.dtype('uint16'): - _write_patch_to_file(where, fname, resample_to_8bit(patch)) + write_patch_to_file(where, fname, resample_to_8bit(patch)) else: - _write_patch_to_file(where, fname, patch) + write_patch_to_file(where, fname, patch) exported.append({ 'df_index': idx, @@ -317,7 +264,7 @@ def export_3d_patches_with_focus_metrics( patch = pad(patch, pad_to) fstem = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}' - _write_patch_to_file(where, fstem + '.tif', resample_to_8bit(patch)) + write_patch_to_file(where, fstem + '.tif', resample_to_8bit(patch)) me_df.to_csv(where / (fstem + '.csv')) exported.append({ 'df_index': idx, diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py index cc700b49aba513f3d793fda87406fcdd211cb575..f9ebb589136058c642aa6ccd84eb4db525ccc654 100644 --- a/model_server/extensions/chaeo/tests/test_zstack.py +++ b/model_server/extensions/chaeo/tests/test_zstack.py @@ -6,7 +6,7 @@ from model_server.conf.testing import output_path from model_server.extensions.chaeo.conf.testing import multichannel_zstack, pixel_classifier, pipeline_params from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams -from model_server.extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack +from model_server.extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack from model_server.extensions.chaeo.workflows import infer_object_map_from_zstack from model_server.extensions.chaeo.zmask import get_label_ids, RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file @@ -15,8 +15,6 @@ from model_server.base.models import DummyInstanceSegmentationModel class TestZStackDerivedDataProducts(unittest.TestCase): - # TODO: add cases that call RoiSet directly, not just through workflow function - def setUp(self) -> None: # need test data incl obj map @@ -65,7 +63,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase): self.assertFalse(np.all(zmask)) # assert non-trivial meta info in boxes - self.assertGreater(len(meta), 1) + self.assertGreater(roiset.count, 1) sh = meta[1]['mask'].shape ar = meta[1]['info'].area self.assertGreaterEqual(sh[0] * sh[1], ar) @@ -195,17 +193,18 @@ class TestZStackDerivedDataProducts(unittest.TestCase): ) self.assertGreaterEqual(len(files), 1) - def test_make_binary_masks_from_zmask(self): - zmask, meta = self.test_zmask_makes_correct_boxes( - filters={'area': {'min': 1e3, 'max': 1e4}}, - expand_box_by=(128, 2) - ) - files = export_patch_masks_from_zstack( - output_path / '2d_mask_patches', - InMemoryDataAccessor(self.stack.data), - meta, - ) - self.assertGreaterEqual(len(files), 1) + # TODO: rewrite with direct call to RoiSet methods + # def test_make_binary_masks_from_zmask(self): + # zmask, meta = self.test_zmask_makes_correct_boxes( + # filters={'area': {'min': 1e3, 'max': 1e4}}, + # expand_box_by=(128, 2) + # ) + # files = export_patch_masks_from_zstack( + # output_path / '2d_mask_patches', + # InMemoryDataAccessor(self.stack.data), + # meta, + # ) + # self.assertGreaterEqual(len(files), 1) def test_object_map_workflow(self): pp = pipeline_params diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py index a2f0d2ae8e28c881c271dbd6c5734da1d86967d8..33ba4b36c6cbb06cf3955e3198d5029db56771df 100644 --- a/model_server/extensions/chaeo/zmask.py +++ b/model_server/extensions/chaeo/zmask.py @@ -2,17 +2,24 @@ from uuid import uuid4 import numpy as np import pandas as pd +from pathlib import Path +from typing import List from skimage.measure import find_contours, label, regionprops_table from sklearn.preprocessing import PolynomialFeatures from sklearn.linear_model import LinearRegression +from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file +from model_server.base.models import InstanceSegmentationModel +from model_server.base.process import pad + from model_server.extensions.chaeo.annotators import draw_boxes_on_3d_image -from model_server.extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta +from model_server.extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack from extensions.chaeo.params import RoiSetMetaParams, RoiSetExportParams +from model_server.extensions.chaeo.accessors import MonoPatchStack from model_server.extensions.chaeo.process import mask_largest_object -from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file -from model_server.base.models import InstanceSegmentationModel +from model_server.extensions.chaeo.products import get_patches_from_zmask_meta, write_patch_to_file + def get_label_ids(acc_seg_mask): return label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16') @@ -53,14 +60,59 @@ class RoiSet(object): projected = self.acc_raw.data.max(axis=-1) return projected + def get_object_mask_by_class(self, class_id): + return self.object_id_labels == class_id + + def get_patch_masks(self, pad_to: int = 256) -> MonoPatchStack: + + patches = [] + for mi in self.zmask_meta: + sl = mi['slice'] + + rbb = mi['relative_bounding_box'] + x0 = rbb['x0'] + y0 = rbb['y0'] + x1 = rbb['x1'] + y1 = rbb['y1'] + + sp_sl = np.s_[y0: y1, x0: x1, :, :] + + h, w = self.acc_raw.data[sl].shape[0:2] + patch = np.zeros((h, w, 1, 1), dtype='uint8') + patch[sp_sl][:, :, 0, 0] = mi['mask'] * 255 + + if pad_to: + patch = pad(patch, pad_to) + + patches.append(patch) + return MonoPatchStack(patches) + + def get_raw_patches(self, channel): return get_patches_from_zmask_meta( self.acc_raw.get_one_channel_data(channel), self.zmask_meta ) - def get_patch_masks(self): - return get_patch_masks_from_zmask_meta(self.acc_raw, self.zmask_meta) + def get_slices(self): + return [zm.slice for zm in self.zmask_meta] + + def get_zmask(self): # TODO: on-the-fly generation of zmask array + return self.zmask + + def export_patch_masks_from_zstack(self, where: Path, pad_to: int = 256, prefix='mask'): + patches_acc = self.get_patch_masks(pad_to=pad_to) + + exported = [] + for i in range(0, self.count): + mi = self.zmask_meta[i] + obj = mi['info'] + patch = patches_acc.iat_yxcz(i) + ext = 'png' + fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}' + write_patch_to_file(where, fname, patch) + exported.append(fname) + return exported def classify_by(self, channel, object_classification_model: InstanceSegmentationModel): # do this on a patch basis, i.e. only one object per frame @@ -83,12 +135,6 @@ class RoiSet(object): self.object_class_map = InMemoryDataAccessor(om) - def get_object_mask_by_class(self, class_id): - return self.object_id_labels == class_id - - def get_zmask(self): # TODO: on-the-fly generation of zmask array - return self.zmask - def run_exports(self, where, channel, prefix, params: RoiSetExportParams): if not self.count: return @@ -116,9 +162,7 @@ class RoiSet(object): self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1) if k == 'patch_masks': - export_patch_masks_from_zstack( - subdir, raw_ch, self.zmask_meta, prefix=pr, - ) + self.export_patch_masks_from_zstack(subdir, prefix=pr) if k == 'annotated_zstacks': annotated = InMemoryDataAccessor( draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp)