diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index 07eba484b24c7b64b729ee1bb71de708ecf1f934..bc74f919f3d2cc294398e921cffcdb7673754232 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -13,6 +13,8 @@ from model_server.base.process import is_mask class GenericImageDataAccessor(ABC): + axes = 'YXCZ' + @abstractmethod def __init__(self): """ @@ -25,6 +27,13 @@ class GenericImageDataAccessor(ABC): def chroma(self): return self.shape_dict['C'] + @staticmethod + def _derived_accessor(data): + """ + Create a new accessor given np.ndarray data; used for example in slicing operations + """ + return InMemoryDataAccessor(data) + @staticmethod def conform_data(data): if len(data.shape) > 4 or (0 in data.shape): @@ -38,12 +47,23 @@ class GenericImageDataAccessor(ABC): def is_mask(self): return is_mask(self._data) - def get_one_channel_data (self, channel: int, mip: bool = False): - c = int(channel) + def get_channels(self, channels: list, mip: bool = False): + carr = [int(c) for c in channels] if mip: - return InMemoryDataAccessor(self.data[:, :, c:(c+1), :].max(axis=-1)) + nda = self.data.take(indices=carr, axis=self._ga('C')).max(axis=self._ga('Z'), keepdims=True) + return self._derived_accessor(nda) else: - return InMemoryDataAccessor(self.data[:, :, c:(c+1), :]) + nda = self.data.take(indices=carr, axis=self._ga('C')) + return self._derived_accessor(nda) + + def get_one_channel_data(self, channel: int, mip: bool = False): + return self.get_channels([channel], mip=mip) + + def _gc(self, channels): + return self.get_channels(list(channels)) + + def _unique(self): + return np.unique(self.data, return_counts=True) @property def pixel_scale_in_micrometers(self): @@ -53,6 +73,12 @@ class GenericImageDataAccessor(ABC): def dtype(self): return self.data.dtype + def get_axis(self, ch): + return self.axes.index(ch.upper()) + + def _ga(self, arg): + return self.get_axis(arg) + @property def hw(self): """ @@ -162,6 +188,14 @@ class CziImageFileAccessor(GenericImageFileAccessor): except Exception: raise FileAccessorError(f'Unable to access CZI data in {fpath}') + try: + md = cf.metadata(raw=False) + compmet = md['ImageDocument']['Metadata']['Information']['Image']['OriginalCompressionMethod'] + except KeyError: + raise InvalidCziCompression('Could not find metadata key OriginalCompressionMethod') + if compmet.upper() != 'UNCOMPRESSED': + raise InvalidCziCompression(f'Unsupported compression method {compmet}') + sd = {ch: cf.shape[cf.axes.index(ch)] for ch in cf.axes} if (sd.get('S') and (sd['S'] > 1)) or (sd.get('T') and (sd['T'] > 1)): raise DataShapeError(f'Cannot handle image with multiple positions or time points: {sd}') @@ -250,40 +284,84 @@ def generate_file_accessor(fpath): class PatchStack(InMemoryDataAccessor): - def __init__(self, data): + axes = 'PYXCZ' + + def __init__(self, data, force_ydim_longest=False): """ A sequence of n (generally) color 3D images of the same size :param data: either a list of np.ndarrays of size YXCZ, or np.ndarray of size PYXCZ + :param force_ydmin_longest: if creating a PatchStack from a list of different-sized patches, rotate each + as needed so that height is always greater than or equal to width """ - + self._slices = [] if isinstance(data, list): # list of YXCZ patches n = len(data) - yxcz_shape = np.array([e.shape for e in data]).max(axis=0) + if force_ydim_longest: + psh = np.array([e.shape[0:2] for e in data]).max(axis=1).max() + psw = np.array([e.shape[0:2] for e in data]).min(axis=1).max() + psc, psz = np.array([e.shape[2:] for e in data]).max(axis=0) + yxcz_shape = np.array([psh, psw, psc, psz]) + else: + yxcz_shape = np.array([e.shape for e in data]).max(axis=0) nda = np.zeros( (n, *yxcz_shape), dtype=data[0].dtype ) for i in range(0, len(data)): - s = tuple([slice(0, c) for c in data[i].shape]) - nda[i][s] = data[i] + h, w = data[i].shape[0:2] + if force_ydim_longest and w > h: + patch = np.rot90(data[i], axes=(0, 1)) + else: + patch = data[i] + s = tuple([slice(0, c) for c in patch.shape]) + nda[i][s] = patch + self._slices.append(s) elif isinstance(data, np.ndarray) and len(data.shape) == 5: # interpret as PYXCZ nda = data + for i in range(0, len(data)): + self._slices.append(tuple([slice(0, c) for c in data[i].shape])) else: raise InvalidDataForPatchStackError(f'Cannot create accessor from {type(data)}') assert nda.ndim == 5 self._data = nda - def iat(self, i): - return InMemoryDataAccessor(self.data[i, :, :, :, :]) + @staticmethod + def _derived_accessor(data): + return PatchStack(data) + + def get_slice_at(self, i): + return self._slices[i] - def iat_yxcz(self, i): - return self.iat(i) + def iat(self, i, crop=False): + if crop: + return InMemoryDataAccessor(self.data[i, :, :, :, :][self._slices[i]]) + else: + return InMemoryDataAccessor(self.data[i, :, :, :, :]) + + def iat_yxcz(self, i, crop=False): + return self.iat(i, crop=crop) @property def count(self): return self.shape_dict['P'] + def export_pyxcz(self, fpath: Path): + tzcyx = np.moveaxis( + self.pyxcz, # yxcz + [0, 4, 3, 1, 2], + [0, 1, 2, 3, 4] + ) + + if self.is_mask(): + if self.dtype == 'bool': + data = (tzcyx * 255).astype('uint8') + else: + data = tzcyx.astype('uint8') + tifffile.imwrite(fpath, data, imagej=True) + else: + tifffile.imwrite(fpath, tzcyx, imagej=True) + @property def shape_dict(self): return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) @@ -313,16 +391,32 @@ class PatchStack(InMemoryDataAccessor): return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) -def make_patch_stack_from_file(fpath): # interpret z-dimension as patch position +def make_patch_stack_from_file(fpath): # interpret t-dimension as patch position if not Path(fpath).exists(): raise FileNotFoundError(f'Could not find {fpath}') - pyxc = np.moveaxis( - generate_file_accessor(fpath).data, # yxcz - [0, 1, 2, 3], - [1, 2, 3, 0] + try: + tf = tifffile.TiffFile(fpath) + except Exception: + raise FileAccessorError(f'Unable to access data in {fpath}') + + if len(tf.series) != 1: + raise DataShapeError(f'Expect only one series in {fpath}') + + se = tf.series[0] + + axs = [a for a in se.axes if a in [*'TZCYX']] + sd = dict(zip(axs, se.shape)) + for a in [*'TZC']: + if a not in axs: + sd[a] = 1 + tzcyx = se.asarray().reshape([sd[k] for k in [*'TZCYX']]) + + pyxcz = np.moveaxis( + tzcyx, + [0, 3, 4, 2, 1], + [0, 1, 2, 3, 4], ) - pyxcz = np.expand_dims(pyxc, axis=3) return PatchStack(pyxcz) @@ -345,6 +439,9 @@ class FileWriteError(Error): class InvalidAxisKey(Error): pass +class InvalidCziCompression(Error): + pass + class InvalidDataShape(Error): pass diff --git a/model_server/base/api.py b/model_server/base/api.py index 8259a98c8f74a5ca25b6b2b3c6e4d8edc95fa547..987cc3817415651f0c69f0506e5880c5b6d3215d 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -1,4 +1,5 @@ from fastapi import FastAPI, HTTPException +from pydantic import BaseModel from model_server.base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel from model_server.base.session import Session, InvalidPathError @@ -20,9 +21,13 @@ def startup(): def read_root(): return {'success': True} +class BounceBackParams(BaseModel): + par1: str + par2: list + @app.put('/bounce_back') -def list_bounce_back(par1=None, par2=None): - return {'success': True, 'params': {'par1': par1, 'par2': par2}} +def list_bounce_back(params: BounceBackParams): + return {'success': True, 'params': {'par1': params.par1, 'par2': params.par2}} @app.get('/paths') def list_session_paths(): @@ -53,7 +58,7 @@ def watch_input_path(path: str): return change_path('inbound_images', path) @app.put('/paths/watch_output') -def watch_input_path(path: str): +def watch_output_path(path: str): return change_path('outbound_images', path) @app.get('/restart') diff --git a/model_server/base/models.py b/model_server/base/models.py index 8413f68d4420fe31add5e25818b4198bdc1a76eb..4313c6f573d733fe08a621f7b2df24889e2d8075 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -3,7 +3,7 @@ from math import floor import numpy as np -from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor +from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, PatchStack class Model(ABC): @@ -70,6 +70,16 @@ class SemanticSegmentationModel(ImageToImageModel): """ pass + def label_patch_stack(self, img: PatchStack, **kwargs) -> PatchStack: + """ + Iterative over a patch stack, call pixel labeling (to boolean array) separately on each cropped patch + """ + data = np.zeros((img.count, *img.hw, 1, img.nz), dtype=bool) + for i in range(0, img.count): + sl = img.get_slice_at(i) + data[i][sl] = self.label_pixel_class(img.iat(i, crop=True), **kwargs) + return PatchStack(data) + class InstanceSegmentationModel(ImageToImageModel): """ @@ -85,9 +95,32 @@ class InstanceSegmentationModel(ImageToImageModel): """ if not mask.is_mask(): raise InvalidInputImageError('Expecting a binary mask') - if not img.shape == mask.shape: + if img.hw != mask.hw or img.nz != mask.nz: raise InvalidInputImageError('Expect input image and mask to be the same shape') + def label_patch_stack(self, img: PatchStack, mask: PatchStack, allow_multiple=True, force_single=False, **kwargs): + """ + Call inference on all patches in a PatchStack at once + :param img: raw image data + :param mask: binary masks of same shape as img + :param allow_multiple: allow multiple nonzero pixel values in inferred patches if True + :param force_single: if True, and if allow_multiple is False, convert all nonzero pixels in a patch to the most + used label; otherwise raise an exception + :return: PatchStack of labeled objects + """ + res = self.label_instance_class(img, mask, **kwargs) + data = res.data + for i in range(0, res.count): # interpret as PYXCZ + la_patch = data[i, :, :, :, :] + la, ct = np.unique(la_patch, return_counts=True) + if len(la[la > 0]) > 1 and not allow_multiple: + if force_single: + la_patch[la_patch > 0] = la[1:][ct[1:].argsort()[-1]] # most common nonzero value + else: + raise InvalidObjectLabelsError(f'Found more than one nonzero label: {la}, counts: {ct}') + data[i, :, :, :, :] = la_patch + return PatchStack(data) + class DummySemanticSegmentationModel(SemanticSegmentationModel): @@ -142,4 +175,7 @@ class ParameterExpectedError(Error): pass class InvalidInputImageError(Error): + pass + +class InvalidObjectLabelsError(Error): pass \ No newline at end of file diff --git a/model_server/base/process.py b/model_server/base/process.py index 992c9fb93140ab49859f0807b55a1f8e321cd80d..94c6cf6206d4a0f1dcc4cfb78635807d58efd0c7 100644 --- a/model_server/base/process.py +++ b/model_server/base/process.py @@ -6,7 +6,7 @@ from math import ceil, floor import numpy as np import skimage from skimage.exposure import rescale_intensity - +from skimage.measure import find_contours def is_mask(img): """ @@ -117,6 +117,17 @@ def mask_largest_object( else: return img +def get_safe_contours(mask): + """ + Return a list of contour coordinates even if a mask is only one pixel across + """ + if mask.shape[0] == 1 or mask.shape[1] == 1: + c0 = mask.shape[0] - 1 + c1 = mask.shape[1] - 1 + return [np.array([(0, 0), (c0, c1)])] + else: + return find_contours(mask) + class Error(Exception): pass diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index c44023db6622237b5fe40dd82b90d2f161187517..f3c0c9479c35d0cea4bb0fc660a9e692ca101686 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -1,5 +1,6 @@ from math import sqrt, floor from pathlib import Path +import re from typing import List, Union from uuid import uuid4 @@ -15,9 +16,9 @@ from sklearn.linear_model import LinearRegression from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.base.models import InstanceSegmentationModel -from model_server.base.process import pad, rescale, resample_to_8bit, make_rgb +from model_server.base.process import get_safe_contours, pad, rescale, resample_to_8bit, make_rgb from model_server.base.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image -from model_server.base.accessors import PatchStack +from model_server.base.accessors import generate_file_accessor, PatchStack from model_server.base.process import mask_largest_object @@ -27,9 +28,10 @@ class PatchParams(BaseModel): draw_mask: bool = False rescale_clip: float = 0.001 focus_metric: str = 'max_sobel' - rgb_overlay_channels: List[Union[int, None]] = [None, None, None] + rgb_overlay_channels: Union[List[Union[int, None]], None] = None rgb_overlay_weights: List[float] = [1.0, 1.0, 1.0] pad_to: int = 256 + expanded: bool = False class AnnotatedZStackParams(BaseModel): @@ -44,7 +46,6 @@ class RoiFilterRange(BaseModel): class RoiFilter(BaseModel): area: Union[RoiFilterRange, None] = None - solidity: Union[RoiFilterRange, None] = None class RoiSetMetaParams(BaseModel): @@ -53,20 +54,43 @@ class RoiSetMetaParams(BaseModel): class RoiSetExportParams(BaseModel): - pixel_probabilities: bool = False patches_3d: Union[PatchParams, None] = None annotated_patches_2d: Union[PatchParams, None] = None patches_2d: Union[PatchParams, None] = None - patch_masks: Union[PatchParams, None] = None annotated_zstacks: Union[AnnotatedZStackParams, None] = None object_classes: bool = False - dataframe: bool = False + derived_channels: bool = False - -def _get_label_ids(acc_seg_mask: GenericImageDataAccessor) -> InMemoryDataAccessor: - return InMemoryDataAccessor(label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16')) +def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connect_3d=True) -> InMemoryDataAccessor: + """ + Convert binary segmentation mask into either a 2D or 3D object identities map + :param acc_seg_mask: binary segmentation mask (mono) of either two or three dimensions + :param allow_3d: return a 3D map if True; return a 2D map of the mask's maximum intensity project if False + :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False + :return: object identities map + """ + if allow_3d and connect_3d: + nda_la = label( + acc_seg_mask.data[:, :, 0, :] + ).astype('uint16') + return InMemoryDataAccessor(np.expand_dims(nda_la, 2)) + elif allow_3d and not connect_3d: + nla = 0 + la_3d = np.zeros((*acc_seg_mask.hw, 1, acc_seg_mask.nz), dtype='uint16') + for zi in range(0, acc_seg_mask.nz): + la_2d = label(acc_seg_mask.data[:, :, 0, zi]).astype('uint16') + la_2d[la_2d > 0] = la_2d[la_2d > 0] + nla + nla = la_2d.max() + la_3d[:, :, 0, zi] = la_2d + return InMemoryDataAccessor(la_3d) + else: + return InMemoryDataAccessor( + label( + acc_seg_mask.data[:, :, 0, :].max(axis=-1) + ).astype('uint16') + ) def _focus_metrics(): @@ -109,9 +133,9 @@ class RoiSet(object): :param params: optional arguments that influence the definition and representation of ROIs """ assert acc_obj_ids.chroma == 1 - assert acc_obj_ids.nz == 1 self.acc_obj_ids = acc_obj_ids self.acc_raw = acc_raw + self.accs_derived = [] self.params = params self._df = self.filter_df( @@ -135,26 +159,32 @@ class RoiSet(object): :param acc_raw: accessor to raw image data :param acc_obj_ids: accessor to map of object IDs :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary) + # :param deproject: assign object's z-position based on argmax of raw data if True :return: pd.DataFrame """ # build dataframe of objects, assign z index to each object - argmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16') - df = ( - pd.DataFrame( - regionprops_table( - acc_obj_ids.data[:, :, 0, 0], - intensity_image=argmax, - properties=('label', 'area', 'intensity_mean', 'solidity', 'bbox', 'centroid') - ) - ) - .rename( - columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1', } - ) - ) - df['zi'] = df['intensity_mean'].round().astype('int') + + if acc_obj_ids.nz == 1: # deproject objects' z-coordinates from argmax of raw image + df = pd.DataFrame(regionprops_table( + acc_obj_ids.data[:, :, 0, 0], + intensity_image=acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'), + properties=('label', 'area', 'intensity_mean', 'bbox', 'centroid') + )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'}) + df['zi'] = df['intensity_mean'].round().astype('int') + + else: # objects' z-coordinates come from arg of max count in object identities map + df = pd.DataFrame(regionprops_table( + acc_obj_ids.data[:, :, 0, :], + properties=('label', 'area', 'bbox', 'centroid') + )).rename(columns={ + 'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1' + }) + df['zi'] = df['label'].apply(lambda x: (acc_obj_ids.data == x).sum(axis=(0, 1, 2)).argmax()) # compute expanded bounding boxes h, w, c, nz = acc_raw.shape + df['h'] = df['y1'] - df['y0'] + df['w'] = df['x1'] - df['x0'] ebxy, ebz = expand_box_by df['ebb_y0'] = (df.y0 - ebxy).apply(lambda x: max(x, 0)) df['ebb_y1'] = (df.y1 + ebxy).apply(lambda x: min(x, h)) @@ -176,6 +206,11 @@ class RoiSet(object): assert np.all(df['rel_y1'] <= (df['ebb_y1'] - df['ebb_y0'])) df['slice'] = df.apply( + lambda r: + np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1)], + axis=1 + ) + df['expanded_slice'] = df.apply( lambda r: np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1], axis=1 @@ -185,19 +220,18 @@ class RoiSet(object): np.s_[int(r.rel_y0): int(r.rel_y1), int(r.rel_x0): int(r.rel_x1), :, :], axis=1 ) - df['mask'] = df.apply( - lambda r: (acc_obj_ids.data == r.label)[r.y0: r.y1, r.x0: r.x1, 0, 0], + df['binary_mask'] = df.apply( + lambda r: (acc_obj_ids.data == r.label).max(axis=-1)[r.y0: r.y1, r.x0: r.x1, 0], axis=1 ) return df - @staticmethod def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: query_str = 'label > 0' # always true if filters is not None: # parse filters for k, val in filters.dict(exclude_unset=True).items(): - assert k in ('area', 'solidity') + assert k in ('area') vmin = val['min'] vmax = val['max'] assert vmin >= 0 @@ -226,18 +260,18 @@ class RoiSet(object): projected = self.acc_raw.data.max(axis=-1) return projected - def get_raw_patches(self, channel=None, pad_to=256, make_3d=False): # padded, un-annotated 2d patches - if channel: - patches_df = self.get_patches(white_channel=channel, pad_to=pad_to) + def get_patches_acc(self, channels: list = None, **kwargs) -> PatchStack: # padded, un-annotated 2d patches + if channels and len(channels) == 1: + patches_df = self.get_patches(white_channel=channels[0], **kwargs) else: - patches_df = self.get_patches(pad_to=pad_to) - patches = list(patches_df['patch']) - return PatchStack(patches) + patches_df = self.get_patches(**kwargs) + return PatchStack(list(patches_df.patch)) - def export_annotated_zstack(self, where, prefix='zstack', **kwargs): + def export_annotated_zstack(self, where, prefix='zstack', **kwargs) -> str: annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs)) - success = write_accessor_data_to_file(where / (prefix + '.tif'), annotated) - return {'location': where.__str__(), 'filename': prefix + '.tif'} + fp = where / (prefix + '.tif') + write_accessor_data_to_file(fp, annotated) + return (prefix + '.tif') def get_zmask(self, mask_type='boxes'): """ @@ -271,17 +305,56 @@ class RoiSet(object): elif mask_type == 'boxes': for roi in self: - zi_st[roi.relative_slice] = 1 + zi_st[roi.slice] = True return zi_st - def classify_by(self, name: str, channel: int, object_classification_model: InstanceSegmentationModel, ): + def classify_by( + self, name: str, channels: list[int], + object_classification_model: InstanceSegmentationModel, + derived_channel_functions: list[callable] = None + ): + """ + Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing + specified inputs through an instance segmentation classifier. Optionally derive additional inputs for object + classification by passing a raw input channel through one or more functions. + + :param name: name of column to insert + :param channels: list of nc raw input channels to send to classifier + :param object_classification_model: InstanceSegmentation model object + :param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and + return a single-channel PatchStack accessor of the same shape + :return: None + """ + + raw_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) # all channels + if derived_channel_functions is not None: + mono_data = [raw_acc.get_one_channel_data(c).data for c in range(0, raw_acc.chroma)] + for fcn in derived_channel_functions: + der = fcn(raw_acc) # returns patch stack + if der.shape != mono_data[0].shape or der.dtype not in ['uint8', 'uint16']: + raise DerivedChannelError( + f'Error processing derived channel {der} with shape {der.shape_dict} and dtype {der.dtype}' + ) + self.accs_derived.append(der) + + # combine channels + data_derived = [acc.data for acc in self.accs_derived] + input_acc = PatchStack( + np.concatenate( + [*mono_data, *data_derived], + axis=raw_acc._ga('C') + ) + ) + + else: + input_acc = raw_acc # do this on a patch basis, i.e. only one object per frame - obmap_patches = object_classification_model.label_instance_class( - self.get_raw_patches(channel=channel), - self.get_patch_masks() + obmap_patches = object_classification_model.label_patch_stack( + input_acc, + self.get_patch_masks_acc(expanded=False, pad_to=None) ) om = np.zeros(self.acc_obj_ids.shape, self.acc_obj_ids.dtype) @@ -294,27 +367,35 @@ class RoiSet(object): mask_largest_object( obmap_patches.iat(i).data ) - )[1] + )[-1] self._df.loc[roi.Index, 'classify_by_' + name] = oc om[self.acc_obj_ids.data == roi.label] = oc self.object_class_maps[name] = InMemoryDataAccessor(om) - def export_patch_masks(self, where: Path, pad_to: int = 256, prefix='mask', **kwargs) -> list: - patches_acc = self.get_patch_masks(pad_to=pad_to) - exported = [] - for i, roi in enumerate(self): # assumes index of patches_acc is same as dataframe - patch = patches_acc.iat_yxcz(i) + def export_dataframe(self, csv_path: Path) -> str: + csv_path.parent.mkdir(parents=True, exist_ok=True) + self._df.drop(['expanded_slice', 'slice', 'relative_slice', 'binary_mask'], axis=1).to_csv(csv_path, index=False) + return csv_path.name + + + def export_patch_masks(self, where: Path, pad_to: int = None, prefix='mask', expanded=False) -> pd.DataFrame: + patches_df = self.get_patch_masks(pad_to=pad_to, expanded=expanded).copy() + + def _export_patch_mask(roi): + patch = InMemoryDataAccessor(roi.patch_mask) ext = 'png' fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' write_accessor_data_to_file(where / fname, patch) - exported.append(fname) - return exported + return fname + + patches_df['patch_mask_path'] = patches_df.apply(_export_patch_mask, axis=1) + return patches_df - def export_patches(self, where: Path, prefix='patch', **kwargs): + def export_patches(self, where: Path, prefix='patch', **kwargs) -> pd.DataFrame: make_3d = kwargs.get('make_3d', False) - patches_df = self.get_patches(**kwargs) + patches_df = self.get_patches(**kwargs).copy() def _export_patch(roi): patch = InMemoryDataAccessor(roi.patch) @@ -326,41 +407,41 @@ class RoiSet(object): write_accessor_data_to_file(where / fname, resampled) else: write_accessor_data_to_file(where / fname, patch) + return fname - exported.append({ - 'df_index': roi.Index, - 'patch_filename': fname, - 'location': where.__str__(), - }) - - exported = [] - for roi in patches_df.itertuples(): # just used for label info - _export_patch(roi) + patches_df['patch_path'] = patches_df.apply(_export_patch, axis=1) + return patches_df - return exported - - def get_patch_masks(self, pad_to: int = 256) -> PatchStack: - patches = [] - for roi in self: - patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') - patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255 + def get_patch_masks(self, pad_to: int = None, expanded: bool = False) -> pd.DataFrame: + def _make_patch_mask(roi): + if expanded: + patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') + patch[roi.relative_slice][:, :, 0, 0] = roi.binary_mask * 255 + else: + patch = np.zeros((roi.y1 - roi.y0, roi.x1 - roi.x0, 1, 1), dtype='uint8') + patch[:, :, 0, 0] = roi.binary_mask * 255 if pad_to: patch = pad(patch, pad_to) + return patch - patches.append(patch) - return PatchStack(patches) + dfe = self._df.copy() + dfe['patch_mask'] = dfe.apply(lambda r: _make_patch_mask(r), axis=1) + return dfe + def get_patch_masks_acc(self, **kwargs) -> PatchStack: + return PatchStack(list(self.get_patch_masks(**kwargs).patch_mask)) def get_patches( self, rescale_clip: float = 0.0, - pad_to: int = 256, + pad_to: int = None, make_3d: bool = False, focus_metric: str = None, rgb_overlay_channels: list = None, rgb_overlay_weights: list = [1.0, 1.0, 1.0], white_channel: int = None, + expanded=False, **kwargs ) -> pd.DataFrame: @@ -388,7 +469,7 @@ class RoiSet(object): raw.data[:, :, ci, :] ) else: - if white_channel: # interpret as just a single channel + if white_channel is not None: # interpret as just a single channel assert white_channel < raw.chroma annotate_rgb = False for k in ['contour_channel', 'bounding_box_channel', 'mask_channel']: @@ -407,12 +488,17 @@ class RoiSet(object): stack = raw.data def _make_patch(roi): - patch3d = stack[roi.slice] + if expanded: + patch3d = stack[roi.expanded_slice] + subpatch = patch3d[roi.relative_slice] + else: + patch3d = stack[roi.slice] + subpatch = patch3d + ph, pw, pc, pz = patch3d.shape - subpatch = patch3d[roi.relative_slice] # make a 3d patch - if make_3d: + if make_3d or not expanded: patch = patch3d # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately @@ -433,10 +519,16 @@ class RoiSet(object): assert len(patch.shape) == 4 + mask = np.zeros(patch3d.shape[0:2], dtype=bool) + if expanded: + mask[roi.relative_slice[0:2]] = roi.binary_mask + else: + mask = roi.binary_mask + if rescale_clip is not None: patch = rescale(patch, rescale_clip) - if kwargs.get('draw_bounding_box') is True: + if kwargs.get('draw_bounding_box') is True and expanded: bci = kwargs.get('bounding_box_channel', 0) assert bci < 3 if bci > 0: @@ -450,28 +542,25 @@ class RoiSet(object): if kwargs.get('draw_mask'): mci = kwargs.get('mask_channel', 0) - mask = np.zeros(patch.shape[0:2], dtype=bool) - mask[roi.relative_slice[0:2]] = roi.mask for zi in range(0, patch.shape[3]): patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi] if kwargs.get('draw_contour'): mci = kwargs.get('contour_channel', 0) - mask = np.zeros(patch.shape[0:2], dtype=bool) - mask[roi.relative_slice[0:2]] = roi.mask for zi in range(0, patch.shape[3]): + contours = get_safe_contours(mask) patch[:, :, mci, zi] = draw_contours_on_patch( patch[:, :, mci, zi], - find_contours(mask) + contours ) - if pad_to: + if pad_to and expanded: patch = pad(patch, pad_to) return patch - dfe = self._df - dfe['patch'] = self._df.apply(lambda r: _make_patch(r), axis=1) + dfe = self._df.copy() + dfe['patch'] = dfe.apply(lambda r: _make_patch(r), axis=1) return dfe def run_exports(self, where: Path, channel, prefix, params: RoiSetExportParams) -> dict: @@ -481,12 +570,12 @@ class RoiSet(object): :param channel: color channel of products to export :param prefix: prefix of the name of each product's file or subfolder :param params: RoiSetExportParams object describing which products to export and with which parameters - :return: dict of Path objects describing the location of single-file export products + :return: nested dict of Path objects describing the location of export products """ record = {} if not self.count: return - raw_ch = self.acc_raw.get_one_channel_data(channel) + for k in params.dict().keys(): subdir = where / k pr = prefix @@ -494,37 +583,85 @@ class RoiSet(object): if kp is None: continue if k == 'patches_3d': - files = self.export_patches( + df_exp = self.export_patches( subdir, white_channel=channel, prefix=pr, make_3d=True, **kp ) + record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path] if k == 'annotated_patches_2d': - files = self.export_patches( + df_exp = self.export_patches( subdir, prefix=pr, make_3d=False, white_channel=channel, bounding_box_channel=1, bounding_box_linewidth=2, **kp, ) + record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path] if k == 'patches_2d': - files = self.export_patches( + df_exp = self.export_patches( subdir, white_channel=channel, prefix=pr, make_3d=False, **kp ) - df_patches = pd.DataFrame(files) - self._df = pd.merge(self._df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') + self._df = self._df.join(df_exp.patch_path.apply(lambda x: str(Path('patches_2d') / x))) self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1) - if k == 'patch_masks': - self.export_patch_masks(subdir, prefix=pr, **kp) + record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path] if k == 'annotated_zstacks': - self.export_annotated_zstack(subdir, prefix=pr, **kp) + record[k] = str(Path(k) / self.export_annotated_zstack(subdir, prefix=pr, **kp)) if k == 'object_classes': for kc, acc in self.object_class_maps.items(): fp = subdir / kc / (pr + '.tif') write_accessor_data_to_file(fp, acc) - record[f'{k}_{kc}'] = fp - if k == 'dataframe': - dfpa = subdir / (pr + '.csv') - dfpa.parent.mkdir(parents=True, exist_ok=True) - self._df.to_csv(dfpa, index=False) - record[k] = dfpa + record[f'{k}_{kc}'] = str(fp) + if k == 'derived_channels': + record[k] = [] + for di, dacc in enumerate(self.accs_derived): + fp = subdir / f'dc{di:01d}.tif' + fp.parent.mkdir(exist_ok=True, parents=True) + dacc.export_pyxcz(fp) + record[k].append(str(fp)) + + # export dataframe and patch masks + record = {**record, **self.serialize(where, prefix=prefix)} + return record + def serialize(self, where: Path, prefix='') -> dict: + """ + Export the minimal information needed to recreate RoiSet object, i.e. CSV data file and tight patch masks + :param where: path of directory in which to write files + :param prefix: (optional) prefix + :return: nested dict of Path objects describing the locations of export products + """ + record = {} + df_exp = self.export_patch_masks( + where / 'tight_patch_masks', + prefix=prefix, + pad_to=None, + expanded=False + ) + se_pa = df_exp.patch_mask_path.apply( + lambda x: str(Path('tight_patch_masks') / x) + ).rename('tight_patch_masks_path') + self._df = self._df.join(se_pa) + df_fn = self.export_dataframe(where / 'dataframe' / (prefix + '.csv')) + record['dataframe'] = str(Path('dataframe') / df_fn) + record['tight_patch_masks'] = list(se_pa) + return record + + @staticmethod + def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix=''): + df = pd.read_csv(where / 'dataframe' / (prefix + '.csv'))[['label', 'zi', 'y0', 'y1', 'x0', 'x1']] + + id_mask = np.zeros((*acc_raw.hw, 1, acc_raw.nz), dtype='uint16') + def _label_obj(r): + sl = np.s_[r.y0:r.y1, r.x0:r.x1, :, r.zi:r.zi + 1] + ext = 'png' + fname = f'{prefix}-la{r.label:04d}-zi{r.zi:04d}.{ext}' + try: + ma_acc = generate_file_accessor(where / 'tight_patch_masks' / fname) + bool_mask = ma_acc.data / np.iinfo(ma_acc.data.dtype).max + id_mask[sl] = id_mask[sl] + r.label * bool_mask + except Exception as e: + raise DeserializeRoiSet(e) + + df.apply(_label_obj, axis=1) + return RoiSet(acc_raw, InMemoryDataAccessor(id_mask)) + def project_stack_from_focal_points( xx: np.ndarray, @@ -573,3 +710,12 @@ def project_stack_from_focal_points( ) + +class Error(Exception): + pass + +class DeserializeRoiSet(Error): + pass + +class DerivedChannelError(Error): + pass \ No newline at end of file diff --git a/model_server/base/workflows.py b/model_server/base/workflows.py index f3719433b727e770285a53dd3f2fb55c7517dbf7..912f6062c702794dfffc10946e777b75c0b2a1bc 100644 --- a/model_server/base/workflows.py +++ b/model_server/base/workflows.py @@ -1,6 +1,7 @@ """ Implementation of image analysis work behind API endpoints, without knowledge of persistent data in server session. """ +from collections import OrderedDict from pathlib import Path from time import perf_counter from typing import Dict @@ -14,7 +15,7 @@ class Timer(object): tfunc = perf_counter def __init__(self): - self.events = {} + self.events = OrderedDict() self.last = self.tfunc() def click(self, key): diff --git a/model_server/clients/imagej/adapter.py b/model_server/clients/imagej/adapter.py index 4f977c2c7fd48a837615b573aae847be09f9374c..b03ee25df903b439667beb06e7b61c828c0e9c23 100644 --- a/model_server/clients/imagej/adapter.py +++ b/model_server/clients/imagej/adapter.py @@ -6,18 +6,20 @@ import httplib import json import urllib +from ij import IJ from ij import ImagePlus HOST = '127.0.0.1' PORT = 6221 uri = 'http://{}:{}/'.format(HOST, PORT) -def hit_endpoint(method, endpoint, params=None): +def hit_endpoint(method, endpoint, params=None, body=None): """ Python 2.7 implementation of HTTP client :param method: (str) either 'GET' or 'PUT' :param endpoint: (str) endpoint of HTTP request - :param params: (dict) of parameters required by client request + :param params: (dict) of parameters that are embedded in client request URL + :param body: (dict) of parameters that JSON-encoded and attached as payload in request :return: (dict) of response status and content, formatted as dict if request is successful """ connection = httplib.HTTPConnection(HOST, PORT) @@ -27,7 +29,7 @@ def hit_endpoint(method, endpoint, params=None): url = endpoint + '?' + urllib.urlencode(params) else: url = endpoint - connection.request(method, url) + connection.request(method, url, body=json.dumps(body)) resp = connection.getresponse() resp_str = resp.read() try: @@ -36,6 +38,28 @@ def hit_endpoint(method, endpoint, params=None): content = {'str': str(resp_str)} return {'status': resp.status, 'content': content} + +def verify_server(popup=True): + try: + resp = hit_endpoint('GET', '/') + except Exception as e: + print(e) + msg = 'Could not find server at: ' + uri + IJ.log(msg) + if popup: + IJ.error(msg) + raise e + return False + if resp['status'] != 200: + msg = 'Unknown error verifying server at: ' + uri + if popup: + IJ.error(msg) + raise Exception(msg) + return False + else: + IJ.log('Verified server is online at: ' + uri) + return True + def run_request_sequence(imp, func, params): """ Execute a sequence of client requests in the ImageJ scripting environment diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index de25566a6571b6a07f8f6e1d453e49a1b13f9c72..2dea13329513d0256baf7e9f2c698d4731301893 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -1,3 +1,4 @@ +import json import os from pathlib import Path @@ -12,8 +13,17 @@ from model_server.base.models import Model, ImageToImageModel, InstanceSegmentat class IlastikModel(Model): - def __init__(self, params, autoload=True): + def __init__(self, params, autoload=True, enforce_embedded=True): + """ + Base class for models that run via ilastik shell API + :param params: + project_file: path to ilastik project file + :param autoload: automatically load model into memory if true + :param enforce_embedded: + raise an error if all input data are not embedded in the project file, i.e. on the filesystem + """ self.project_file = Path(params['project_file']) + self.enforce_embedded = enforce_embedded params['project_file'] = self.project_file.__str__() if self.project_file.is_absolute(): pap = self.project_file @@ -42,14 +52,42 @@ class IlastikModel(Model): args.project = self.project_file_abspath.__str__() shell = app.main(args, init_logging=False) + # validate if inputs are embedded in project file + h5 = shell.projectManager.currentProjectFile + for lane in h5['Input Data/infos'].keys(): + for role in h5[f'Input Data/infos/{lane}'].keys(): + grp = h5[f'Input Data/infos/{lane}/{role}'] + if self.enforce_embedded and ('location' in grp.keys()) and grp['location'][()] != b'ProjectInternal': + raise IlastikInputEmbedding('Cannot load ilastik project file where inputs are on filesystem') + assert True if not isinstance(shell.workflow, self.get_workflow()): raise ParameterExpectedError( - f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}' + f'Ilastik project file {self.project_file} does not describe an instance of {self.__class__}' ) self.shell = shell return True + @property + def model_shape_dict(self): + raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data'] + ax = raw_info['axistags'][()] + ax_keys = [ax['key'].upper() for ax in json.loads(ax)['axes']] + shape = raw_info['shape'][()] + dd = dict(zip(ax_keys, shape)) + for ci in 'TCZ': + if ci not in dd.keys(): + dd[ci] = 1 + return dd + + @property + def model_chroma(self): + return self.model_shape_dict['C'] + + @property + def model_3d(self): + return self.model_shape_dict['Z'] > 1 + class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): model_id = 'ilastik_pixel_classification' @@ -60,7 +98,15 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): from ilastik.workflows import PixelClassificationWorkflow return PixelClassificationWorkflow - def infer(self, input_img: GenericImageDataAccessor) -> (np.ndarray, dict): + @property + def labels(self): + h5 = self.shell.projectManager.currentProjectFile + return [l.decode() for l in h5['PixelClassification/LabelNames'][()]] + + def infer(self, input_img: GenericImageDataAccessor) -> (InMemoryDataAccessor, dict): + if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): + raise IlastikInputShapeError() + tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') dsi = [ { @@ -78,6 +124,22 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): ) return InMemoryDataAccessor(data=yxcz), {'success': True} + def infer_patch_stack(self, img: PatchStack, **kwargs) -> (np.ndarray, dict): + """ + Iterative over a patch stack, call inference separately on each cropped patch + """ + from ilastik.applets.featureSelection.opFeatureSelection import FeatureSelectionConstraintError + + nc = len(self.labels) + data = np.zeros((img.count, *img.hw, nc, img.nz), dtype=float) # interpret as PYXCZ + for i in range(0, img.count): + sl = img.get_slice_at(i) + try: + data[i][sl[0], sl[1], :, sl[3]] = self.infer(img.iat(i, crop=True))[0].data + except FeatureSelectionConstraintError: # occurs occasionally on small patches + continue + return PatchStack(data), {'success': True} + def label_pixel_class(self, img: GenericImageDataAccessor, px_class: int = 0, px_prob_threshold=0.5, **kwargs): pxmap, _ = self.infer(img) mask = pxmap.data[:, :, px_class, :] > px_prob_threshold @@ -87,22 +149,43 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel): model_id = 'ilastik_object_classification_from_segmentation' + @staticmethod + def _make_8bit_mask(nda): + if nda.dtype == 'bool': + return 255 * nda.astype('uint8') + else: + return nda + @staticmethod def get_workflow(): from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary return ObjectClassificationWorkflowBinary def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): - tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') + if self.model_chroma != input_img.chroma: + raise IlastikInputShapeError( + f'Model {self} expects {self.model_chroma} input channels but received only {input_img.chroma}' + ) + if self.model_3d != input_img.is_3d(): + if self.model_3d: + raise IlastikInputShapeError(f'Model is 3D but input image is 2D') + else: + raise IlastikInputShapeError(f'Model is 2D but input image is 3D') + assert segmentation_img.is_mask() - if segmentation_img.dtype == 'bool': - seg = 255 * segmentation_img.data.astype('uint8') + if isinstance(input_img, PatchStack): + assert isinstance(segmentation_img, PatchStack) + tagged_input_data = vigra.taggedView(input_img.pczyx, 'tczyx') tagged_seg_data = vigra.taggedView( - 255 * segmentation_img.data.astype('uint8'), - 'yxcz' + self._make_8bit_mask(segmentation_img.pczyx), + 'tczyx' ) else: - tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') + tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') + tagged_seg_data = vigra.taggedView( + self._make_8bit_mask(segmentation_img.data), + 'yxcz' + ) dsi = [ { @@ -115,12 +198,21 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment assert len(obmaps) == 1, 'ilastik generated more than one object map' - yxcz = np.moveaxis( - obmaps[0], - [1, 2, 3, 0], - [0, 1, 2, 3] - ) - return InMemoryDataAccessor(data=yxcz), {'success': True} + + if isinstance(input_img, PatchStack): + pyxcz = np.moveaxis( + obmaps[0], + [0, 1, 2, 3, 4], + [0, 4, 1, 2, 3] + ) + return PatchStack(data=pyxcz), {'success': True} + else: + yxcz = np.moveaxis( + obmaps[0], + [1, 2, 3, 0], + [0, 1, 2, 3] + ) + return InMemoryDataAccessor(data=yxcz), {'success': True} def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) @@ -137,8 +229,16 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag return ObjectClassificationWorkflowPrediction def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): - tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') - tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') + if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): + raise IlastikInputShapeError() + + if isinstance(input_img, PatchStack): + assert isinstance(pxmap_img, PatchStack) + tagged_input_data = vigra.taggedView(input_img.pczyx, 'tczyx') + tagged_pxmap_data = vigra.taggedView(pxmap_img.pczyx, 'tczyx') + else: + tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') + tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') dsi = [ { @@ -151,12 +251,20 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag assert len(obmaps) == 1, 'ilastik generated more than one object map' - yxcz = np.moveaxis( - obmaps[0], - [1, 2, 3, 0], - [0, 1, 2, 3] - ) - return InMemoryDataAccessor(data=yxcz), {'success': True} + if isinstance(input_img, PatchStack): + pyxcz = np.moveaxis( + obmaps[0], + [0, 1, 2, 3, 4], + [0, 4, 1, 2, 3] + ) + return PatchStack(data=pyxcz), {'success': True} + else: + yxcz = np.moveaxis( + obmaps[0], + [1, 2, 3, 0], + [0, 1, 2, 3] + ) + return InMemoryDataAccessor(data=yxcz), {'success': True} def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs): @@ -172,48 +280,40 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag """ if not img.shape == pxmap.shape: raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape') - # TODO: check that pxmap is in-range pxch = kwargs.get('pixel_classification_channel', 0) - pxtr = kwargs('pixel_classification_threshold', 0.5) - mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr) - # super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) + pxtr = kwargs.get('pixel_classification_threshold', 0.5) + mask = img._derived_accessor(pxmap.get_one_channel_data(pxch).data > pxtr) obmap, _ = self.infer(img, mask) return obmap + def make_instance_segmentation_model(self, px_ch: int): + """ + Generate an instance segmentation model, i.e. one that takes binary masks instead of pixel probabilities as a + second input. + :param px_ch: channel of pixel probability map to use + :return: + InstanceSegmentationModel object + """ + class _Mod(self.__class__, InstanceSegmentationModel): + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + if mask.dtype == 'bool': + norm_mask = 1.0 * mask.data + else: + norm_mask = mask.data / np.iinfo(mask.dtype).max + norm_mask_acc = mask._derived_accessor(norm_mask.astype('float32')) + return super().label_instance_class(img, norm_mask_acc, pixel_classification_channel=px_ch) + return _Mod(params={'project_file': self.project_file}) -class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel): - """ - Wrap ilastik object classification for inputs comprising single-object series of raw images and binary - segmentation masks. - """ - - def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict): - assert segmentation_acc.is_mask() - if not input_acc.chroma == 1: - raise InvalidInputImageError('Object classifier expects only monochrome patches') - if not input_acc.nz == 1: - raise InvalidInputImageError('Object classifier expects only 2d patches') - tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx') - tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx') - dsi = [ - { - 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), - 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), - } - ] - - obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] +class Error(Exception): + pass - assert len(obmaps) == 1, 'ilastik generated more than one object map' - - # for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C - assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1) - pyxcz = np.moveaxis( - obmaps[0], - [0, 1, 2, 3, 4], - [0, 4, 1, 2, 3] - ) +class IlastikInputEmbedding(Error): + pass - return PatchStack(data=pyxcz), {'success': True} \ No newline at end of file +class IlastikInputShapeError(Error): + """Raised when an ilastik classifier is asked to infer on data that is incompatible with its input shape""" + pass \ No newline at end of file diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py index 3411ea7ca158a46c40d5f94943094491ad793a41..c40679c02a95bb386aaad613c26cc29e28b48055 100644 --- a/model_server/extensions/ilastik/router.py +++ b/model_server/extensions/ilastik/router.py @@ -25,14 +25,10 @@ def load_ilastik_model(model_class: ilm.IlastikModel, project_file: str, duplica if not duplicate: existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True) if existing_model_id is not None: + session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate') return {'model_id': existing_model_id} - try: - result = session.load_model(model_class, {'project_file': project_file}) - except (FileNotFoundError, ParameterExpectedError): - raise HTTPException( - status_code=404, - detail=f'Could not load project file {project_file}', - ) + result = session.load_model(model_class, {'project_file': project_file}) + session.log_info(f'Loaded ilastik model {result} from {project_file}') return {'model_id': result} @router.put('/seg/load/') @@ -60,6 +56,7 @@ def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: st channel=channel, mip=mip, ) + session.log_info(f'Completed pixel and object classification of {input_filename}') except AssertionError: raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}') return record \ No newline at end of file diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index 32dd137683a1a65a0fcbfc48535dd8a88a221213..a1a4983efbded47e54804521c04394c55795439c 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -5,12 +5,16 @@ import unittest import numpy as np from model_server.conf.testing import czifile, ilastik_classifiers, output_path, roiset_test_data -from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file +from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file from model_server.extensions.ilastik import models as ilm +from model_server.base.models import InvalidObjectLabelsError from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams from model_server.base.workflows import classify_pixels from tests.test_api import TestServerBaseClass +def _random_int(*args): + return np.random.randint(0, 2 ** 8, size=args, dtype='uint8') + class TestIlastikPixelClassification(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) @@ -83,6 +87,66 @@ class TestIlastikPixelClassification(unittest.TestCase): self.mono_image = mono_image self.mask = mask + def test_pixel_classifier_enforces_input_shape(self): + model = ilm.IlastikPixelClassifierModel( + {'project_file': ilastik_classifiers['px']} + ) + self.assertEqual(model.model_chroma, 1) + self.assertEqual(model.model_3d, False) + + # correct data + self.assertIsInstance( + model.label_pixel_class( + InMemoryDataAccessor( + _random_int(512, 256, 1, 1) + ) + ), + InMemoryDataAccessor + ) + + # raise except with input of multiple channels + with self.assertRaises(ilm.IlastikInputShapeError): + mask = model.label_pixel_class( + InMemoryDataAccessor( + _random_int(512, 256, 3, 1) + ) + ) + + # raise except with input of multiple channels + with self.assertRaises(ilm.IlastikInputShapeError): + mask = model.label_pixel_class( + InMemoryDataAccessor( + _random_int(512, 256, 1, 15) + ) + ) + + def test_ilastik_infer_pxmap_from_patchstack(self): + + def _r(h): + return np.random.randint(0, 2 ** 8, size=(h, 512, 1, 1), dtype='uint8') + + acc = PatchStack([_r(256), _r(512), _r(256)]) + self.assertEqual(acc.hw, (512, 512)) + self.assertEqual(acc.iat(0, crop=True).hw, (256, 512)) + + model = ilm.IlastikPixelClassifierModel( + {'project_file': ilastik_classifiers['px']} + ) + + mask = model.label_patch_stack(acc) + self.assertEqual(mask.dtype, bool) + self.assertEqual(mask.chroma, 1) + self.assertEqual(mask.hw, acc.hw) + self.assertEqual(mask.nz, acc.nz) + self.assertEqual(mask.count, acc.count) + + pxmap, _ = model.infer_patch_stack(acc) + self.assertEqual(pxmap.dtype, float) + self.assertEqual(pxmap.chroma, len(model.labels)) + self.assertEqual(pxmap.hw, acc.hw) + self.assertEqual(pxmap.nz, acc.nz) + self.assertEqual(pxmap.count, acc.count) + def test_run_object_classifier_from_pixel_predictions(self): self.test_run_pixel_classifier() fp = czifile['path'] @@ -97,7 +161,8 @@ class TestIlastikPixelClassification(unittest.TestCase): objmap, ) ) - self.assertEqual(objmap.data.max(), 3) + self.assertEqual(objmap.data.max(), 2) + def test_run_object_classifier_from_segmentation(self): self.test_run_pixel_classifier() @@ -113,7 +178,7 @@ class TestIlastikPixelClassification(unittest.TestCase): objmap, ) ) - self.assertEqual(objmap.data.max(), 3) + self.assertEqual(objmap.data.max(), 2) def test_ilastik_pixel_classification_as_workflow(self): result = classify_pixels( @@ -132,15 +197,15 @@ class TestIlastikOverApi(TestServerBaseClass): def test_httpexception_if_incorrect_project_file_loaded(self): resp_load = self._put( 'ilastik/seg/load/', - {'project_file': 'improper.ilp'}, + query={'project_file': 'improper.ilp'}, ) - self.assertEqual(resp_load.status_code, 404) + self.assertEqual(resp_load.status_code, 500) def test_load_ilastik_pixel_model(self): resp_load = self._put( 'ilastik/seg/load/', - {'project_file': str(ilastik_classifiers['px'])}, + query={'project_file': str(ilastik_classifiers['px'])}, ) self.assertEqual(resp_load.status_code, 200, resp_load.json()) model_id = resp_load.json()['model_id'] @@ -156,19 +221,19 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(len(resp_list_1st), 1, resp_list_1st) resp_load_2nd = self._put( 'ilastik/seg/load/', - {'project_file': str(ilastik_classifiers['px']), 'duplicate': True, }, + query={'project_file': str(ilastik_classifiers['px']), 'duplicate': True, }, ) resp_list_2nd = self._get('models').json() self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd) resp_load_3rd = self._put( 'ilastik/seg/load/', - {'project_file': str(ilastik_classifiers['px']), 'duplicate': False}, + query={'project_file': str(ilastik_classifiers['px']), 'duplicate': False}, ) resp_list_3rd = self._get('models').json() self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd) def test_no_duplicate_model_with_different_path_formats(self): - self._get('restart') + self._get('session/restart') resp_list_1 = self._get('models').json() self.assertEqual(len(resp_list_1), 0) ilp = ilastik_classifiers['px'] @@ -185,11 +250,11 @@ class TestIlastikOverApi(TestServerBaseClass): # load models with these paths resp1 = self._put( 'ilastik/seg/load/', - {'project_file': ilp_win, 'duplicate': False }, + query={'project_file': ilp_win, 'duplicate': False }, ) resp2 = self._put( 'ilastik/seg/load/', - {'project_file': ilp_posx, 'duplicate': False}, + query={'project_file': ilp_posx, 'duplicate': False}, ) self.assertEqual(resp1.json(), resp2.json()) @@ -202,7 +267,7 @@ class TestIlastikOverApi(TestServerBaseClass): def test_load_ilastik_pxmap_to_obj_model(self): resp_load = self._put( 'ilastik/pxmap_to_obj/load/', - {'project_file': str(ilastik_classifiers['pxmap_to_obj'])}, + query={'project_file': str(ilastik_classifiers['pxmap_to_obj'])}, ) model_id = resp_load.json()['model_id'] @@ -216,7 +281,7 @@ class TestIlastikOverApi(TestServerBaseClass): def test_load_ilastik_seg_to_obj_model(self): resp_load = self._put( 'ilastik/seg_to_obj/load/', - {'project_file': str(ilastik_classifiers['seg_to_obj'])}, + query={'project_file': str(ilastik_classifiers['seg_to_obj'])}, ) model_id = resp_load.json()['model_id'] @@ -233,10 +298,11 @@ class TestIlastikOverApi(TestServerBaseClass): resp_infer = self._put( f'workflows/segment', - {'model_id': model_id, 'input_filename': czifile['filename'], 'channel': 0}, + query={'model_id': model_id, 'input_filename': czifile['filename'], 'channel': 0}, ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) + def test_ilastik_infer_px_then_ob(self): self.copy_input_file_to_server() px_model_id = self.test_load_ilastik_pixel_model() @@ -244,7 +310,7 @@ class TestIlastikOverApi(TestServerBaseClass): resp_infer = self._put( 'ilastik/pixel_then_object_classification/infer/', - { + query={ 'px_model_id': px_model_id, 'ob_model_id': ob_model_id, 'input_filename': czifile['filename'], @@ -253,6 +319,7 @@ class TestIlastikOverApi(TestServerBaseClass): ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) + class TestIlastikObjectClassification(unittest.TestCase): def setUp(self): stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) @@ -269,16 +336,36 @@ class TestIlastikObjectClassification(unittest.TestCase): ) ) - self.object_classifier = ilm.PatchStackObjectClassifier( + self.classifier = ilm.IlastikObjectClassifierFromSegmentationModel( params={'project_file': ilastik_classifiers['seg_to_obj']} ) + self.raw = self.roiset.get_patches_acc() + self.masks = self.roiset.get_patch_masks_acc() + def test_classify_patches(self): - raw_patches = self.roiset.get_raw_patches() - patch_masks = self.roiset.get_patch_masks() - res_patches, _ = self.object_classifier.infer(raw_patches, patch_masks) - self.assertEqual(res_patches.count, self.roiset.count) - for pi in range(0, res_patches.count): # assert that there is only one nonzero label per patch - unique = np.unique(res_patches.iat(pi).data) - self.assertEqual(len(unique), 2) - self.assertEqual(unique[0], 0) + res = self.classifier.label_patch_stack(self.raw, self.masks) + self.assertEqual(res.count, self.roiset.count) + res.export_pyxcz(output_path / 'res_patches.tif') + for pi in range(0, res.count): # assert that there is only one nonzero label per patch + la, ct = np.unique(res.iat(pi).data, return_counts=True) + self.assertEqual(np.sum(ct > 1), 2) # exclude single-pixel anomaly + self.assertEqual(la[0], 0) + + def test_multiple_objects_in_patch(self): + # allow multiple labels in a classified patch + res1 = self.classifier.label_patch_stack(self.raw, self.masks, allow_multiple=True) + la1, cts1 = np.unique(res1.iat(2).data, return_counts=True) + self.assertGreater(len(la1[la1 > 0]), 1) + self.assertEqual(la1[0], 0) + self.assertTrue(np.all(cts1[1] > cts1[2:])) + + # raise exception if there are multiple labels in any patch + with self.assertRaises(InvalidObjectLabelsError): + res2 = self.classifier.label_patch_stack(self.raw, self.masks, allow_multiple=False, force_single=False) + + # convert all nonzero pixels to the label with highest occurrence + res3 = self.classifier.label_patch_stack(self.raw, self.masks, allow_multiple=False, force_single=True) + la3, cts3 = np.unique(res3.iat(2).data, return_counts=True) + self.assertEqual(len(la3[la3 > 0]), 1) + self.assertEqual(la3[1], la1[1]) \ No newline at end of file diff --git a/model_server/scripts/run_server.py b/model_server/scripts/run_server.py index b40ff2b399008f26f24d1d0ab5bf205650b03c95..2ca5e559019e3444c3f135450db749736e88bf29 100644 --- a/model_server/scripts/run_server.py +++ b/model_server/scripts/run_server.py @@ -27,6 +27,11 @@ def parse_args(): action='store_true', help='display extra information that is helpful for debugging' ) + parser.add_argument( + '--reload', + action='store_true', + help='automatically restart server when changes are noticed, for development purposes' + ) return parser.parse_args() @@ -41,8 +46,9 @@ def main(args, app_name='model_server.base.api:app') -> None: 'host': args.host, 'port': int(args.port), 'log_level': 'debug', + 'reload': args.reload, }, - daemon=True, + daemon=(args.reload is False), ) url = f'http://{args.host}:{int(args.port):04d}/status' print(url) diff --git a/tests/test_accessors.py b/tests/test_accessors.py index bc5b4065f314eadd66692a806cb2356eccdc28e2..11eebf8f7034c252bd75e247d65aac18e278966a 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -7,6 +7,9 @@ from model_server.base.accessors import PatchStack, make_patch_stack_from_file, from model_server.conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile, monozstackmask from model_server.base.accessors import CziImageFileAccessor, DataShapeError, generate_file_accessor, InMemoryDataAccessor, PngFileAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor +def _random_int(*args): + return np.random.randint(0, 2 ** 8, size=args, dtype='uint8') + class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: @@ -23,6 +26,7 @@ class TestCziImageFileAccess(unittest.TestCase): self.assertEqual(len(tf.data.shape), 4) self.assertEqual(tf.shape[0], tifffile['h']) self.assertEqual(tf.shape[1], tifffile['w']) + self.assertEqual(tf.get_axis('x'), 1) def test_czifile_is_correct_shape(self): cf = CziImageFileAccessor(czifile['path']) @@ -40,7 +44,7 @@ class TestCziImageFileAccess(unittest.TestCase): nc = 4 nz = 11 c = 3 - cf = InMemoryDataAccessor(np.random.rand(h, w, nc, nz)) + cf = InMemoryDataAccessor(_random_int(h, w, nc, nz)) sc = cf.get_one_channel_data(c) self.assertEqual(sc.shape, (h, w, 1, nz)) @@ -70,7 +74,7 @@ class TestCziImageFileAccess(unittest.TestCase): def test_conform_data_shorter_than_xycz(self): h = 256 w = 512 - data = np.random.rand(h, w, 1) + data = _random_int(h, w, 1) acc = InMemoryDataAccessor(data) self.assertEqual( InMemoryDataAccessor.conform_data(data).shape, @@ -82,7 +86,7 @@ class TestCziImageFileAccess(unittest.TestCase): ) def test_conform_data_longer_than_xycz(self): - data = np.random.rand(256, 512, 12, 8, 3) + data = _random_int(256, 512, 12, 8, 3) with self.assertRaises(DataShapeError): acc = InMemoryDataAccessor(data) @@ -93,7 +97,7 @@ class TestCziImageFileAccess(unittest.TestCase): c = 3 nz = 10 - yxcz = (2**8 * np.random.rand(h, w, c, nz)).astype('uint8') + yxcz = _random_int(h, w, c, nz) acc = InMemoryDataAccessor(yxcz) fp = output_path / f'rand3d.tif' self.assertTrue( @@ -138,16 +142,18 @@ class TestPatchStackAccessor(unittest.TestCase): w = 256 h = 512 n = 4 - acc = PatchStack(np.random.rand(n, h, w, 1, 1)) + acc = PatchStack(_random_int(n, h, w, 1, 1)) self.assertEqual(acc.count, n) self.assertEqual(acc.hw, (h, w)) self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1)) + self.assertEqual(acc.shape[1:], acc.iat(0, crop=True).shape) + def test_make_patch_stack_from_list(self): w = 256 h = 512 n = 4 - acc = PatchStack([np.random.rand(h, w, 1, 1) for _ in range(0, n)]) + acc = PatchStack([_random_int(h, w, 1, 1) for _ in range(0, n)]) self.assertEqual(acc.count, n) self.assertEqual(acc.hw, (h, w)) self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1)) @@ -176,8 +182,8 @@ class TestPatchStackAccessor(unittest.TestCase): nz = 5 n = 4 - patches = [np.random.rand(h, w, c, nz) for _ in range(0, n)] - patches.append(np.random.rand(h, 2 * w, c, nz)) + patches = [_random_int(h, w, c, nz) for _ in range(0, n)] + patches.append(_random_int(h, 2 * w, c, nz)) acc = PatchStack(patches) self.assertEqual(acc.count, n + 1) self.assertEqual(acc.hw, (h, 2 * w)) @@ -185,13 +191,70 @@ class TestPatchStackAccessor(unittest.TestCase): self.assertEqual(acc.iat(0).shape, (h, 2 * w, c, nz)) self.assertEqual(acc.iat_yxcz(0).shape, (h, 2 * w, c, nz)) + # test that initial patches are maintained + for i in range(0, acc.count): + self.assertEqual(patches[i].shape, acc.iat(i, crop=True).shape) + self.assertEqual(acc.shape[1:], acc.iat(i, crop=False).shape) + + def test_make_3d_patch_stack_from_list_force_long_dim(self): + def _r(h, w): + return np.random.randint(0, 2 ** 8, size=(h, w, 1, 1), dtype='uint8') + patches = [_r(256, 128), _r(128, 256), _r(512, 10), _r(10, 512)] + + acc_ref = PatchStack(patches, force_ydim_longest=False) + self.assertEqual(acc_ref.hw, (512, 512)) + self.assertEqual(acc_ref.iat(-1, crop=False).hw, (512, 512)) + self.assertEqual(acc_ref.iat(-1, crop=True).hw, (10, 512)) + + acc_rot = PatchStack(patches, force_ydim_longest=True) + self.assertEqual(acc_rot.hw, (512, 128)) + self.assertEqual(acc_rot.iat(-1, crop=False).hw, (512, 128)) + self.assertEqual(acc_rot.iat(-1, crop=True).hw, (512, 10)) + + nda_rot_rot = np.rot90(acc_rot.iat(-1, crop=True).data, axes=(1, 0)) + nda_ref = acc_ref.iat(-1, crop=True).data + self.assertTrue(np.all(nda_ref == nda_rot_rot)) + + self.assertLess(acc_rot.data.size, acc_ref.data.size) + def test_pczyx(self): w = 256 h = 512 n = 4 nz = 15 - nc = 2 - acc = PatchStack(np.random.rand(n, h, w, nc, nz)) + nc = 3 + acc = PatchStack(_random_int(n, h, w, nc, nz)) self.assertEqual(acc.count, n) self.assertEqual(acc.pczyx.shape, (n, nc, nz, h, w)) self.assertEqual(acc.hw, (h, w)) + return acc + + def test_get_one_channel(self): + acc = self.test_pczyx() + mono = acc.get_one_channel_data(channel=1) + for a in 'PXYZ': + self.assertEqual(mono.shape_dict[a], acc.shape_dict[a]) + self.assertEqual(mono.shape_dict['C'], 1) + + def test_get_multiple_channels(self): + acc = self.test_pczyx() + channels = [0, 1] + mcacc = acc.get_channels(channels=channels) + for a in 'PXYZ': + self.assertEqual(mcacc.shape_dict[a], acc.shape_dict[a]) + self.assertEqual(mcacc.shape_dict['C'], len(channels)) + + def test_get_one_channel_mip(self): + acc = self.test_pczyx() + mono_mip = acc.get_one_channel_data(channel=1, mip=True) + for a in 'PXY': + self.assertEqual(mono_mip.shape_dict[a], acc.shape_dict[a]) + for a in 'CZ': + self.assertEqual(mono_mip.shape_dict[a], 1) + + def test_export_pczyx_patch_hyperstack(self): + acc = self.test_pczyx() + fp = output_path / 'patch_hyperstack.tif' + acc.export_pyxcz(fp) + acc2 = make_patch_stack_from_file(fp) + self.assertEqual(acc.shape, acc2.shape) \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py index 1b201440da4504e16f10573d31747603926e3b17..13c62c36466ad86ec220199fef95b7a46ece11b9 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,3 +1,5 @@ +import json + from multiprocessing import Process from pathlib import Path import requests @@ -36,8 +38,12 @@ class TestServerBaseClass(unittest.TestCase): def _get(self, endpoint): return self._get_sesh().get(self.uri + endpoint) - def _put(self, endpoint, params=None): - return self._get_sesh().put(self.uri + endpoint, params=params) + def _put(self, endpoint, query=None, body=None): + return self._get_sesh().put( + self.uri + endpoint, + params=query, + data=json.dumps(body) + ) def copy_input_file_to_server(self): from shutil import copyfile @@ -59,10 +65,10 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(resp.status_code, 200) def test_bounceback_parameters(self): - resp = self._put('bounce_back', {'par1': 'hello'}) + resp = self._put('bounce_back', body={'par1': 'hello', 'par2': ['ab', 'cd']}) self.assertEqual(resp.status_code, 200, resp.json()) self.assertEqual(resp.json()['params']['par1'], 'hello', resp.json()) - self.assertEqual(resp.json()['params']['par2'], None, resp.json()) + self.assertEqual(resp.json()['params']['par2'], ['ab', 'cd'], resp.json()) def test_default_session_paths(self): import model_server.conf.defaults @@ -103,7 +109,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): resp = self._put( f'infer/from_image_file', - {'model_id': model_id, 'input_filename': 'not_a_real_file.name'} + query={'model_id': model_id, 'input_filename': 'not_a_real_file.name'} ) self.assertEqual(resp.status_code, 404, resp.content.decode()) @@ -112,7 +118,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): model_id = 'not_a_real_model' resp = self._put( f'workflows/segment', - {'model_id': model_id, 'input_filename': 'not_a_real_file.name'} + query={'model_id': model_id, 'input_filename': 'not_a_real_file.name'} ) self.assertEqual(resp.status_code, 409, resp.content.decode()) @@ -121,7 +127,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.copy_input_file_to_server() resp_infer = self._put( f'workflows/segment', - { + query={ 'model_id': model_id, 'input_filename': czifile['filename'], 'channel': 2, @@ -145,7 +151,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): resp_inpath = self._get('paths') resp_change = self._put( f'paths/watch_output', - {'path': resp_inpath.json()['inbound_images']} + query={'path': resp_inpath.json()['inbound_images']} ) self.assertEqual(resp_change.status_code, 200) resp_check = self._get('paths') @@ -156,7 +162,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): fakepath = 'c:/fake/path/to/nowhere' resp_change = self._put( f'paths/watch_output', - {'path': fakepath} + query={'path': fakepath} ) self.assertEqual(resp_change.status_code, 404) self.assertIn(fakepath, resp_change.json()['detail']) @@ -167,7 +173,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): resp_inpath = self._get('paths') resp_change = self._put( f'paths/watch_output', - {'path': resp_inpath.json()['outbound_images']} + query={'path': resp_inpath.json()['outbound_images']} ) self.assertEqual(resp_change.status_code, 200) resp_check = self._get('paths') diff --git a/tests/test_process.py b/tests/test_process.py index 569ac1b797437381df4d80e3da749447ea52e91f..d2fb33b9cc9a8b6af04c8eacd99b6b87d7d64527 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -1,9 +1,10 @@ import unittest import numpy as np +from skimage.measure import find_contours -from model_server.base.process import mask_largest_object -from model_server.base.process import pad +from model_server.base.annotators import draw_contours_on_patch +from model_server.base.process import get_safe_contours, mask_largest_object, pad class TestProcessingUtilityMethods(unittest.TestCase): def setUp(self) -> None: @@ -56,3 +57,26 @@ class TestMaskLargestObject(unittest.TestCase): self.assertTrue(np.all(np.unique(masked) == [0, 255])) self.assertTrue(np.all(masked[:, 3:5] == 0)) self.assertTrue(np.all(masked[3:5, :] == 0)) + + +class TestSafeContours(unittest.TestCase): + def setUp(self) -> None: + self.patch = np.ones((10, 20), dtype='uint8') + self.mask_ref = np.zeros((10, 20), dtype=bool) + self.mask_ref[0:5, 0:10] = True + self.mask_test = np.ones((1, 20), dtype=bool) + + def test_contours_on_compliant_mask(self): + con = get_safe_contours(self.mask_ref) + patch = self.patch.copy() + self.assertEqual((patch == 0).sum(), 0) + patch = draw_contours_on_patch(patch, con) + self.assertEqual((patch == 0).sum(), 14) + + def test_contours_on_noncompliant_mask(self): + con = get_safe_contours(self.mask_test) + patch = self.patch.copy() + self.assertEqual((patch == 0).sum(), 0) + patch = draw_contours_on_patch(self.patch, con) + self.assertEqual((patch == 0).sum(), 20) + self.assertEqual((patch[0, :] == 0).sum(), 20) \ No newline at end of file diff --git a/tests/test_roiset.py b/tests/test_roiset.py index 3040e6d251305731d0c3dcad5b3727a1bb48b614..4785688835c71d1f26ae1b95362b1d3403d83262 100644 --- a/tests/test_roiset.py +++ b/tests/test_roiset.py @@ -1,13 +1,17 @@ +import os +import re import unittest import numpy as np from pathlib import Path +import pandas as pd + from model_server.conf.testing import output_path, roiset_test_data from model_server.base.roiset import RoiSetExportParams, RoiSetMetaParams from model_server.base.roiset import _get_label_ids, RoiSet -from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file +from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack from model_server.base.models import DummyInstanceSegmentationModel class BaseTestRoiSetMonoProducts(object): @@ -29,31 +33,36 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): params=RoiSetMetaParams( mask_type=mask_type, filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}), - expand_box_by=(64, 2) + expand_box_by=(128, 2) ) ) return roiset def test_roi_mask_shape(self, **kwargs): roiset = self._make_roi_set(**kwargs) + + # all masks' bounding boxes are at least as big as ROI area + for roi in roiset.get_df().itertuples(): + self.assertEqual(roi.binary_mask.dtype, 'bool') + sh = roi.binary_mask.shape + self.assertEqual(sh, (roi.h, roi.w)) + self.assertGreaterEqual(sh[0] * sh[1], roi.area) + + def test_roi_zmask(self, **kwargs): + roiset = self._make_roi_set(**kwargs) zmask = roiset.get_zmask() zmask_acc = InMemoryDataAccessor(zmask) self.assertTrue(zmask_acc.is_mask()) # assert dimensionality of zmask - self.assertGreater(zmask_acc.shape_dict['Z'], 1) - self.assertEqual(zmask_acc.shape_dict['C'], 1) + self.assertEqual(zmask_acc.nz, roiset.acc_raw.nz) + self.assertEqual(zmask_acc.chroma, 1) write_accessor_data_to_file(output_path / 'zmask.tif', zmask_acc) # mask values are not just all True or all False self.assertTrue(np.any(zmask)) self.assertFalse(np.all(zmask)) - # assert non-trivial meta info in boxes - self.assertGreater(roiset.count, 1) - sh = roiset.get_df().iloc[1]['mask'].shape - ar = roiset.get_df().iloc[1]['area'] - self.assertGreaterEqual(sh[0] * sh[1], ar) def test_roiset_from_non_zstacks(self, **kwargs): acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0]) @@ -73,37 +82,79 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) + def test_dataframe_and_mask_array_in_iterator(self): + roiset = self._make_roi_set() + for roi in roiset: + ma = roi.binary_mask + self.assertEqual(ma.dtype, 'bool') + self.assertEqual(ma.shape, (roi.h, roi.w)) + def test_rel_slices_are_valid(self): roiset = self._make_roi_set() for roi in roiset: - ebb = roiset.acc_raw.data[roi.slice] + ebb = roiset.acc_raw.data[roi.expanded_slice] self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) rbb = ebb[roi.relative_slice] self.assertEqual(len(rbb.shape), 4) self.assertTrue(np.all([si >= 1 for si in rbb.shape])) - def test_make_2d_patches(self): + + def test_make_expanded_2d_patches(self): roiset = self._make_roi_set() - files = roiset.export_patches( - output_path / '2d_patches', + where = output_path / 'expanded_2d_patches' + df_res = roiset.export_patches( + where, draw_bounding_box=True, + expanded=True, + pad_to=256, ) - self.assertGreaterEqual(len(files), 1) - - def test_make_3d_patches(self): + df = roiset.get_df() + for f in df_res.patch_path: + acc = generate_file_accessor(where / f) + la = int(re.search(r'la([\d]+)', str(f)).group(1)) + roi_q = df.loc[df.label == la, :] + self.assertEqual(len(roi_q), 1) + self.assertEqual((256, 256), acc.hw) + + def test_make_tight_2d_patches(self): + roiset = self._make_roi_set() + where = output_path / 'tight_2d_patches' + df_res = roiset.export_patches( + where, + draw_bounding_box=True, + expanded=False + ) + df = roiset.get_df() + for f in df_res.patch_path: # all exported files are same shape as bounding boxes in RoiSet's datatable + acc = generate_file_accessor(where / f) + la = int(re.search(r'la([\d]+)', str(f)).group(1)) + roi_q = df.loc[df.label == la, :] + self.assertEqual(len(roi_q), 1) + roi = roi_q.iloc[0] + self.assertEqual((roi.h, roi.w), acc.hw) + + def test_make_expanded_3d_patches(self): roiset = self._make_roi_set() - files = roiset.export_patches( - output_path / '3d_patches', - make_3d=True) - self.assertGreaterEqual(len(files), 1) + where = output_path / '3d_patches' + df_res = roiset.export_patches( + where, + make_3d=True, + expanded=True + ) + self.assertGreaterEqual(len(df_res), 1) + for f in df_res.patch_path: + acc = generate_file_accessor(where / f) + self.assertGreater(acc.nz, 1) + def test_export_annotated_zstack(self): roiset = self._make_roi_set() + where = output_path / 'annotated_zstack' file = roiset.export_annotated_zstack( - output_path / 'annotated_zstack', + where, ) - result = generate_file_accessor(Path(file['location']) / file['filename']) + result = generate_file_accessor(where / file) self.assertEqual(result.shape, roiset.acc_raw.shape) def test_flatten_image(self): @@ -131,16 +182,58 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_make_binary_masks(self): roiset = self._make_roi_set() - files = roiset.export_patch_masks(output_path / '2d_mask_patches', ) - self.assertGreaterEqual(len(files), 1) + df_res = roiset.export_patch_masks(output_path / '2d_mask_patches', ) + + df = roiset.get_df() + for f in df_res.patch_mask_path: # all exported files are same shape as bounding boxes in RoiSet's datatable + acc = generate_file_accessor(output_path / '2d_mask_patches' / f) + la = int(re.search(r'la([\d]+)', str(f)).group(1)) + roi_q = df.loc[df.label == la, :] + self.assertEqual(len(roi_q), 1) + roi = roi_q.iloc[0] + self.assertEqual((roi.h, roi.w), acc.hw) def test_classify_by(self): roiset = self._make_roi_set() - roiset.classify_by('dummy_class', 0, DummyInstanceSegmentationModel()) + roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel()) self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) return roiset + def test_classify_by_multiple_channels(self): + roiset = self._make_roi_set() + roiset.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel()) + self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) + self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) + return roiset + + def test_classify_by_with_derived_channel(self): + class ModelWithDerivedInputs(DummyInstanceSegmentationModel): + def infer(self, img, mask): + return PatchStack(super().infer(img, mask).data * img.chroma) + + roiset = self._make_roi_set() + roiset.classify_by( + 'multiple_input_model', + [0, 1], + ModelWithDerivedInputs(), + derived_channel_functions=[ + lambda acc: PatchStack(2 * acc.data), + lambda acc: PatchStack((0.5 * acc.data).astype('uint8')) + ] + ) + self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [3])) + self.assertTrue(all(np.unique(roiset.object_class_maps['multiple_input_model'].data) == [0, 3])) + + self.assertEqual(len(roiset.accs_derived), 2) + for di in roiset.accs_derived: + self.assertEqual(roiset.get_patches_acc().shape, di.shape) + + dpas = roiset.run_exports(output_path / 'derived_channels', 0, 'der', RoiSetExportParams(derived_channels=True)) + for fp in dpas['derived_channels']: + assert Path(fp).exists() + return roiset + def test_export_object_classes(self): record = self.test_classify_by().run_exports( output_path / 'object_class_maps', @@ -156,17 +249,21 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_raw_patches_are_correct_shape(self): roiset = self._make_roi_set() - patches = roiset.get_raw_patches() + patches = roiset.get_patches_acc() np, h, w, nc, nz = patches.shape self.assertEqual(np, roiset.count) self.assertEqual(nc, roiset.acc_raw.chroma) + self.assertEqual(nz, 1) def test_patch_masks_are_correct_shape(self): roiset = self._make_roi_set() - patch_masks = roiset.get_patch_masks() - np, h, w, nc, nz = patch_masks.shape - self.assertEqual(np, roiset.count) - self.assertEqual(nc, 1) + df_patch_masks = roiset.get_patch_masks() + for roi in df_patch_masks.itertuples(): + h, w, nc, nz = roi.patch_mask.shape + self.assertEqual(nc, 1) + self.assertEqual(nz, 1) + self.assertEqual(h, roi.h) + self.assertEqual(w, roi.w) class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): @@ -185,71 +282,307 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa ) def test_multichannel_to_mono_2d_patches(self): - files = self.roiset.export_patches( - output_path / 'multichannel' / 'mono_2d_patches', - white_channel=3, + where = output_path / 'multichannel' / 'mono_2d_patches' + df_res = self.roiset.export_patches( + where, + white_channel=0, draw_bounding_box=True, + expanded=True, + pad_to=256, ) - result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + result = generate_file_accessor(where / df_res.patch_path.iloc[0]) self.assertEqual(result.chroma, 1) def test_multichannnel_to_mono_2d_patches_rgb_bbox(self): - files = self.roiset.export_patches( - output_path / 'multichannel' / 'mono_2d_patches_rgb_bbox', + where = output_path / 'multichannel' / 'mono_2d_patches_rgb_bbox' + df_res = self.roiset.export_patches( + where, white_channel=3, draw_bounding_box=True, bounding_box_channel=1, + expanded=True, + pad_to=256, ) - result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + result = generate_file_accessor(where / df_res.patch_path.iloc[0]) self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_bbox(self): - files = self.roiset.export_patches( - output_path / 'multichannel' / 'rgb_2d_patches_bbox', + where = output_path / 'multichannel' / 'rgb_2d_patches_bbox' + df_res = self.roiset.export_patches( + where, + white_channel=4, + rgb_overlay_channels=(3, None, None), + draw_mask=False, + draw_bounding_box=True, + bounding_box_channel=1, + rgb_overlay_weights=(0.1, 1.0, 1.0), + expanded=True, + pad_to=256, + ) + result = generate_file_accessor(where / df_res.patch_path.iloc[0]) + self.assertEqual(result.chroma, 3) + + def test_multichannnel_to_rgb_2d_patches_mask(self): + where = output_path / 'multichannel' / 'rgb_2d_patches_mask' + df_res = self.roiset.export_patches( + where, white_channel=4, rgb_overlay_channels=(3, None, None), draw_mask=True, mask_channel=0, - rgb_overlay_weights=(0.1, 1.0, 1.0) + rgb_overlay_weights=(0.1, 1.0, 1.0), + expanded=True, + pad_to=256, ) - result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + result = generate_file_accessor(where / df_res.patch_path.iloc[0]) self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_contour(self): - files = self.roiset.export_patches( - output_path / 'multichannel' / 'rgb_2d_patches_contour', + where = output_path / 'multichannel' / 'rgb_2d_patches_contour' + df_res = self.roiset.export_patches( + where, rgb_overlay_channels=(3, None, None), draw_contour=True, contour_channel=1, - rgb_overlay_weights=(0.1, 1.0, 1.0) + rgb_overlay_weights=(0.1, 1.0, 1.0), + expanded=True, + pad_to=256, ) - result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + result = generate_file_accessor(where / df_res.patch_path.iloc[0]) self.assertEqual(result.chroma, 3) self.assertEqual(result.get_one_channel_data(2).data.max(), 0) # blue channel is black def test_multichannel_to_multichannel_tif_patches(self): - files = self.roiset.export_patches( - output_path / 'multichannel' / 'multichannel_tif_patches', + where = output_path / 'multichannel' / 'multichannel_tif_patches' + df_res = self.roiset.export_patches( + where, + expanded=True, + pad_to=256, ) - result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + result = generate_file_accessor(where / df_res.patch_path.iloc[0]) self.assertEqual(result.chroma, 5) + self.assertEqual(result.nz, 1) def test_multichannel_annotated_zstack(self): + where = output_path / 'multichannel' / 'annotated_zstack' file = self.roiset.export_annotated_zstack( - output_path / 'multichannel' / 'annotated_zstack', + where, 'test_multichannel_annotated_zstack', + expanded=True, + pad_to=256, ) - result = generate_file_accessor(Path(file['location']) / file['filename']) + result = generate_file_accessor(where / file) self.assertEqual(result.chroma, self.stack.chroma) self.assertEqual(result.nz, self.stack.nz) def test_export_single_channel_annotated_zstack(self): + where = output_path / 'annotated_zstack' file = self.roiset.export_annotated_zstack( - output_path / 'annotated_zstack', + where, channel=3, + expanded=True, + pad_to=256, ) - result = generate_file_accessor(Path(file['location']) / file['filename']) + result = generate_file_accessor(where / file) self.assertEqual(result.hw, self.roiset.acc_raw.hw) self.assertEqual(result.nz, self.roiset.acc_raw.nz) self.assertEqual(result.chroma, 1) + def test_run_exports(self): + p = RoiSetExportParams(**{ + 'patches_3d': {}, + 'annotated_patches_2d': { + 'draw_bounding_box': True, + 'rgb_overlay_channels': [3, None, None], + 'rgb_overlay_weights': [0.2, 1.0, 1.0], + 'pad_to': 512, + }, + 'patches_2d': { + 'draw_bounding_box': False, + 'draw_mask': False, + }, + 'patch_masks': { + 'pad_to': 256, + }, + 'annotated_zstacks': {}, + 'object_classes': True, + 'dataframe': True, + }) + + where = output_path / 'run_exports' + res = self.roiset.run_exports( + where, + channel=3, + prefix='test', + params=p + ) + + # test on return paths + for k, v in res.items(): + if isinstance(v, list): + for f in v: + self.assertFalse(Path(f).is_absolute()) + self.assertTrue((where / f).exists()) + else: + self.assertFalse(Path(v).is_absolute()) + self.assertTrue((where / v).exists()) + + # test on paths in CSV + test_df = pd.read_csv(where / res['dataframe']) + for c in ['tight_patch_masks_path', 'patch_path']: + self.assertTrue(c in test_df.columns) + for f in test_df[c]: + self.assertTrue((where / f).exists(), where / f) + + def test_run_export_expanded_2d_patch(self): + p = RoiSetExportParams(**{ + 'patches_2d': { + 'draw_bounding_box': False, + 'draw_mask': False, + 'expanded': True, + 'pad_to': 256, + }, + }) + self.assertTrue(hasattr(p.patches_2d, 'pad_to')) + self.assertTrue(hasattr(p.patches_2d, 'expanded')) + + where = output_path / 'run_exports_expanded_2d_patch' + res = self.roiset.run_exports( + where, + channel=-1, + prefix='test', + params=p + ) + + # test that exported patches are padded dimension + for fn in res['patches_2d']: + pa = where / fn + self.assertTrue(pa.exists()) + pacc = generate_file_accessor(pa) + self.assertEqual(pacc.hw, (256, 256)) + print('res') + + def test_run_export_mono_2d_patch(self): + p = RoiSetExportParams(**{ + 'patches_2d': { + 'draw_bounding_box': False, + 'draw_mask': False, + 'expanded': True, + 'pad_to': 256, + 'rgb_overlay_channels': None, + }, + }) + self.assertTrue(hasattr(p.patches_2d, 'pad_to')) + self.assertTrue(hasattr(p.patches_2d, 'expanded')) + + where = output_path / 'run_exports_mono_2d_patch' + res = self.roiset.run_exports( + where, + channel=-1, + prefix='test', + params=p + ) + + # test that exported patches are padded dimension + for fn in res['patches_2d']: + pa = where / fn + self.assertTrue(pa.exists()) + pacc = generate_file_accessor(pa) + self.assertEqual(pacc.chroma, 1) + print('res') + + + +class TestRoiSetSerialization(unittest.TestCase): + + def setUp(self) -> None: + # set up test raw data and segmentation from file + self.stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) + self.stack_ch_pa = self.stack.get_one_channel_data(roiset_test_data['pipeline_params']['segmentation_channel']) + self.seg_mask_3d = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path_3d']) + + @staticmethod + def _label_is_2d(id_map, la): # single label's zmask has same counts as its MIP + mask_3d = (id_map == la) + mask_mip = mask_3d.max(axis=-1) + return mask_3d.sum() == mask_mip.sum() + + def test_id_map_connects_z(self): + id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=True) + labels = np.unique(id_map.data)[1:] + is_2d = all([self._label_is_2d(id_map.data, la) for la in labels]) + self.assertFalse(is_2d) + + def test_id_map_disconnects_z(self): + id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) + labels = np.unique(id_map.data)[1:] + is_2d = all([self._label_is_2d(id_map.data, la) for la in labels]) + self.assertTrue(is_2d) + + def test_create_roiset_from_3d_obj_ids(self): + id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) + self.assertEqual(self.stack_ch_pa.shape, id_map.shape) + + roiset = RoiSet( + self.stack_ch_pa, + id_map, + params=RoiSetMetaParams(mask_type='contours') + ) + self.assertEqual(roiset.count, id_map.data.max()) + self.assertGreater(len(roiset.get_df()['zi'].unique()), 1) + + def test_create_roiset_from_2d_obj_ids(self): + id_map = _get_label_ids(self.seg_mask_3d, allow_3d=False) + self.assertEqual(self.stack_ch_pa.shape[0:3], id_map.shape[0:3]) + self.assertEqual(id_map.nz, 1) + + roiset = RoiSet( + self.stack_ch_pa, + id_map, + params=RoiSetMetaParams(mask_type='contours') + ) + self.assertEqual(roiset.count, id_map.data.max()) + self.assertGreater(len(roiset.get_df()['zi'].unique()), 1) + return roiset + + def test_create_roiset_from_df_and_patch_masks(self): + ref_roiset = self.test_create_roiset_from_2d_obj_ids() + where_ser = output_path / 'serialize' + ref_roiset.serialize(where_ser, prefix='ref') + where_df = where_ser / 'dataframe' / 'ref.csv' + self.assertTrue(where_df.exists()) + df_test = pd.read_csv(where_df) + + # check that patches are correct size + where_patch_masks = where_ser / 'tight_patch_masks' + patch_filenames = [] + for pmf in where_patch_masks.iterdir(): + self.assertTrue(pmf.suffix.upper() == '.PNG') + la = int(re.search(r'la([\d]+)', str(pmf)).group(1)) + roi_q = df_test.loc[df_test.label == la, :] + self.assertEqual(len(roi_q), 1) + roi = roi_q.iloc[0] + m_acc = generate_file_accessor(pmf) + self.assertEqual((roi.h, roi.w), m_acc.hw) + patch_filenames.append(pmf.name) + + # make another RoiSet from just the data table, raw images, and (tight) patch masks + test_roiset = RoiSet.deserialize(self.stack_ch_pa, where_ser, prefix='ref') + self.assertEqual(ref_roiset.get_zmask().shape, test_roiset.get_zmask().shape,) + self.assertTrue((ref_roiset.get_zmask() == test_roiset.get_zmask()).all()) + self.assertTrue(np.all(test_roiset.get_df().label == ref_roiset.get_df().label)) + cols = ['label', 'y1', 'y0', 'x1', 'x0', 'zi'] + self.assertTrue((test_roiset.get_df()[cols] == ref_roiset.get_df()[cols]).all().all()) + + # re-serialize and check that patch masks are the same + where_dser = output_path / 'deserialize' + test_roiset.serialize(where_dser, prefix='test') + for fr in patch_filenames: + pr = (where_ser / 'tight_patch_masks' / fr) + self.assertTrue(pr.exists()) + pt = (where_dser / 'tight_patch_masks' / fr.replace('ref', 'test')) + self.assertTrue(pt.exists()) + r_acc = generate_file_accessor(pr) + t_acc = generate_file_accessor(pt) + self.assertTrue(np.all(r_acc.data == t_acc.data)) +