From 1af591200248727678b2507b364892b809428b24 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 7 Feb 2024 15:51:35 +0100 Subject: [PATCH] Rearranged methods, condensed into roiset package --- model_server/base/process.py | 15 +- model_server/extensions/chaeo/accessors.py | 32 +++ model_server/extensions/chaeo/products.py | 247 ----------------- .../extensions/chaeo/{zmask.py => roiset.py} | 251 ++++++++++++++++-- .../extensions/chaeo/tests/test_zstack.py | 42 ++- model_server/extensions/chaeo/workflows.py | 4 +- 6 files changed, 293 insertions(+), 298 deletions(-) rename model_server/extensions/chaeo/{zmask.py => roiset.py} (53%) diff --git a/model_server/base/process.py b/model_server/base/process.py index 48128223..3ee227d7 100644 --- a/model_server/base/process.py +++ b/model_server/base/process.py @@ -71,4 +71,17 @@ def rescale(nda, clip=0.0): clip_pct = (100.0 * clip, 100.0 * (1.0 - clip)) cmin, cmax = np.percentile(nda, clip_pct) rescaled = rescale_intensity(nda, in_range=(cmin, cmax)) - return rescaled \ No newline at end of file + return rescaled + + +def make_rgb(nda): + """ + Convert a YXCZ stack array to RGB, and error if more than three channels + :param nda: np.ndarray (YXCZ dimensions) + :return: np.ndarray of 3-channel stack + """ + h, w, c, nz = nda.shape + assert c <= 3 + outdata = np.zeros((h, w, 3, nz), dtype=nda.dtype) + outdata[:, :, 0:c, :] = nda[:, :, :, :] + return outdata diff --git a/model_server/extensions/chaeo/accessors.py b/model_server/extensions/chaeo/accessors.py index bc65c6f3..1f1278a3 100644 --- a/model_server/extensions/chaeo/accessors.py +++ b/model_server/extensions/chaeo/accessors.py @@ -1,7 +1,10 @@ from pathlib import Path import numpy as np +from skimage.io import imsave +from tifffile import imwrite +from base.process import make_rgb from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor class MonoPatchStack(InMemoryDataAccessor): @@ -127,6 +130,33 @@ class Multichannel3dPatchStack(InMemoryDataAccessor): def shape_dict(self): return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) + +def write_patch_to_file(where, fname, yxcz): + ext = fname.split('.')[-1].upper() + where.mkdir(parents=True, exist_ok=True) + + if ext == 'PNG': + assert yxcz.dtype == 'uint8', f'Invalid data type {yxcz.dtype}' + assert yxcz.shape[2] <= 3, f'Cannot export images with more than 3 channels as PNGs' + assert yxcz.shape[3] == 1, f'Cannot export z-stacks as PNGs' + if yxcz.shape[2] == 1: + outdata = yxcz[:, :, 0, 0] + elif yxcz.shape[2] == 2: # add a blank blue channel + outdata = make_rgb(yxcz) + else: # preserve RGB order + outdata = yxcz[:, :, :, 0] + imsave(where / fname, outdata, check_contrast=False) + return True + + elif ext in ['TIF', 'TIFF']: + zcyx = np.moveaxis(yxcz, [3, 2, 0, 1], [0, 1, 2, 3]) + imwrite(where / fname, zcyx, imagej=True) + return True + + else: + raise Exception(f'Unsupported file extension: {ext}') + + class Error(Exception): pass @@ -135,3 +165,5 @@ class InvalidDataForPatchStackError(Error): class FileNotFoundError(Error): pass + + diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py index 484756fa..b28b04f6 100644 --- a/model_server/extensions/chaeo/products.py +++ b/model_server/extensions/chaeo/products.py @@ -1,250 +1,3 @@ -from math import floor, sqrt -from pathlib import Path -import numpy as np -import pandas as pd -from scipy.stats import moment -from skimage.filters import sobel -from skimage.io import imsave -from skimage.measure import find_contours, shannon_entropy -from tifffile import imwrite -from model_server.extensions.chaeo.accessors import MonoPatchStack -from model_server.extensions.chaeo.annotators import draw_box_on_patch, draw_contours_on_patch -from model_server.base.accessors import InMemoryDataAccessor -from model_server.base.process import pad, rescale, resample_to_8bit -def _make_rgb(zs): - h, w, c, nz = zs.shape - assert c <= 3 - outdata = np.zeros((h, w, 3, nz), dtype=zs.dtype) - outdata[:, :, 0:c, :] = zs[:, :, :, :] - return outdata - -def _focus_metrics(): - return { - 'max_intensity': lambda x: np.max(x), - 'stdev': lambda x: np.std(x), - 'max_sobel': lambda x: np.max(sobel(x)), - 'rms_sobel': lambda x: sqrt(np.mean(sobel(x) ** 2)), - 'entropy': lambda x: shannon_entropy(x), - 'moment': lambda x: moment(x.flatten(), moment=2), - } - -def _safe_add(a, g, b): - assert a.dtype == b.dtype - assert a.shape == b.shape - assert g >= 0.0 - - return np.clip( - a.astype('uint32') + g * b.astype('uint32'), - 0, - np.iinfo(a.dtype).max - ).astype(a.dtype) - -def write_patch_to_file(where, fname, yxcz): - ext = fname.split('.')[-1].upper() - where.mkdir(parents=True, exist_ok=True) - - if ext == 'PNG': - assert yxcz.dtype == 'uint8', f'Invalid data type {yxcz.dtype}' - assert yxcz.shape[2] <= 3, f'Cannot export images with more than 3 channels as PNGs' - assert yxcz.shape[3] == 1, f'Cannot export z-stacks as PNGs' - if yxcz.shape[2] == 1: - outdata = yxcz[:, :, 0, 0] - elif yxcz.shape[2] == 2: # add a blank blue channel - outdata = _make_rgb(yxcz) - else: # preserve RGB order - outdata = yxcz[:, :, :, 0] - imsave(where / fname, outdata, check_contrast=False) - return True - - elif ext in ['TIF', 'TIFF']: - zcyx = np.moveaxis(yxcz, [3, 2, 0, 1], [0, 1, 2, 3]) - imwrite(where / fname, zcyx, imagej=True) - return True - - else: - raise Exception(f'Unsupported file extension: {ext}') - - - -def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack: - patches = [] - for roi in roiset: - patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') - patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255 - - if pad_to: - patch = pad(patch, pad_to) - - patches.append(patch) - return MonoPatchStack(patches) - - -def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask', **kwargs) -> list: - patches_acc = get_patch_masks(roiset, pad_to=pad_to) - - exported = [] - for i, roi in enumerate(roiset): # assumes index of patches_acc is same as dataframe - patch = patches_acc.iat_yxcz(i) - ext = 'png' - fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' - write_patch_to_file(where, fname, patch) - exported.append(fname) - return exported - - -def get_roiset_patches( - roiset, - rescale_clip: float = 0.0, - pad_to: int = 256, - make_3d: bool = False, - focus_metric: str = None, - rgb_overlay_channels: list = None, - rgb_overlay_weights: list = [1.0, 1.0, 1.0], - white_channel: int = None, - **kwargs -) -> pd.DataFrame: - - # arrange RGB channels if so specified, otherwise copy roiset.raw_acc data - raw = roiset.acc_raw - if isinstance(rgb_overlay_channels, (list, tuple)) and isinstance(rgb_overlay_weights, (list, tuple)): - assert all([c < raw.chroma for c in rgb_overlay_channels if c is not None]) - assert len(rgb_overlay_channels) == 3 - assert len(rgb_overlay_weights) == 3 - - if white_channel: - assert white_channel < raw.chroma - stack = raw.data[:, :, [white_channel, white_channel, white_channel], :] - else: - stack = np.zeros([*raw.shape[0:2], 3, raw.shape[3]], dtype=raw.dtype) - - for ii, ci in enumerate(rgb_overlay_channels): - if ci is None: - continue - assert isinstance(ci, int) - assert ci < raw.chroma - stack[:, :, ii, :] = _safe_add( - stack[:, :, ii, :], # either black or grayscale channel - rgb_overlay_weights[ii], - raw.data[:, :, ci, :] - ) - else: - if white_channel: # interpret as just a single channel - assert white_channel < raw.chroma - annotate_rgb = False - for k in ['contour_channel', 'bounding_box_channel', 'mask_channel']: - ca = kwargs.get(k) - if ca is None: - continue - assert(ca < raw.chroma) - if ca != white_channel: - annotate_rgb = True - break - if annotate_rgb: # make RGB patches anyway to include annotation color - stack = raw.data[:, :, [white_channel, white_channel, white_channel], :] - else: # make monochrome patches - stack = raw.data[:, :, [white_channel], :] - else: - stack = raw.data - - def _make_patch(roi): - patch3d = stack[roi.slice] - ph, pw, pc, pz = patch3d.shape - subpatch = patch3d[roi.relative_slice] - - # make a 3d patch - if make_3d: - patch = patch3d - - # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately - elif focus_metric is not None: - foc = _focus_metrics()[focus_metric] - - patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype) - - for ci in range(0, pc): - me = [foc(subpatch[:, :, ci, zi]) for zi in range(0, pz)] - zif = np.argmax(me) - patch[:, :, ci, 0] = patch3d[:, :, ci, zif] - - # make a 2d patch from middle of z-stack - else: - zim = floor(pz / 2) - patch = patch3d[:, :, :, [zim]] - - assert len(patch.shape) == 4 - - if rescale_clip is not None: - patch = rescale(patch, rescale_clip) - - if kwargs.get('draw_bounding_box') is True: - bci = kwargs.get('bounding_box_channel', 0) - assert bci < 3 - if bci > 0: - patch = _make_rgb(patch) - for zi in range(0, patch.shape[3]): - patch[:, :, bci, zi] = draw_box_on_patch( - patch[:, :, bci, zi], - ((roi.rel_x0, roi.rel_y0), (roi.rel_x1, roi.rel_y1)), - linewidth=kwargs.get('bounding_box_linewidth', 1) - ) - - if kwargs.get('draw_mask'): - mci = kwargs.get('mask_channel', 0) - mask = np.zeros(patch.shape[0:2], dtype=bool) - mask[roi.relative_slice[0:2]] = roi.mask - for zi in range(0, patch.shape[3]): - patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi] - - if kwargs.get('draw_contour'): - mci = kwargs.get('contour_channel', 0) - mask = np.zeros(patch.shape[0:2], dtype=bool) - mask[roi.relative_slice[0:2]] = roi.mask - - for zi in range(0, patch.shape[3]): - patch[:, :, mci, zi] = draw_contours_on_patch( - patch[:, :, mci, zi], - find_contours(mask) - ) - - if pad_to: - patch = pad(patch, pad_to) - return patch - - dfe = roiset.get_df() - dfe['patch'] = roiset.get_df().apply(lambda r: _make_patch(r), axis=1) - return dfe - - -def export_patches_from_zstack( - where: Path, - roiset, - prefix='patch', - **kwargs -): - make_3d = kwargs.get('make_3d', False) - patches_df = get_roiset_patches(roiset, **kwargs) - - def _export_patch(roi): - patch = InMemoryDataAccessor(roi.patch) - ext = 'tif' if make_3d or patch.chroma > 3 else 'png' - fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' - - if patch.dtype is np.dtype('uint16'): - write_patch_to_file(where, fname, resample_to_8bit(patch.data)) - else: - write_patch_to_file(where, fname, patch) - - exported.append({ - 'df_index': roi.Index, - 'patch_filename': fname, - 'location': where.__str__(), - }) - - exported = [] - for roi in patches_df.itertuples(): # just used for label info - _export_patch(roi) - - return exported \ No newline at end of file diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/roiset.py similarity index 53% rename from model_server/extensions/chaeo/zmask.py rename to model_server/extensions/chaeo/roiset.py index 7bb1d99c..c478ff25 100644 --- a/model_server/extensions/chaeo/zmask.py +++ b/model_server/extensions/chaeo/roiset.py @@ -1,34 +1,68 @@ +from math import sqrt, floor +from pathlib import Path from uuid import uuid4 import numpy as np import pandas as pd +from scipy.stats import moment +from skimage.filters import sobel -from skimage.measure import label, regionprops_table +from skimage.measure import label, regionprops_table, shannon_entropy, find_contours from sklearn.preprocessing import PolynomialFeatures from sklearn.linear_model import LinearRegression from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.base.models import InstanceSegmentationModel - -from model_server.extensions.chaeo.annotators import draw_boxes_on_3d_image -from model_server.extensions.chaeo.products import export_patches_from_zstack -from extensions.chaeo.params import RoiFilter, RoiSetMetaParams, RoiSetExportParams -from model_server.extensions.chaeo.accessors import MonoPatchStack, Multichannel3dPatchStack +from model_server.base.process import pad, rescale, resample_to_8bit, make_rgb +from model_server.extensions.chaeo.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image +from model_server.extensions.chaeo.params import RoiFilter, RoiSetMetaParams, RoiSetExportParams +from model_server.extensions.chaeo.accessors import write_patch_to_file, MonoPatchStack, Multichannel3dPatchStack from model_server.extensions.chaeo.process import mask_largest_object -from model_server.extensions.chaeo.products import get_roiset_patches, get_patch_masks, export_patch_masks -def get_label_ids(acc_seg_mask: GenericImageDataAccessor) -> InMemoryDataAccessor: +def _get_label_ids(acc_seg_mask: GenericImageDataAccessor) -> InMemoryDataAccessor: return InMemoryDataAccessor(label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16')) + +def _focus_metrics(): + return { + 'max_intensity': lambda x: np.max(x), + 'stdev': lambda x: np.std(x), + 'max_sobel': lambda x: np.max(sobel(x)), + 'rms_sobel': lambda x: sqrt(np.mean(sobel(x) ** 2)), + 'entropy': lambda x: shannon_entropy(x), + 'moment': lambda x: moment(x.flatten(), moment=2), + } + + +def _safe_add(a, g, b): + assert a.dtype == b.dtype + assert a.shape == b.shape + assert g >= 0.0 + + return np.clip( + a.astype('uint32') + g * b.astype('uint32'), + 0, + np.iinfo(a.dtype).max + ).astype(a.dtype) + + class RoiSet(object): def __init__( self, - acc_obj_ids: GenericImageDataAccessor, acc_raw: GenericImageDataAccessor, + acc_obj_ids: GenericImageDataAccessor, params: RoiSetMetaParams = RoiSetMetaParams(), ): + """ + A set of regions of interest, referenced by their positions and contours in the YXCZ space of stack acc_raw. + RoiSet contains their internal state, which may be exported as patches, maps, and other products by export methods. + :param acc_raw: accessor to a generally a multichannel z-stack + :param acc_obj_ids: accessor to a 2D single-channel object identities map, where each pixel's intensity + labels its membership in a connected object + :param params: optional arguments that influence the definition and representation of ROIs + """ assert acc_obj_ids.chroma == 1 assert acc_obj_ids.nz == 1 self.acc_obj_ids = acc_obj_ids @@ -147,17 +181,11 @@ class RoiSet(object): projected = self.acc_raw.data.max(axis=-1) return projected - def get_patch_masks(self, **kwargs): - return get_patch_masks(self, **kwargs) - - def export_patch_masks(self, where, **kwargs) -> list: - return export_patch_masks(self, where, **kwargs) - def get_raw_patches(self, channel=None, pad_to=256, make_3d=False): # padded, un-annotated 2d patches if channel: - patches_df = get_roiset_patches(self, white_channel=channel, pad_to=pad_to) + patches_df = self.get_patches(white_channel=channel, pad_to=pad_to) else: - patches_df = get_roiset_patches(self, pad_to=pad_to) + patches_df = self.get_patches(pad_to=pad_to) patches = list(patches_df['patch']) if channel is not None or self.acc_raw.chroma == 1: return MonoPatchStack(patches) @@ -229,6 +257,179 @@ class RoiSet(object): om[self.acc_obj_ids.data == roi.label] = oc self.object_class_maps[name] = InMemoryDataAccessor(om) + def export_patch_masks(self, where: Path, pad_to: int = 256, prefix='mask', **kwargs) -> list: + patches_acc = self.get_patch_masks(pad_to=pad_to) + + exported = [] + for i, roi in enumerate(self): # assumes index of patches_acc is same as dataframe + patch = patches_acc.iat_yxcz(i) + ext = 'png' + fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' + write_patch_to_file(where, fname, patch) + exported.append(fname) + return exported + + + def export_patches(self, where: Path, prefix='patch', **kwargs): + make_3d = kwargs.get('make_3d', False) + patches_df = self.get_patches(**kwargs) + + def _export_patch(roi): + patch = InMemoryDataAccessor(roi.patch) + ext = 'tif' if make_3d or patch.chroma > 3 else 'png' + fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' + + if patch.dtype is np.dtype('uint16'): + write_patch_to_file(where, fname, resample_to_8bit(patch.data)) + else: + write_patch_to_file(where, fname, patch) + + exported.append({ + 'df_index': roi.Index, + 'patch_filename': fname, + 'location': where.__str__(), + }) + + exported = [] + for roi in patches_df.itertuples(): # just used for label info + _export_patch(roi) + + return exported + + def get_patch_masks(self, pad_to: int = 256) -> MonoPatchStack: + patches = [] + for roi in self: + patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') + patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255 + + if pad_to: + patch = pad(patch, pad_to) + + patches.append(patch) + return MonoPatchStack(patches) + + def get_patches( + self, + rescale_clip: float = 0.0, + pad_to: int = 256, + make_3d: bool = False, + focus_metric: str = None, + rgb_overlay_channels: list = None, + rgb_overlay_weights: list = [1.0, 1.0, 1.0], + white_channel: int = None, + **kwargs + ) -> pd.DataFrame: + + # arrange RGB channels if so specified, otherwise copy roiset.raw_acc data + raw = self.acc_raw + if isinstance(rgb_overlay_channels, (list, tuple)) and isinstance(rgb_overlay_weights, (list, tuple)): + assert all([c < raw.chroma for c in rgb_overlay_channels if c is not None]) + assert len(rgb_overlay_channels) == 3 + assert len(rgb_overlay_weights) == 3 + + if white_channel: + assert white_channel < raw.chroma + stack = raw.data[:, :, [white_channel, white_channel, white_channel], :] + else: + stack = np.zeros([*raw.shape[0:2], 3, raw.shape[3]], dtype=raw.dtype) + + for ii, ci in enumerate(rgb_overlay_channels): + if ci is None: + continue + assert isinstance(ci, int) + assert ci < raw.chroma + stack[:, :, ii, :] = _safe_add( + stack[:, :, ii, :], # either black or grayscale channel + rgb_overlay_weights[ii], + raw.data[:, :, ci, :] + ) + else: + if white_channel: # interpret as just a single channel + assert white_channel < raw.chroma + annotate_rgb = False + for k in ['contour_channel', 'bounding_box_channel', 'mask_channel']: + ca = kwargs.get(k) + if ca is None: + continue + assert (ca < raw.chroma) + if ca != white_channel: + annotate_rgb = True + break + if annotate_rgb: # make RGB patches anyway to include annotation color + stack = raw.data[:, :, [white_channel, white_channel, white_channel], :] + else: # make monochrome patches + stack = raw.data[:, :, [white_channel], :] + else: + stack = raw.data + + def _make_patch(roi): + patch3d = stack[roi.slice] + ph, pw, pc, pz = patch3d.shape + subpatch = patch3d[roi.relative_slice] + + # make a 3d patch + if make_3d: + patch = patch3d + + # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately + elif focus_metric is not None: + foc = _focus_metrics()[focus_metric] + + patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype) + + for ci in range(0, pc): + me = [foc(subpatch[:, :, ci, zi]) for zi in range(0, pz)] + zif = np.argmax(me) + patch[:, :, ci, 0] = patch3d[:, :, ci, zif] + + # make a 2d patch from middle of z-stack + else: + zim = floor(pz / 2) + patch = patch3d[:, :, :, [zim]] + + assert len(patch.shape) == 4 + + if rescale_clip is not None: + patch = rescale(patch, rescale_clip) + + if kwargs.get('draw_bounding_box') is True: + bci = kwargs.get('bounding_box_channel', 0) + assert bci < 3 + if bci > 0: + patch = make_rgb(patch) + for zi in range(0, patch.shape[3]): + patch[:, :, bci, zi] = draw_box_on_patch( + patch[:, :, bci, zi], + ((roi.rel_x0, roi.rel_y0), (roi.rel_x1, roi.rel_y1)), + linewidth=kwargs.get('bounding_box_linewidth', 1) + ) + + if kwargs.get('draw_mask'): + mci = kwargs.get('mask_channel', 0) + mask = np.zeros(patch.shape[0:2], dtype=bool) + mask[roi.relative_slice[0:2]] = roi.mask + for zi in range(0, patch.shape[3]): + patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi] + + if kwargs.get('draw_contour'): + mci = kwargs.get('contour_channel', 0) + mask = np.zeros(patch.shape[0:2], dtype=bool) + mask[roi.relative_slice[0:2]] = roi.mask + + for zi in range(0, patch.shape[3]): + patch[:, :, mci, zi] = draw_contours_on_patch( + patch[:, :, mci, zi], + find_contours(mask) + ) + + if pad_to: + patch = pad(patch, pad_to) + return patch + + dfe = self._df + dfe['patch'] = self._df.apply(lambda r: _make_patch(r), axis=1) + return dfe + def run_exports(self, where, channel, prefix, params: RoiSetExportParams): if not self.count: return @@ -240,17 +441,17 @@ class RoiSet(object): if kp is None: continue if k == 'patches_3d': - files = export_patches_from_zstack( - subdir, self, white_channel=channel, prefix=pr, make_3d=True, **kp + files = self.export_patches( + subdir, white_channel=channel, prefix=pr, make_3d=True, **kp ) if k == 'annotated_patches_2d': - files = export_patches_from_zstack( - subdir, self, prefix=pr, make_3d=False, white_channel=channel, + files = self.export_patches( + subdir, prefix=pr, make_3d=False, white_channel=channel, bounding_box_channel=1, bounding_box_linewidth=2, **kp, ) if k == 'patches_2d': - files = export_patches_from_zstack( - subdir, self, white_channel=channel, prefix=pr, make_3d=False, **kp + files = self.export_patches( + subdir, white_channel=channel, prefix=pr, make_3d=False, **kp ) df_patches = pd.DataFrame(files) self._df = pd.merge(self._df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') @@ -313,3 +514,7 @@ def project_stack_from_focal_points( ), axis=3 ) + + + + diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py index 9631dd04..f3b82997 100644 --- a/model_server/extensions/chaeo/tests/test_zstack.py +++ b/model_server/extensions/chaeo/tests/test_zstack.py @@ -7,9 +7,8 @@ from model_server.conf.testing import output_path from model_server.extensions.chaeo.conf.testing import multichannel_zstack, pixel_classifier, pipeline_params from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams -from model_server.extensions.chaeo.products import export_patches_from_zstack from model_server.extensions.chaeo.workflows import infer_object_map_from_zstack -from model_server.extensions.chaeo.zmask import get_label_ids, RoiSet +from model_server.extensions.chaeo.roiset import _get_label_ids, RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.extensions.ilastik.models import IlastikPixelClassifierModel from model_server.base.models import DummyInstanceSegmentationModel @@ -42,10 +41,10 @@ class BaseTestRoiSetMonoProducts(object): class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def _make_roi_set(self, mask_type='boxes', **kwargs): - id_map = get_label_ids(self.seg_mask) + id_map = _get_label_ids(self.seg_mask) roiset = RoiSet( - id_map, self.stack_ch_pa, + id_map, params=RoiSetMetaParams( mask_type=mask_type, filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}), @@ -78,9 +77,9 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_roiset_from_non_zstacks(self, **kwargs): acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0]) self.assertEqual(acc_zstack_slice.nz, 1) - id_map = get_label_ids(self.seg_mask) + id_map = _get_label_ids(self.seg_mask) - roiset = RoiSet(id_map, acc_zstack_slice, params=RoiSetMetaParams(mask_type='boxes')) + roiset = RoiSet(acc_zstack_slice, id_map, params=RoiSetMetaParams(mask_type='boxes')) zmask = roiset.get_zmask() zmask_acc = InMemoryDataAccessor(zmask) @@ -105,18 +104,16 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_make_2d_patches(self): roiset = self._make_roi_set() - files = export_patches_from_zstack( + files = roiset.export_patches( output_path / '2d_patches', - roiset, draw_bounding_box=True, ) self.assertGreaterEqual(len(files), 1) def test_make_3d_patches(self): roiset = self._make_roi_set() - files = export_patches_from_zstack( + files = roiset.export_patches( output_path / '3d_patches', - roiset, make_3d=True) self.assertGreaterEqual(len(files), 1) @@ -129,12 +126,12 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertEqual(result.shape, roiset.acc_raw.shape) def test_flatten_image(self): - id_map = get_label_ids(self.seg_mask) + id_map = _get_label_ids(self.seg_mask) - roiset = RoiSet(id_map, self.stack_ch_pa, params=RoiSetMetaParams(mask_type='boxes')) + roiset = RoiSet(self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='boxes')) df = roiset.get_df() - from model_server.extensions.chaeo.zmask import project_stack_from_focal_points + from model_server.extensions.chaeo.roiset import project_stack_from_focal_points img = project_stack_from_focal_points( df['centroid-0'].to_numpy(), @@ -227,10 +224,10 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa def setUp(self) -> None: super().setUp() - id_map = get_label_ids(self.seg_mask) + id_map = _get_label_ids(self.seg_mask) self.roiset = RoiSet( - id_map, self.stack, + id_map, params=RoiSetMetaParams( expand_box_by=(128, 2), mask_type='boxes', @@ -239,9 +236,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa ) def test_multichannel_to_mono_2d_patches(self): - files = export_patches_from_zstack( + files = self.roiset.export_patches( output_path / 'multichannel' / 'mono_2d_patches', - self.roiset, white_channel=3, draw_bounding_box=True, ) @@ -249,9 +245,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa self.assertEqual(result.chroma, 1) def test_multichannnel_to_mono_2d_patches_rgb_bbox(self): - files = export_patches_from_zstack( + files = self.roiset.export_patches( output_path / 'multichannel' / 'mono_2d_patches_rgb_bbox', - self.roiset, white_channel=3, draw_bounding_box=True, bounding_box_channel=1, @@ -260,9 +255,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_bbox(self): - files = export_patches_from_zstack( + files = self.roiset.export_patches( output_path / 'multichannel' / 'rgb_2d_patches_bbox', - self.roiset, white_channel=4, rgb_overlay_channels=(3, None, None), draw_mask=True, @@ -273,9 +267,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_contour(self): - files = export_patches_from_zstack( + files = self.roiset.export_patches( output_path / 'multichannel' / 'rgb_2d_patches_contour', - self.roiset, rgb_overlay_channels=(3, None, None), draw_contour=True, contour_channel=1, @@ -286,9 +279,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa self.assertEqual(result.get_one_channel_data(2).data.max(), 0) # blue channel is black def test_multichannel_to_multichannel_tif_patches(self): - files = export_patches_from_zstack( + files = self.roiset.export_patches( output_path / 'multichannel' / 'multichannel_tif_patches', - self.roiset, ) result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 5) diff --git a/model_server/extensions/chaeo/workflows.py b/model_server/extensions/chaeo/workflows.py index 7e90436b..08d06adf 100644 --- a/model_server/extensions/chaeo/workflows.py +++ b/model_server/extensions/chaeo/workflows.py @@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams from model_server.extensions.chaeo.process import mask_largest_object -from model_server.extensions.chaeo.zmask import get_label_ids, RoiSet +from model_server.extensions.chaeo.roiset import _get_label_ids, RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.base.models import Model, InstanceSegmentationModel, SemanticSegmentationModel @@ -48,7 +48,7 @@ def infer_object_map_from_zstack( ti.click('classify_pixels') # make zmask - rois = RoiSet(get_label_ids(mip_mask), stack, params=roi_params) + rois = RoiSet(stack, _get_label_ids(mip_mask), params=roi_params) ti.click('generate_zmasks') rois.classify_by( -- GitLab