diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index 712bbacf19296201775e1d7d49ef45b356ac1d1d..4b91eb5cda12dc79c3b1bd300c5630d1370d2c91 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -3,11 +3,12 @@ import os from pathlib import Path import numpy as np -from skimage.io import imread +from skimage.io import imread, imsave import czifile import tifffile +from base.process import make_rgb from model_server.base.process import is_mask class GenericImageDataAccessor(ABC): @@ -183,29 +184,57 @@ class CziImageFileAccessor(GenericImageFileAccessor): return sc -def write_accessor_data_to_file(fpath: Path, accessor: GenericImageDataAccessor, mkdir=True) -> bool: +def write_accessor_data_to_file(fpath: Path, acc: GenericImageDataAccessor, mkdir=True) -> bool: + """ + Export an image accessor to file. + :param fpath: complete path including filename and extension + :param acc: image accessor to be written + :param mkdir: create any needed subdirectories in fpath if True + :return: True + """ + if 'P' in acc.shape_dict.keys(): + raise FileWriteError(f'Can only write single-position accessor to file') + ext = fpath.suffix.upper() + if mkdir: fpath.parent.mkdir(parents=True, exist_ok=True) - try: + + if ext == '.PNG': + if acc.dtype != 'uint8': + raise FileWriteError(f'Invalid data type {acc.dtype}') + if acc.chroma == 1: + data = acc.data[:, :, 0, 0] + elif acc.chroma == 2: # add a blank blue channel + data = make_rgb(acc.data) + else: # preserve RGB order + data = acc.data[:, :, :, 0] + imsave(fpath, data, check_contrast=False) + return True + + elif ext in ['.TIF', '.TIFF']: zcyx= np.moveaxis( - accessor.data, # yxcz + acc.data, # yxcz [3, 2, 0, 1], [0, 1, 2, 3] ) - if accessor.is_mask(): - if accessor.dtype == 'bool': + if acc.is_mask(): + if acc.dtype == 'bool': data = (zcyx * 255).astype('uint8') else: data = zcyx.astype('uint8') tifffile.imwrite(fpath, data, imagej=True) else: tifffile.imwrite(fpath, zcyx, imagej=True) - except: - raise FileWriteError(f'Unable to write data to file') + else: + raise FileWriteError(f'Unable to write data to file of extension {ext}') return True def generate_file_accessor(fpath): + """ + Given an image file path, return an image accessor, assuming the file is a supported format and represents + a single position array, which may be single or multi-channel, single plane or z-stack. + """ if str(fpath).upper().endswith('.TIF') or str(fpath).upper().endswith('.TIFF'): return TifSingleSeriesFileAccessor(fpath) elif str(fpath).upper().endswith('.CZI'): @@ -215,6 +244,86 @@ def generate_file_accessor(fpath): else: raise FileAccessorError(f'Could not match a file accessor with {fpath}') + +class PatchStack(InMemoryDataAccessor): + + def __init__(self, data): + """ + A sequence of n (generally) color 3D images of the same size + :param data: either a list of np.ndarrays of size YXCZ, or np.ndarray of size PYXCZ + """ + + if isinstance(data, list): # list of YXCZ patches + n = len(data) + yxcz_shape = np.array([e.shape for e in data]).max(axis=0) + nda = np.zeros( + (n, *yxcz_shape), dtype=data[0].dtype + ) + for i in range(0, len(data)): + s = tuple([slice(0, c) for c in data[i].shape]) + nda[i][s] = data[i] + + elif isinstance(data, np.ndarray) and len(data.shape) == 5: # interpret as PYXCZ + nda = data + else: + raise InvalidDataForPatchStackError(f'Cannot create accessor from {type(data)}') + + assert nda.ndim == 5 + self._data = nda + + def iat(self, i): + return InMemoryDataAccessor(self.data[i, :, :, :, :]) + + def iat_yxcz(self, i): + return self.iat(i) + + @property + def count(self): + return self.shape_dict['P'] + + @property + def shape_dict(self): + return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) + + def get_list(self): + n = self.nz + return [self.data[:, :, 0, zi] for zi in range(0, n)] + + @property + def pyxcz(self): + return self.data + + @property + def pczyx(self): + return np.moveaxis( + self.data, + [0, 3, 4, 1, 2], + [0, 1, 2, 3, 4] + ) + + @property + def shape(self): + return self.data.shape + + @property + def shape_dict(self): + return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) + + +def make_patch_stack_from_file(fpath): # interpret z-dimension as patch position + if not Path(fpath).exists(): + raise FileNotFoundError(f'Could not find {fpath}') + + pyxc = np.moveaxis( + generate_file_accessor(fpath).data, # yxcz + [0, 1, 2, 3], + [1, 2, 3, 0] + ) + pyxcz = np.expand_dims(pyxc, axis=3) + return PatchStack(pyxcz) + + + class Error(Exception): pass @@ -234,4 +343,10 @@ class InvalidAxisKey(Error): pass class InvalidDataShape(Error): - pass \ No newline at end of file + pass + +class InvalidDataForPatchStackError(Error): + pass + + + diff --git a/model_server/base/annotators.py b/model_server/base/annotators.py new file mode 100644 index 0000000000000000000000000000000000000000..b2df418d46fc5801f39e34f4512863dee69e7cfb --- /dev/null +++ b/model_server/base/annotators.py @@ -0,0 +1,52 @@ +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +from model_server.base.process import rescale + +def draw_boxes_on_3d_image(roiset, draw_full_depth=False, **kwargs): + _, _, chroma, nz = roiset.acc_raw.shape + font_size = kwargs.get('font_size', 18) + linewidth = kwargs.get('linewidth', 4) + + annotated = np.zeros(roiset.acc_raw.shape, dtype=roiset.acc_raw.dtype) + + for zi in range(0, nz): + if draw_full_depth: + subset = roiset.get_df() + else: + subset = roiset.get_df().query(f'zi == {zi}') + for c in range(0, chroma): + pilimg = Image.fromarray(roiset.acc_raw.data[:, :, c, zi]) + draw = ImageDraw.Draw(pilimg) + draw.font = ImageFont.truetype(font="arial.ttf", size=font_size) + + for roi in subset.itertuples('Roi'): + xm = round((roi.x0 + roi.x1) / 2) + draw.rectangle([(roi.x0, roi.y0), (roi.x1, roi.y1)], outline='white', width=linewidth) + + if kwargs.get('draw_label') is True: + draw.text((xm, roi.y0), f'{roi.label:04d}', fill='white', anchor='mb') + annotated[:, :, c, zi] = pilimg + + + if clip := kwargs.get('rescale_clip'): + assert clip >= 0.0 and clip <= 1.0 + annotated = rescale(annotated, clip=clip) + + return annotated + +def draw_box_on_patch(patch, bbox, linewidth=1): + assert len(patch.shape) == 2 + ((x0, y0), (x1, y1)) = bbox + pilimg = Image.fromarray(patch) # drawing modifies array in-place + draw = ImageDraw.Draw(pilimg) + draw.rectangle([(x0, y0), (x1, y1)], outline='white', width=linewidth) + return np.array(pilimg) + +def draw_contours_on_patch(patch, contours, linewidth=1): + assert len(patch.shape) == 2 + pilimg = Image.fromarray(patch) # drawing modifies array in-place + draw = ImageDraw.Draw(pilimg) + for co in contours: + draw.line([(p[1], p[0]) for p in co], width=linewidth, joint='curve') + return np.array(pilimg) \ No newline at end of file diff --git a/model_server/base/api.py b/model_server/base/api.py index 613b05cb647d2b6eaddf129d98665008b4cc906e..2ab37ad3da614c3a54ec1c01250a60d5dbf4fec4 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -1,9 +1,10 @@ from fastapi import FastAPI, HTTPException -from model_server.base.models import DummyImageToImageModel +from model_server.base.models import DummySemanticSegmentationModel from model_server.base.session import Session, InvalidPathError from model_server.base.validators import validate_workflow_inputs -from model_server.base.workflows import infer_image_to_image +from model_server.base.workflows import classify_pixels +from model_server.extensions.ilastik.workflows import infer_px_then_ob_model app = FastAPI(debug=True) session = Session() @@ -66,13 +67,13 @@ def list_active_models(): @app.put('/models/dummy/load/') def load_dummy_model() -> dict: - return {'model_id': session.load_model(DummyImageToImageModel)} + return {'model_id': session.load_model(DummySemanticSegmentationModel)} -@app.put('/infer/from_image_file') +@app.put('/workflows/segment') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: inpath = session.paths['inbound_images'] / input_filename validate_workflow_inputs([model_id], [inpath]) - record = infer_image_to_image( + record = classify_pixels( inpath, session.models[model_id]['object'], session.paths['outbound_images'], diff --git a/model_server/base/models.py b/model_server/base/models.py index 014957be8ada6bc059a9da7852e561a6ad4e63b0..8413f68d4420fe31add5e25818b4198bdc1a76eb 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -35,7 +35,12 @@ class Model(ABC): pass @abstractmethod - def infer(self, img: GenericImageDataAccessor) -> (object, dict): # return json describing inference result + def infer(self, *args) -> (object, dict): + """ + Abstract method that carries out the computationally intensive step of running data through a model + :param args: + :return: + """ pass def reload(self): @@ -51,7 +56,40 @@ class ImageToImageModel(Model): def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): pass -class DummyImageToImageModel(ImageToImageModel): + +class SemanticSegmentationModel(ImageToImageModel): + """ + Base model that exposes a method that returns a binary mask for a given input image and pixel class + """ + + @abstractmethod + def label_pixel_class( + self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: + """ + Given an image, return an image of the same shape where each pixel is assigned to one or more integer classes + """ + pass + + +class InstanceSegmentationModel(ImageToImageModel): + """ + Base model that exposes a method that returns an instance classification map for a given input image and mask + """ + + @abstractmethod + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + """ + Given an image and a mask of the same size, return a map where each connected object is assigned a class + """ + if not mask.is_mask(): + raise InvalidInputImageError('Expecting a binary mask') + if not img.shape == mask.shape: + raise InvalidInputImageError('Expect input image and mask to be the same shape') + + +class DummySemanticSegmentationModel(SemanticSegmentationModel): model_id = 'dummy_make_white_square' @@ -66,6 +104,34 @@ class DummyImageToImageModel(ImageToImageModel): result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255 return InMemoryDataAccessor(data=result), {'success': True} + def label_pixel_class( + self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: + mask, _ = self.infer(img) + return mask + +class DummyInstanceSegmentationModel(InstanceSegmentationModel): + + model_id = 'dummy_pass_input_mask' + + def load(self): + return True + + def infer( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor + ) -> (GenericImageDataAccessor, dict): + return img.__class__( + (mask.data / mask.data.max()).astype('uint16') + ) + + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + """ + Returns a trivial segmentation, i.e. the input mask with value 1 + """ + super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs) + return self.infer(img, mask) + class Error(Exception): pass @@ -73,4 +139,7 @@ class CouldNotLoadModelError(Error): pass class ParameterExpectedError(Error): + pass + +class InvalidInputImageError(Error): pass \ No newline at end of file diff --git a/model_server/base/process.py b/model_server/base/process.py index 481282237964ef147821fe99f79f025cf768c948..992c9fb93140ab49859f0807b55a1f8e321cd80d 100644 --- a/model_server/base/process.py +++ b/model_server/base/process.py @@ -4,6 +4,7 @@ Image processing utility functions from math import ceil, floor import numpy as np +import skimage from skimage.exposure import rescale_intensity @@ -71,4 +72,57 @@ def rescale(nda, clip=0.0): clip_pct = (100.0 * clip, 100.0 * (1.0 - clip)) cmin, cmax = np.percentile(nda, clip_pct) rescaled = rescale_intensity(nda, in_range=(cmin, cmax)) - return rescaled \ No newline at end of file + return rescaled + + +def make_rgb(nda): + """ + Convert a YXCZ stack array to RGB, and error if more than three channels + :param nda: np.ndarray (YXCZ dimensions) + :return: np.ndarray of 3-channel stack + """ + h, w, c, nz = nda.shape + assert c <= 3 + outdata = np.zeros((h, w, 3, nz), dtype=nda.dtype) + outdata[:, :, 0:c, :] = nda[:, :, :, :] + return outdata + + +def mask_largest_object( + img: np.ndarray, + max_allowed: int = 10, + verbose: bool = True +) -> np.ndarray: + """ + Where more than one connected component is found in an image, return the largest object by area + :param img: (np.ndarray) containing object labels or binary mask + :param max_allowed: raise an error if more than this number of objects is found + :param verbose: print a message each time more than one object is found + :return: np.ndarray of same size as img + """ + if is_mask(img): # assign object labels + ob_id = skimage.measure.label(img) + else: # assume img is contains object labels + ob_id = img + + num_obj = len(np.unique(ob_id)) - 1 + if num_obj > max_allowed: + raise TooManyObjectError(f'Found {num_obj} objects in frame') + if num_obj > 1: + if verbose: + print(f'Found {num_obj} nonzero unique values in object map; keeping the one with the largest area') + val, cts = np.unique(ob_id, return_counts=True) + mask = ob_id == val[1 + cts[1:].argmax()] + return mask * img + else: + return img + + +class Error(Exception): + pass + + +class TooManyObjectError(Exception): + pass + + diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py new file mode 100644 index 0000000000000000000000000000000000000000..a66f97414c61b163403abfb75d96c119f616400b --- /dev/null +++ b/model_server/base/roiset.py @@ -0,0 +1,563 @@ +from math import sqrt, floor +from pathlib import Path +from typing import List, Union +from uuid import uuid4 + +import numpy as np +import pandas as pd +from pydantic import BaseModel +from scipy.stats import moment +from skimage.filters import sobel + +from skimage.measure import label, regionprops_table, shannon_entropy, find_contours +from sklearn.preprocessing import PolynomialFeatures +from sklearn.linear_model import LinearRegression + +from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file +from model_server.base.models import InstanceSegmentationModel +from model_server.base.process import pad, rescale, resample_to_8bit, make_rgb +from base.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image +from base.accessors import PatchStack +from base.process import mask_largest_object + + +class PatchParams(BaseModel): + draw_bounding_box: bool = False + draw_contour: bool = False + draw_mask: bool = False + rescale_clip: float = 0.001 + focus_metric: str = 'max_sobel' + rgb_overlay_channels: List[Union[int, None]] = [None, None, None] + rgb_overlay_weights: List[float] = [1.0, 1.0, 1.0] + pad_to: int = 256 + + +class AnnotatedZStackParams(BaseModel): + draw_label: bool = False + + +class RoiFilterRange(BaseModel): + min: float + max: float + + +class RoiFilter(BaseModel): + area: Union[RoiFilterRange, None] = None + solidity: Union[RoiFilterRange, None] = None + + +class RoiSetMetaParams(BaseModel): + filters: Union[RoiFilter, None] = None + expand_box_by: List[int] = [128, 0] + + +class RoiSetExportParams(BaseModel): + pixel_probabilities: bool = False + patches_3d: Union[PatchParams, None] = None + annotated_patches_2d: Union[PatchParams, None] = None + patches_2d: Union[PatchParams, None] = None + patch_masks: Union[PatchParams, None] = None + annotated_zstacks: Union[AnnotatedZStackParams, None] = None + object_classes: bool = False + dataframe: bool = False + + + + +def _get_label_ids(acc_seg_mask: GenericImageDataAccessor) -> InMemoryDataAccessor: + return InMemoryDataAccessor(label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16')) + + +def _focus_metrics(): + return { + 'max_intensity': lambda x: np.max(x), + 'stdev': lambda x: np.std(x), + 'max_sobel': lambda x: np.max(sobel(x)), + 'rms_sobel': lambda x: sqrt(np.mean(sobel(x) ** 2)), + 'entropy': lambda x: shannon_entropy(x), + 'moment': lambda x: moment(x.flatten(), moment=2), + } + + +def _safe_add(a, g, b): + assert a.dtype == b.dtype + assert a.shape == b.shape + assert g >= 0.0 + + return np.clip( + a.astype('uint32') + g * b.astype('uint32'), + 0, + np.iinfo(a.dtype).max + ).astype(a.dtype) + + +class RoiSet(object): + + def __init__( + self, + acc_raw: GenericImageDataAccessor, + acc_obj_ids: GenericImageDataAccessor, + params: RoiSetMetaParams = RoiSetMetaParams(), + ): + """ + A set of regions of interest, referenced by their positions and contours in the YXCZ space of stack acc_raw. + RoiSet contains their internal state, which may be exported as patches, maps, and other products by export methods. + :param acc_raw: accessor to a generally a multichannel z-stack + :param acc_obj_ids: accessor to a 2D single-channel object identities map, where each pixel's intensity + labels its membership in a connected object + :param params: optional arguments that influence the definition and representation of ROIs + """ + assert acc_obj_ids.chroma == 1 + assert acc_obj_ids.nz == 1 + self.acc_obj_ids = acc_obj_ids + self.acc_raw = acc_raw + self.params = params + + self._df = self.filter_df( + self.make_df( + self.acc_raw, self.acc_obj_ids, expand_box_by=params.expand_box_by + ), + params.filters, + ) + + self.count = len(self._df) + self.object_class_maps = {} # classification results + + def __iter__(self): + """Expose ROI meta information via the Pandas.DataFrame API""" + return self._df.itertuples(name='Roi') + + @staticmethod + def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame: + """ + Build dataframe associate object IDs with summary stats + :param acc_raw: accessor to raw image data + :param acc_obj_ids: accessor to map of object IDs + :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary) + :return: pd.DataFrame + """ + # build dataframe of objects, assign z index to each object + argmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16') + df = ( + pd.DataFrame( + regionprops_table( + acc_obj_ids.data[:, :, 0, 0], + intensity_image=argmax, + properties=('label', 'area', 'intensity_mean', 'solidity', 'bbox', 'centroid') + ) + ) + .rename( + columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1', } + ) + ) + df['zi'] = df['intensity_mean'].round().astype('int') + + # compute expanded bounding boxes + h, w, c, nz = acc_raw.shape + ebxy, ebz = expand_box_by + df['ebb_y0'] = (df.y0 - ebxy).apply(lambda x: max(x, 0)) + df['ebb_y1'] = (df.y1 + ebxy).apply(lambda x: min(x, h)) + df['ebb_x0'] = (df.x0 - ebxy).apply(lambda x: max(x, 0)) + df['ebb_x1'] = (df.x1 + ebxy).apply(lambda x: min(x, w)) + df['ebb_z0'] = (df.zi - ebz).apply(lambda x: max(x, 0)) + df['ebb_z1'] = (df.zi + ebz).apply(lambda x: min(x, nz)) + df['ebb_h'] = df['ebb_y1'] - df['ebb_y0'] + df['ebb_w'] = df['ebb_x1'] - df['ebb_x0'] + df['ebb_nz'] = df['ebb_z1'] - df['ebb_z0'] + 1 + + # compute relative bounding boxes + df['rel_y0'] = df.y0 - df.ebb_y0 + df['rel_y1'] = df.y1 - df.ebb_y0 + df['rel_x0'] = df.x0 - df.ebb_x0 + df['rel_x1'] = df.x1 - df.ebb_x0 + + assert np.all(df['rel_x1'] <= (df['ebb_x1'] - df['ebb_x0'])) + assert np.all(df['rel_y1'] <= (df['ebb_x1'] - df['ebb_x0'])) + + df['slice'] = df.apply( + lambda r: + np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1], + axis=1 + ) + df['relative_slice'] = df.apply( + lambda r: + np.s_[int(r.rel_y0): int(r.rel_y1), int(r.rel_x0): int(r.rel_x1), :, :], + axis=1 + ) + df['mask'] = df.apply( + lambda r: (acc_obj_ids.data == r.label)[r.y0: r.y1, r.x0: r.x1, 0, 0], + axis=1 + ) + return df + + + @staticmethod + def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: + query_str = 'label > 0' # always true + if filters is not None: # parse filters + for k, val in filters.dict(exclude_unset=True).items(): + assert k in ('area', 'solidity') + vmin = val['min'] + vmax = val['max'] + assert vmin >= 0 + query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}' + return df.loc[df.query(query_str).index, :] + + def get_df(self) -> pd.DataFrame: # TODO: exclude columns that refer to objects + return self._df + + def get_slices(self) -> pd.Series: + return self.get_df()['slice'] + + def add_df_col(self, name, se: pd.Series) -> None: + self._df[name] = se + + def get_multichannel_projection(self): # TODO: document and test + if self.count: + projected = project_stack_from_focal_points( + self._df['centroid-0'].to_numpy(), + self._df['centroid-1'].to_numpy(), + self._df['zi'].to_numpy(), + self.acc_raw, + degree=4, + ) + else: # else just return MIP + projected = self.acc_raw.data.max(axis=-1) + return projected + + # TODO: remove, since padding is implicit in PatchStack + # TODO: test case where patch channel is restricted + def get_raw_patches(self, channel=None, pad_to=256, make_3d=False): # padded, un-annotated 2d patches + if channel: + patches_df = self.get_patches(white_channel=channel, pad_to=pad_to) + else: + patches_df = self.get_patches(pad_to=pad_to) + patches = list(patches_df['patch']) + return PatchStack(patches) + + def export_annotated_zstack(self, where, prefix='zstack', **kwargs): + annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs)) + success = write_accessor_data_to_file(where / (prefix + '.tif'), annotated) + return {'location': where.__str__(), 'filename': prefix + '.tif'} + + def get_zmask(self, mask_type='boxes'): + """ + Return a mask of same dimensionality as raw data + + :param kwargs: variable-length keyword arguments + mask_type: if 'boxes', zmask is True in each object's complete bounding box; otherwise 'contours' + """ + + assert mask_type in ('contours', 'boxes') + zi_st = np.zeros(self.acc_raw.shape, dtype='bool') + lamap = self.acc_obj_ids.data + + # make an object map where label is replaced by focus position in stack and background is -1 + lut = np.zeros(lamap.max() + 1) - 1 + df = self.get_df() + lut[df.label] = df.zi + + if mask_type == 'contours': + zi_map = (lut[lamap] + 1.0).astype('int') + idxs = np.array(zi_map) - 1 + np.put_along_axis( + zi_st, + np.expand_dims(idxs, (2, 3)), + 1, + axis=3 + ) + + # change background level from to 0 in final frame + zi_st[:, :, :, -1][lamap == 0] = 0 + + elif mask_type == 'boxes': + for roi in self: + zi_st[roi.relative_slice] = 1 + + return zi_st + + + def classify_by(self, name: str, channel: int, object_classification_model: InstanceSegmentationModel, ): + + # do this on a patch basis, i.e. only one object per frame + obmap_patches = object_classification_model.label_instance_class( + self.get_raw_patches(channel=channel), + self.get_patch_masks() + ) + + om = np.zeros(self.acc_obj_ids.shape, self.acc_obj_ids.dtype) + + self._df['classify_by_' + name] = pd.Series(dtype='Int64') + + # assign labels to object map: + for i, roi in enumerate(self): + oc = np.unique( + mask_largest_object( + obmap_patches.iat(i).data + ) + )[1] + self._df.loc[roi.Index, 'classify_by_' + name] = oc + om[self.acc_obj_ids.data == roi.label] = oc + self.object_class_maps[name] = InMemoryDataAccessor(om) + + def export_patch_masks(self, where: Path, pad_to: int = 256, prefix='mask', **kwargs) -> list: + patches_acc = self.get_patch_masks(pad_to=pad_to) + + exported = [] + for i, roi in enumerate(self): # assumes index of patches_acc is same as dataframe + patch = patches_acc.iat_yxcz(i) + ext = 'png' + fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' + write_accessor_data_to_file(where / fname, patch) + exported.append(fname) + return exported + + + def export_patches(self, where: Path, prefix='patch', **kwargs): + make_3d = kwargs.get('make_3d', False) + patches_df = self.get_patches(**kwargs) + + def _export_patch(roi): + patch = InMemoryDataAccessor(roi.patch) + ext = 'tif' if make_3d or patch.chroma > 3 else 'png' + fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' + + if patch.dtype is np.dtype('uint16'): + resampled = InMemoryDataAccessor(resample_to_8bit(patch.data)) + write_accessor_data_to_file(where / fname, resampled) + else: + write_accessor_data_to_file(where / fname, patch) + + exported.append({ + 'df_index': roi.Index, + 'patch_filename': fname, + 'location': where.__str__(), + }) + + exported = [] + for roi in patches_df.itertuples(): # just used for label info + _export_patch(roi) + + return exported + + def get_patch_masks(self, pad_to: int = 256) -> PatchStack: + patches = [] + for roi in self: + patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') + patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255 + + if pad_to: + patch = pad(patch, pad_to) + + patches.append(patch) + return PatchStack(patches) + + + def get_patches( + self, + rescale_clip: float = 0.0, + pad_to: int = 256, + make_3d: bool = False, + focus_metric: str = None, + rgb_overlay_channels: list = None, + rgb_overlay_weights: list = [1.0, 1.0, 1.0], + white_channel: int = None, + **kwargs + ) -> pd.DataFrame: + + # arrange RGB channels if so specified, otherwise copy roiset.raw_acc data + raw = self.acc_raw + if isinstance(rgb_overlay_channels, (list, tuple)) and isinstance(rgb_overlay_weights, (list, tuple)): + assert all([c < raw.chroma for c in rgb_overlay_channels if c is not None]) + assert len(rgb_overlay_channels) == 3 + assert len(rgb_overlay_weights) == 3 + + if white_channel: + assert white_channel < raw.chroma + stack = raw.data[:, :, [white_channel, white_channel, white_channel], :] + else: + stack = np.zeros([*raw.shape[0:2], 3, raw.shape[3]], dtype=raw.dtype) + + for ii, ci in enumerate(rgb_overlay_channels): + if ci is None: + continue + assert isinstance(ci, int) + assert ci < raw.chroma + stack[:, :, ii, :] = _safe_add( + stack[:, :, ii, :], # either black or grayscale channel + rgb_overlay_weights[ii], + raw.data[:, :, ci, :] + ) + else: + if white_channel: # interpret as just a single channel + assert white_channel < raw.chroma + annotate_rgb = False + for k in ['contour_channel', 'bounding_box_channel', 'mask_channel']: + ca = kwargs.get(k) + if ca is None: + continue + assert (ca < raw.chroma) + if ca != white_channel: + annotate_rgb = True + break + if annotate_rgb: # make RGB patches anyway to include annotation color + stack = raw.data[:, :, [white_channel, white_channel, white_channel], :] + else: # make monochrome patches + stack = raw.data[:, :, [white_channel], :] + else: + stack = raw.data + + def _make_patch(roi): + patch3d = stack[roi.slice] + ph, pw, pc, pz = patch3d.shape + subpatch = patch3d[roi.relative_slice] + + # make a 3d patch + if make_3d: + patch = patch3d + + # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately + elif focus_metric is not None: + foc = _focus_metrics()[focus_metric] + + patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype) + + for ci in range(0, pc): + me = [foc(subpatch[:, :, ci, zi]) for zi in range(0, pz)] + zif = np.argmax(me) + patch[:, :, ci, 0] = patch3d[:, :, ci, zif] + + # make a 2d patch from middle of z-stack + else: + zim = floor(pz / 2) + patch = patch3d[:, :, :, [zim]] + + assert len(patch.shape) == 4 + + if rescale_clip is not None: + patch = rescale(patch, rescale_clip) + + if kwargs.get('draw_bounding_box') is True: + bci = kwargs.get('bounding_box_channel', 0) + assert bci < 3 + if bci > 0: + patch = make_rgb(patch) + for zi in range(0, patch.shape[3]): + patch[:, :, bci, zi] = draw_box_on_patch( + patch[:, :, bci, zi], + ((roi.rel_x0, roi.rel_y0), (roi.rel_x1, roi.rel_y1)), + linewidth=kwargs.get('bounding_box_linewidth', 1) + ) + + if kwargs.get('draw_mask'): + mci = kwargs.get('mask_channel', 0) + mask = np.zeros(patch.shape[0:2], dtype=bool) + mask[roi.relative_slice[0:2]] = roi.mask + for zi in range(0, patch.shape[3]): + patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi] + + if kwargs.get('draw_contour'): + mci = kwargs.get('contour_channel', 0) + mask = np.zeros(patch.shape[0:2], dtype=bool) + mask[roi.relative_slice[0:2]] = roi.mask + + for zi in range(0, patch.shape[3]): + patch[:, :, mci, zi] = draw_contours_on_patch( + patch[:, :, mci, zi], + find_contours(mask) + ) + + if pad_to: + patch = pad(patch, pad_to) + return patch + + dfe = self._df + dfe['patch'] = self._df.apply(lambda r: _make_patch(r), axis=1) + return dfe + + def run_exports(self, where, channel, prefix, params: RoiSetExportParams): + if not self.count: + return + raw_ch = self.acc_raw.get_one_channel_data(channel) + for k in params.dict().keys(): + subdir = where / k + pr = prefix + kp = params.dict()[k] + if kp is None: + continue + if k == 'patches_3d': + files = self.export_patches( + subdir, white_channel=channel, prefix=pr, make_3d=True, **kp + ) + if k == 'annotated_patches_2d': + files = self.export_patches( + subdir, prefix=pr, make_3d=False, white_channel=channel, + bounding_box_channel=1, bounding_box_linewidth=2, **kp, + ) + if k == 'patches_2d': + files = self.export_patches( + subdir, white_channel=channel, prefix=pr, make_3d=False, **kp + ) + df_patches = pd.DataFrame(files) + self._df = pd.merge(self._df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') + self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1) + if k == 'patch_masks': + self.export_patch_masks(subdir, prefix=pr, **kp) + if k == 'annotated_zstacks': + self.export_annotated_zstack(subdir, prefix=pr, **kp) + if k == 'object_classes': + for k, acc in self.object_class_maps.items(): + write_accessor_data_to_file(subdir / k / (pr + '.tif'), acc) + if k == 'dataframe': + dfpa = subdir / (pr + '.csv') + dfpa.parent.mkdir(parents=True, exist_ok=True) + self._df.to_csv(dfpa, index=False) + + +def project_stack_from_focal_points( + xx: np.ndarray, + yy: np.ndarray, + zz: np.ndarray, + stack: GenericImageDataAccessor, + degree: int = 2, +) -> np.ndarray: + """ + Given a set of 3D points, project a multichannel z-stack based on a surface fit of the provided points + :param xx: vector of point x-coordinates + :param yy: vector of point y-coordinates + :param zz: vector of point z-coordinates + :param stack: z-stack to project + :param degree: order of polynomial to fit + :return: multichannel 2d projected image array + """ + assert xx.shape == yy.shape + assert xx.shape == zz.shape + + poly = PolynomialFeatures(degree=degree) + X = np.stack([xx, yy]).T + features = poly.fit_transform(X, zz) + model = LinearRegression(fit_intercept=False) + model.fit(features, zz) + + xy_indices = np.indices(stack.hw).reshape(2, -1).T + xy_features = np.dot( + poly.fit_transform(xy_indices, zz), + model.coef_ + ) + zi_image = xy_features.reshape( + stack.hw + ).round().clip( + 0, (stack.nz - 1) + ).astype('uint16') + + return np.take_along_axis( + stack.data, + np.repeat( + np.expand_dims(zi_image, (2, 3)), + stack.chroma, + axis=2 + ), + axis=3 + ) + + diff --git a/model_server/base/workflows.py b/model_server/base/workflows.py index e3eedb68d9ed2574180e47c1a16e65da5b50372e..f3719433b727e770285a53dd3f2fb55c7517dbf7 100644 --- a/model_server/base/workflows.py +++ b/model_server/base/workflows.py @@ -6,7 +6,7 @@ from time import perf_counter from typing import Dict from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file -from model_server.base.models import Model +from model_server.base.models import SemanticSegmentationModel from pydantic import BaseModel @@ -29,11 +29,12 @@ class WorkflowRunRecord(BaseModel): timer_results: Dict[str, float] -def infer_image_to_image(fpi: Path, model: Model, where_output: Path, **kwargs) -> WorkflowRunRecord: +def classify_pixels(fpi: Path, model: SemanticSegmentationModel, where_output: Path, **kwargs) -> WorkflowRunRecord: """ - Generic workflow where a model processes an input image into an output image + Run a semantic segmentation model to compute a binary mask from an input image + :param fpi: Path object that references input image file - :param model: model object + :param model: semantic segmentation model instance :param where_output: Path object that references output image directory :param kwargs: variable-length keyword arguments :return: record object @@ -43,7 +44,7 @@ def infer_image_to_image(fpi: Path, model: Model, where_output: Path, **kwargs) img = generate_file_accessor(fpi).get_one_channel_data(ch) ti.click('file_input') - outdata, _ = model.infer(img) + outdata = model.label_pixel_class(img) ti.click('inference') outpath = where_output / (model.model_id + '_' + fpi.stem + '.tif') diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py index 33cc08595f00e9176e0207e793b7d65b8439b9df..3d07931ee087dba7eb56f416f232b16b89056161 100644 --- a/model_server/conf/testing.py +++ b/model_server/conf/testing.py @@ -56,7 +56,25 @@ monozstackmask = { ilastik_classifiers = { 'px': root / 'ilastik' / 'demo_px.ilp', 'pxmap_to_obj': root / 'ilastik' / 'demo_obj.ilp', - 'seg_to_obj': root / 'ilastik' / 'new_auto_obj.ilp', + 'seg_to_obj': root / 'ilastik' / 'demo_obj_seg.ilp', +} + +roiset_test_data = { + 'multichannel_zstack': { + 'path': root / 'zmask-test-stack-chlorophyl.tif', + 'w': 512, + 'h': 512, + 'c': 5, + 'z': 7, + 'mask_path': root / 'zmask-test-stack-mask.tif', + }, + 'pipeline_params': { + 'segmentation_channel': 0, + 'patches_channel': 4, + 'pxmap_channel': 0, + 'pxmap_threshold': 0.6, + }, + 'pixel_classifier': root / 'zmask' / 'AF405-bodies_boundaries.ilp', } output_path = root / 'testing_output' diff --git a/model_server/extensions/chaeo/accessors.py b/model_server/extensions/chaeo/accessors.py deleted file mode 100644 index 0e6cfdc07bba6479df9b738afa880612f0e244e8..0000000000000000000000000000000000000000 --- a/model_server/extensions/chaeo/accessors.py +++ /dev/null @@ -1,129 +0,0 @@ -from pathlib import Path - -import numpy as np - -from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor - -class MonoPatchStack(InMemoryDataAccessor): - - def __init__(self, data): - """ - A sequence of n monochrome images of the same size - :param data: either np.ndarray of dimensions YXn, or a list of np.ndarrays of size YX - """ - - if isinstance(data, np.ndarray): # interpret as YXZ - assert data.ndim == 3 - self._data = np.expand_dims(data, 2) - elif isinstance(data, list): # list of YX patches - if len(data) == 0: - self._data = np.ndarray([0, 0, 0, 0], dtype='uin9') - elif len(data) == 1: - self._data = np.expand_dims( - np.array( - data[0].squeeze() - ), - (2, 3) - ) - else: - nda = np.array(data).squeeze() - assert nda.ndim == 3 - self._data = np.expand_dims( - np.moveaxis( - nda, - [1, 2, 0], - [0, 1, 2]), - 2 - ) - else: - raise InvalidDataForPatchStackError(f'Cannot create accessor from {type(data)}') - - def make_tczyx(self): - assert self.chroma == 1 - tyx = np.moveaxis( - self.data[:, :, 0, :], # YX(C)Z - [2, 0, 1], - [0, 1, 2] - ) - return np.expand_dims(tyx, (1, 2)) - - @property - def count(self): - return self.nz - - def iat(self, i): - return self.data[:, :, 0, i] - - def iat_yxcz(self, i): - return np.expand_dims(self.iat(i), (2, 3)) - - def get_list(self): - n = self.nz - return [self.data[:, :, 0, zi] for zi in range(0, n)] - - -class MonoPatchStackFromFile(MonoPatchStack): - def __init__(self, fpath): - if not Path(fpath).exists(): - raise FileNotFoundError(f'Could not find {fpath}') - self.file_acc = generate_file_accessor(fpath) - super().__init__(self.file_acc.data[:, :, 0, :]) - - @property - def fpath(self): - return self.file_acc.fpath - -class Multichannel3dPatchStack(InMemoryDataAccessor): - - def __init__(self, data): - """ - A sequence of n (generally) color 3D images of the same size - :param data: a list of np.ndarrays of size YXCZ - """ - - if isinstance(data, list): # list of YXCZ patches - nda = np.array(data) - assert nda.ndim == 5 - # self._data = np.moveaxis( # pos-YXCZ - # nda, - # [0, 1, 2, 0, 3], - # [0, 1, 2, 3] - # ) - self._data = nda - else: - raise InvalidDataForPatchStackError(f'Cannot create accessor from {type(data)}') - - def iat(self, i): - return self.data[i, :, :, :, :] - - def iat_yxcz(self, i): - return self.iat(i) - - @property - def count(self): - return self.shape_dict['P'] - - @property - def data(self): - """ - Return data as 5d with axes in order of pos, Y, X, C, Z - :return: np.ndarray - """ - return self._data - - @property - def shape(self): - return self._data.shape - - @property - def shape_dict(self): - return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) - -class Error(Exception): - pass - -class InvalidDataForPatchStackError(Error): - pass - -class FileNotFoundError(Error): - pass diff --git a/model_server/extensions/chaeo/annotators.py b/model_server/extensions/chaeo/annotators.py deleted file mode 100644 index 050123a496291668794c13553c23a31b3b559def..0000000000000000000000000000000000000000 --- a/model_server/extensions/chaeo/annotators.py +++ /dev/null @@ -1,66 +0,0 @@ -import numpy as np -from PIL import Image, ImageDraw, ImageFont - -from model_server.base.process import rescale - -def draw_boxes_on_2d_image(yx_img, boxes, **kwargs): - pilimg = Image.fromarray(np.copy(yx_img)) # drawing modifies array in-place - draw = ImageDraw.Draw(pilimg) - font_size = kwargs.get('font_size', 18) - linewidth = kwargs.get('linewidth', 4) - - draw.font = ImageFont.truetype(font="arial.ttf", size=font_size) - - for box in boxes: - y0 = box['info'].y0 - y1 = box['info'].y1 - x0 = box['info'].x0 - x1 = box['info'].x1 - xm = round((x0 + x1) / 2) - - la = box['info'].label - zi = box['info'].zi - - draw.rectangle([(x0, y0), (x1, y1)], outline='white', width=linewidth) - - if kwargs.get('add_label') is True: - draw.text((xm, y0), f'{la:04d}', fill='white', anchor='mb') - - return pilimg - - -def draw_boxes_on_3d_image(yxcz_img, boxes, draw_full_depth=False, **kwargs): - assert len(yxcz_img.shape) == 4 - nz = yxcz_img.shape[3] - assert yxcz_img.shape[2] == 1 - - annotated = np.zeros(yxcz_img.shape, dtype=yxcz_img.dtype) - - for zi in range(0, nz): - if draw_full_depth: - zi_boxes = boxes - else: - zi_boxes = [bb for bb in boxes if bb['info'].zi == zi] - annotated[:, :, 0, zi] = draw_boxes_on_2d_image(yxcz_img[:, :, 0, zi], zi_boxes, **kwargs) - - if clip := kwargs.get('rescale_clip'): - assert clip >= 0.0 and clip <= 1.0 - annotated = rescale(annotated, clip=clip) - - return annotated - -def draw_box_on_patch(patch, bbox, linewidth=1): - assert len(patch.shape) == 2 - ((x0, y0), (x1, y1)) = bbox - pilimg = Image.fromarray(patch) # drawing modifies array in-place - draw = ImageDraw.Draw(pilimg) - draw.rectangle([(x0, y0), (x1, y1)], outline='white', width=linewidth) - return np.array(pilimg) - -def draw_contours_on_patch(patch, contours, linewidth=1): - assert len(patch.shape) == 2 - pilimg = Image.fromarray(patch) # drawing modifies array in-place - draw = ImageDraw.Draw(pilimg) - for co in contours: - draw.line([(p[1], p[0]) for p in co], width=linewidth, joint='curve') - return np.array(pilimg) \ No newline at end of file diff --git a/model_server/extensions/chaeo/batch_jobs/coloring_book.py b/model_server/extensions/chaeo/batch_jobs/coloring_book.py index a3a9c5268a12cac204897d9ecd5dba4d8a648ad6..e65dd1d27345852d2687715de5d91cfd90f770b9 100644 --- a/model_server/extensions/chaeo/batch_jobs/coloring_book.py +++ b/model_server/extensions/chaeo/batch_jobs/coloring_book.py @@ -7,7 +7,7 @@ from skimage.measure import label, regionprops_table import tifffile -from model_server.extensions.chaeo.accessors import MonoPatchStack +from model_server.base.accessors import PatchStack from model_server.extensions.ilastik.models import IlastikPixelClassifierModel from model_server.base.accessors import write_accessor_data_to_file, InMemoryDataAccessor @@ -19,7 +19,7 @@ if __name__ == '__main__': min_area = 400 tf = tifffile.imread(root / '20231008-162336-z04-TL.tif') - instack = MonoPatchStack(np.moveaxis(tf, 0, -1)) + instack = PatchStack(np.moveaxis(tf, 0, -1)) px_mod = IlastikPixelClassifierModel(params={'project_file': px_ilp}) pxmaps = [] diff --git a/model_server/extensions/chaeo/batch_jobs/20231028_Porto_PA.py b/model_server/extensions/chaeo/batch_jobs/int_test_20231028_Porto_PA.py similarity index 58% rename from model_server/extensions/chaeo/batch_jobs/20231028_Porto_PA.py rename to model_server/extensions/chaeo/batch_jobs/int_test_20231028_Porto_PA.py index 4ae729d10601f0020cc4e85fdbff8ca2ffa9e785..6007943bc394d0d17a751621e2273ef0260de3e6 100644 --- a/model_server/extensions/chaeo/batch_jobs/20231028_Porto_PA.py +++ b/model_server/extensions/chaeo/batch_jobs/int_test_20231028_Porto_PA.py @@ -2,7 +2,8 @@ from pathlib import Path from model_server.base.util import autonumber_new_directory, get_matching_files, loop_workflow from model_server.extensions.chaeo.ecotaxa import write_ecotaxa_tsv_chunked_subdirectories -from model_server.extensions.chaeo.workflows import export_patches_from_multichannel_zstack +from extensions.chaeo import ZMaskExportParams +from model_server.extensions.chaeo.workflows import infer_object_map_from_zstack from model_server.extensions.ilastik.models import IlastikPixelClassifierModel @@ -17,6 +18,23 @@ if __name__ == '__main__': px_ilp = Path('c:/Users/rhodes/projects/proj0011-plankton-seg/exp0017/pxAF405_dim8bit.ilp').__str__() + export_params = { + 'pixel_probabilities': True, + # 'patches_3d': {}, + 'patches_2d_for_training': { + # 'draw_bounding_box': False, + }, + 'patches_2d_for_annotation': { + 'draw_bounding_box': True, + 'rgb_overlay_channels': (1, None, None), + 'rgb_overlay_weights': (0.2, 1.0, 1.0), + }, + 'annotated_z_stack': { + 'draw_label': True + } + } + ZMaskExportParams(**export_params) + params = { 'pxmap_threshold': 0.25, 'pxmap_foreground_channel': 0, @@ -26,17 +44,19 @@ if __name__ == '__main__': 'zmask_type': 'boxes', 'zmask_filters': {'area': (1e3, 1e8)}, 'zmask_expand_box_by': (128, 3), - 'export_pixel_probabilities': True, - 'export_2d_patches_for_training': True, - 'draw_bounding_box_on_2d_patch': True, - 'export_2d_patches_for_annotation': True, - 'export_3d_patches': False, - 'export_annotated_zstack': True, - 'export_patch_masks': True, 'zmask_clip': 0.01, - 'rgb_overlay_channels': (1, None, None), - 'rgb_overlay_weights': (0.2, 1.0, 1.0), - 'draw_label_on_zstack': True, + # 'export_pixel_probabilities': True, + # 'export_2d_patches_for_training': True, + # 'draw_bounding_box_on_2d_patch': True, + # 'export_2d_patches_for_annotation': True, + # 'export_3d_patches': False, + # 'export_annotated_zstack': True, + # 'export_patch_masks': True, + + # 'rgb_overlay_channels': (1, None, None), + # 'rgb_overlay_weights': (0.2, 1.0, 1.0), + # 'draw_label_on_zstack': True, + 'exports': ZMaskExportParams(**export_params), } input_files = get_matching_files(where_czi, 'czi', coord_filter={}) @@ -44,7 +64,7 @@ if __name__ == '__main__': loop_workflow( input_files, where_output, - export_patches_from_multichannel_zstack, + infer_object_map_from_zstack, [IlastikPixelClassifierModel(params={'project_file': Path(px_ilp)})], params, catch_and_continue=False, diff --git a/model_server/extensions/chaeo/conf/testing.py b/model_server/extensions/chaeo/conf/testing.py deleted file mode 100644 index 00389e08d8adbe01a3bd14c19f76342a9c00ccea..0000000000000000000000000000000000000000 --- a/model_server/extensions/chaeo/conf/testing.py +++ /dev/null @@ -1,22 +0,0 @@ -from pathlib import Path - -root = Path.home() / 'model_server' - -multichannel_zstack = { - 'path': root / 'testing' / 'zmask-test-stack-chlorophyl.tif', - 'w': 512, - 'h': 512, - 'c': 5, - 'z': 7, -} - -pixel_classifier = { - 'path': root / 'testing' / 'zmask' / 'AF405-bodies_boundaries.ilp' -} - -pipeline_params = { - 'threshold': 0.6, -} - -output_path = root / 'testing' / 'output' / 'chaeo' -output_path.mkdir(parents=True, exist_ok=True) \ No newline at end of file diff --git a/model_server/extensions/chaeo/examples/batch_obj_cla.py b/model_server/extensions/chaeo/examples/batch_obj_cla.py index 7471034f7091efec74c2cbf6230df385960d2ec9..eef5e70343061d6a47d09aac236395df141d1c17 100644 --- a/model_server/extensions/chaeo/examples/batch_obj_cla.py +++ b/model_server/extensions/chaeo/examples/batch_obj_cla.py @@ -2,7 +2,7 @@ from pathlib import Path from model_server.conf.testing import output_path from model_server.base.util import autonumber_new_directory, get_matching_files, loop_workflow -from model_server.extensions.chaeo.models import PatchStackObjectClassifier +from extensions.ilastik.models import PatchStackObjectClassifier from model_server.extensions.chaeo.workflows import infer_object_map_from_zstack from model_server.extensions.ilastik.models import IlastikPixelClassifierModel diff --git a/model_server/extensions/chaeo/examples/export_patch_focus_metrics.py b/model_server/extensions/chaeo/examples/export_patch_focus_metrics.py deleted file mode 100644 index 7e1279948ea90b10463f19d50b0d8412e605d12c..0000000000000000000000000000000000000000 --- a/model_server/extensions/chaeo/examples/export_patch_focus_metrics.py +++ /dev/null @@ -1,157 +0,0 @@ -from pathlib import Path -import re -from time import localtime, strftime -from typing import Dict - -import pandas as pd - -from model_server.extensions.ilastik.models import IlastikPixelClassifierModel -from model_server.extensions.chaeo.products import export_3d_patches_with_focus_metrics, export_patches_from_zstack -from model_server.extensions.chaeo.zmask import build_zmask_from_object_mask -from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file -from model_server.base.workflows import Timer - - -def export_patch_focus_metrics_from_multichannel_zstack( - input_zstack_path: str, - ilastik_project_file: str, - pxmap_threshold: float, - pixel_class: int, - zmask_channel: int, - patches_channel: int, - where_output: str, - mask_type: str = 'boxes', - zmask_filters: Dict = None, - zmask_expand_box_by: int = None, - annotate_focus_metric=None, - **kwargs, -) -> Dict: - - ti = Timer() - stack = generate_file_accessor(Path(input_zstack_path)) - fstem = Path(input_zstack_path).stem - ti.click('file_input') - assert stack.nz > 1, 'Expecting z-stack' - - # MIP and classify pixels - mip = InMemoryDataAccessor( - stack.get_one_channel_data(channel=0).data.max(axis=-1, keepdims=True) - ) - px_model = IlastikPixelClassifierModel( - params={'project_file': Path(ilastik_project_file)} - ) - pxmap, _ = px_model.infer(mip) - ti.click('infer_pixel_probability') - - obmask = InMemoryDataAccessor( - pxmap.data > pxmap_threshold - ) - ti.click('threshold_pixel_mask') - - # make zmask - zmask, zmask_meta, df, interm = build_zmask_from_object_mask( - obmask.get_one_channel_data(pixel_class), - stack.get_one_channel_data(zmask_channel), - mask_type=mask_type, - filters=zmask_filters, - expand_box_by=zmask_expand_box_by, - ) - zmask_acc = InMemoryDataAccessor(zmask) - ti.click('generate_zmasks') - - files = export_3d_patches_with_focus_metrics( - Path(where_output) / '3d_patches', - stack.get_one_channel_data(patches_channel), - zmask_meta, - prefix=fstem, - rescale_clip=0.0, - make_3d=True, - annotate_focus_metric=annotate_focus_metric, - ) - ti.click('export_3d_patches') - - files = export_patches_from_zstack( - Path(where_output) / '2d_patches', - stack.get_one_channel_data(patches_channel), - zmask_meta, - prefix=fstem, - draw_bounding_box=True, - rescale_clip=0.0, - # focus_metric=lambda x: np.max(sobel(x)), - focus_metric='max_sobel', - make_3d=False, - ) - ti.click('export_2d_patches') - - return { - 'pixel_model_id': px_model.model_id, - 'input_filepath': input_zstack_path, - 'number_of_objects': len(zmask_meta), - 'success': True, - 'timer_results': ti.events, - 'dataframe': df, - 'interm': interm, - } - -if __name__ == '__main__': - where_czi = Path( - 'c:/Users/rhodes/projects/proj0004-marine-photoactivation/data/exp0038/AutoMic/20230906-163415/Selection' - ) - - where_output_root = Path( - 'c:/Users/rhodes/projects/proj0011-plankton-seg/exp0009/output' - ) - yyyymmdd = strftime('%Y%m%d', localtime()) - idx = 0 - while Path(where_output_root / f'batch-output-{yyyymmdd}-{idx:04d}').exists(): - idx += 1 - where_output = Path( - where_output_root / f'batch-output-{yyyymmdd}-{idx:04d}' - ) - - csv_args = {'mode': 'w', 'header': True} # when creating file - px_ilp = Path.home() / 'model_server' / 'ilastik' / 'AF405-bodies_boundaries.ilp' - - for ff in where_czi.iterdir(): - if ff.stem != 'Selection--W0000--P0009-T0001': - continue - - pattern = 'Selection--W([\d]+)--P([\d]+)-T([\d]+)' - ma = re.match(pattern, ff.stem) - - print(ff) - if not ff.suffix.upper() == '.CZI': - continue - if int(ma.groups()[1]) > 10: # skip second half of set - continue - - export_kwargs = { - 'input_zstack_path': (where_czi / ff).__str__(), - 'ilastik_project_file': px_ilp.__str__(), - 'pxmap_threshold': 0.25, - 'pixel_class': 0, - 'zmask_channel': 0, - 'patches_channel': 4, - 'where_output': where_output.__str__(), - 'mask_type': 'boxes', - 'zmask_filters': {'area': (1e3, 1e8)}, - 'zmask_expand_box_by': (128, 3), - 'annotate_focus_metric': 'max_sobel' - } - - result = export_patch_focus_metrics_from_multichannel_zstack(**export_kwargs) - - # parse and record results - df = result['dataframe'] - df['filename'] = ff.name - df.to_csv(where_output / 'df_objects.csv', **csv_args) - pd.DataFrame(result['timer_results'], index=[0]).to_csv(where_output / 'timer_results.csv', **csv_args) - pd.json_normalize(export_kwargs).to_csv(where_output / 'workflow_params.csv', **csv_args) - csv_args = {'mode': 'a', 'header': False} # append to CSV from here on - - # export intermediate data if flagged - for k in result['interm'].keys(): - write_accessor_data_to_file( - where_output / k / (ff.stem + '.tif'), - InMemoryDataAccessor(result['interm'][k]) - ) \ No newline at end of file diff --git a/model_server/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py b/model_server/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py index ae30674583dee43e234688ffc985da0049e939fd..04e0c20e31d2da78ab11bec16908333e56a7c74f 100644 --- a/model_server/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py +++ b/model_server/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py @@ -3,8 +3,9 @@ import numpy as np import pandas as pd import skimage -from model_server.extensions.chaeo.accessors import MonoPatchStackFromFile -from model_server.extensions.chaeo.models import generate_ilastik_object_classifier, PatchStackObjectClassifier +from model_server.base.accessors import make_patch_stack_from_file +from model_server.extensions.chaeo.models import generate_ilastik_object_classifier +from extensions.ilastik.models import PatchStackObjectClassifier from model_server.base.accessors import GenericImageDataAccessor, write_accessor_data_to_file @@ -68,9 +69,9 @@ if __name__ == '__main__': classifier_file = generate_ilastik_object_classifier( template_ilp, root / 'new_auto_obj.ilp', - MonoPatchStackFromFile(root / 'zstack_train_raw.tif'), - MonoPatchStackFromFile(root / 'zstack_train_mask.tif'), - MonoPatchStackFromFile(root / 'zstack_train_label.tif'), + make_patch_stack_from_file(root / 'zstack_train_raw.tif'), + make_patch_stack_from_file(root / 'zstack_train_mask.tif'), + make_patch_stack_from_file(root / 'zstack_train_label.tif'), label_names, allow_multiple_objects=False ) @@ -80,17 +81,17 @@ if __name__ == '__main__': infer_and_compare( classifier, 'train', - MonoPatchStackFromFile(root / 'zstack_train_raw.tif'), - MonoPatchStackFromFile(root / 'zstack_train_mask.tif'), - MonoPatchStackFromFile(root / 'zstack_train_label.tif') + make_patch_stack_from_file(root / 'zstack_train_raw.tif'), + make_patch_stack_from_file(root / 'zstack_train_mask.tif'), + make_patch_stack_from_file(root / 'zstack_train_label.tif') ) # run test set infer_and_compare( classifier, 'test', - MonoPatchStackFromFile(root / 'zstack_test_raw.tif'), - MonoPatchStackFromFile(root / 'zstack_test_mask.tif'), - MonoPatchStackFromFile(root / 'zstack_test_label.tif'), + make_patch_stack_from_file(root / 'zstack_test_raw.tif'), + make_patch_stack_from_file(root / 'zstack_test_mask.tif'), + make_patch_stack_from_file(root / 'zstack_test_label.tif'), ) diff --git a/model_server/extensions/chaeo/models.py b/model_server/extensions/chaeo/models.py index 7f8c08bdcd490f8d7803583c3fe4b87c6e2d221d..865537e67ba79738a82e5b7e63472210026657f2 100644 --- a/model_server/extensions/chaeo/models.py +++ b/model_server/extensions/chaeo/models.py @@ -4,55 +4,16 @@ import shutil import h5py import numpy as np import skimage -import vigra -from model_server.extensions.chaeo.accessors import MonoPatchStack, MonoPatchStackFromFile -from model_server.extensions.ilastik.models import IlastikObjectClassifierFromSegmentationModel - - -class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel): - """ - Wrap ilastik object classification for inputs comprising raw image and binary segmentation masks, both represented - as time-series images where each frame contains only one object. - """ - - def infer(self, input_acc: MonoPatchStack, segmentation_acc: MonoPatchStack) -> (np.ndarray, dict): - assert segmentation_acc.is_mask() - assert input_acc.chroma == 1 - - tagged_input_data = vigra.taggedView(input_acc.make_tczyx(), 'tczyx') - tagged_seg_data = vigra.taggedView(segmentation_acc.make_tczyx(), 'tczyx') - - dsi = [ - { - 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), - 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), - } - ] - - obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] - - assert len(obmaps) == 1, 'ilastik generated more than one object map' - - # for some reason ilastik scrambles these axes to Z(1)YX(1) - assert obmaps[0].shape == (input_acc.nz, 1, input_acc.hw[0], input_acc.hw[1], 1) - yxz = np.moveaxis( - obmaps[0][:, 0, :, :, 0], - [1, 2, 0], - [0, 1, 2] - ) - - assert yxz.shape[0:2] == input_acc.hw - assert yxz.shape[2] == input_acc.nz - return MonoPatchStack(data=yxz), {'success': True} +from model_server.base.accessors import PatchStack def generate_ilastik_object_classifier( template_ilp: Path, target_ilp: Path, - raw_stack: MonoPatchStackFromFile, - mask_stack: MonoPatchStackFromFile, - label_stack: MonoPatchStackFromFile, + raw_stack: PatchStack, + mask_stack: PatchStack, + label_stack: PatchStack, label_names: list, lane: int = 0, allow_multiple_objects=True, diff --git a/model_server/extensions/chaeo/process.py b/model_server/extensions/chaeo/process.py deleted file mode 100644 index e41578eee35f7849e7d714326c59ca68e34042b0..0000000000000000000000000000000000000000 --- a/model_server/extensions/chaeo/process.py +++ /dev/null @@ -1,48 +0,0 @@ -import numpy as np -import skimage - -from model_server.base.process import is_mask - -def mask_largest_object( - img: np.ndarray, - max_allowed: int = 10, - verbose: bool = True -) -> np.ndarray: - """ - Where more than one connected component is found in an image, return the largest object by area - :param img: (np.ndarray) containing object labels or binary mask - :param max_allowed: raise an error if more than this number of objects is found - :param verbose: print a message each time more than one object is found - :return: np.ndarray of same size as img - """ - if is_mask(img): # assign object labels - ob_id = skimage.measure.label(img) - else: # assume img is contains object labels - ob_id = img - - # import skimage - # from pathlib import Path - # where = Path('c:/Users/rhodes/projects/proj0011-plankton-seg/tmp') - # skimage.io.imsave(where / 'raw.png', img) - num_obj = len(np.unique(ob_id)) - 1 - if num_obj > max_allowed: - raise TooManyObjectError(f'Found {num_obj} objects in frame') - if num_obj > 1: - if verbose: - print(f'Found {num_obj} nonzero unique values in object map; keeping the one with the largest area') - # pr = regionprops_table(ob_id, properties=['label', 'area']) - val, cts = np.unique(ob_id, return_counts=True) - mask = ob_id == val[1 + cts[1:].argmax()] - # idx_max_area = pr['area'].argmax() - # mask = ob_id == pr['label'][idx_max_area] - return mask * img - else: - return img - - -class Error(Exception): - pass - - -class TooManyObjectError(Exception): - pass diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py deleted file mode 100644 index b1b1dc1f0fd92af2d8d4789457788c46ec513370..0000000000000000000000000000000000000000 --- a/model_server/extensions/chaeo/products.py +++ /dev/null @@ -1,380 +0,0 @@ -from math import floor, sqrt -from pathlib import Path - -import numpy as np -import pandas as pd -from scipy.stats import moment -from skimage.filters import sobel -from skimage.io import imsave -from skimage.measure import find_contours, shannon_entropy -from tifffile import imwrite - -from model_server.extensions.chaeo.accessors import MonoPatchStack, Multichannel3dPatchStack -from model_server.extensions.chaeo.annotators import draw_box_on_patch, draw_contours_on_patch -from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor -from model_server.base.process import pad, rescale, resample_to_8bit - -def _make_rgb(zs): - h, w, c, nz = zs.shape - assert c <= 3 - outdata = np.zeros((h, w, 3, nz), dtype=zs.dtype) - outdata[:, :, 0:c, :] = zs[:, :, :, :] - return outdata - -def _focus_metrics(): - return { - 'max_intensity': lambda x: np.max(x), - 'stdev': lambda x: np.std(x), - 'max_sobel': lambda x: np.max(sobel(x)), - 'rms_sobel': lambda x: sqrt(np.mean(sobel(x) ** 2)), - 'entropy': lambda x: shannon_entropy(x), - 'moment': lambda x: moment(x.flatten(), moment=2), - } - -def _write_patch_to_file(where, fname, yxcz): - ext = fname.split('.')[-1].upper() - where.mkdir(parents=True, exist_ok=True) - - if ext == 'PNG': - assert yxcz.dtype == 'uint8', f'Invalid data type {yxcz.dtype}' - assert yxcz.shape[2] <= 3, f'Cannot export images with more than 3 channels as PNGs' - assert yxcz.shape[3] == 1, f'Cannot export z-stacks as PNGs' - if yxcz.shape[2] == 1: - outdata = yxcz[:, :, 0, 0] - elif yxcz.shape[2] == 2: # add a blank blue channel - outdata = _make_rgb(yxcz) - else: # preserve RGB order - outdata = yxcz[:, :, :, 0] - imsave(where / fname, outdata, check_contrast=False) - return True - - elif ext in ['TIF', 'TIFF']: - zcyx = np.moveaxis(yxcz, [3, 2, 0, 1], [0, 1, 2, 3]) - imwrite(where / fname, zcyx, imagej=True) - return True - - else: - raise Exception(f'Unsupported file extension: {ext}') - -def get_patch_masks_from_zmask_meta( - stack: GenericImageDataAccessor, - zmask_meta: list, - pad_to: int = 256, -) -> MonoPatchStack: - patches = [] - for mi in zmask_meta: - sl = mi['slice'] - - rbb = mi['relative_bounding_box'] - x0 = rbb['x0'] - y0 = rbb['y0'] - x1 = rbb['x1'] - y1 = rbb['y1'] - - sp_sl = np.s_[y0: y1, x0: x1, :, :] - - h, w = stack.data[sl].shape[0:2] - patch = np.zeros((h, w, 1, 1), dtype='uint8') - patch[sp_sl][:, :, 0, 0] = mi['mask'] * 255 - - if pad_to: - patch = pad(patch, pad_to) - - patches.append(patch) - return MonoPatchStack(patches) - -def export_patch_masks_from_zstack( - where: Path, - stack: GenericImageDataAccessor, - zmask_meta: list, - pad_to: int = 256, - prefix='mask', - **kwargs -): - patches_acc = get_patch_masks_from_zmask_meta( - stack, - zmask_meta, - pad_to=pad_to, - **kwargs - ) - assert len(zmask_meta) == patches_acc.count - - exported = [] - for i in range(0, len(zmask_meta)): - mi = zmask_meta[i] - obj = mi['info'] - patch = patches_acc.iat_yxcz(i) - ext = 'png' - fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}' - _write_patch_to_file(where, fname, patch) - exported.append(fname) - return exported - -def get_patches_from_zmask_meta( - stack: GenericImageDataAccessor, - zmask_meta: list, - rescale_clip: float = 0.0, - pad_to: int = 256, - make_3d: bool = False, - focus_metric: str = None, - **kwargs -) -> MonoPatchStack: - patches = [] - - for mi in zmask_meta: - - sl = mi['slice'] - rbb = mi['relative_bounding_box'] - idx = mi['df_index'] - - x0 = rbb['x0'] - y0 = rbb['y0'] - x1 = rbb['x1'] - y1 = rbb['y1'] - - sp_sl = np.s_[y0: y1, x0: x1, :, :] - - patch3d = stack.data[sl] - ph, pw, pc, pz = patch3d.shape - subpatch = patch3d[sp_sl] - - # make a 3d patch - if make_3d: - patch = patch3d - - # make a 2d patch, find optimal z-position determined by focus_metric function - elif focus_metric is not None: - foc = _focus_metrics()[focus_metric] - - patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype) - - for ci in range(0, pc): - me = [foc(subpatch[:, :, ci, zi]) for zi in range(0, pz)] - zif = np.argmax(me) - patch[:, :, ci, 0] = patch3d[:, :, ci, zif] - - # make a 2d patch from middle of z-stack - else: - zim = floor(pz / 2) - patch = patch3d[:, :, :, [zim]] - - assert len(patch.shape) == 4 - assert patch.shape[2] == stack.chroma - - if rescale_clip is not None: - patch = rescale(patch, rescale_clip) - - if kwargs.get('draw_bounding_box') is True: - bci = kwargs.get('bounding_box_channel', 0) - assert bci < 3 - if bci > 0: - patch = _make_rgb(patch) - for zi in range(0, patch.shape[3]): - patch[:, :, bci, zi] = draw_box_on_patch( - patch[:, :, bci, zi], - ((x0, y0), (x1, y1)), - linewidth=kwargs.get('bounding_box_linewidth', 1) - ) - - if kwargs.get('draw_mask'): - mci = kwargs.get('mask_channel', 0) - mask = np.zeros(patch.shape[0:2], dtype=bool) - mask[sp_sl[0:2]] = mi['mask'] - for zi in range(0, patch.shape[3]): - patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi] - - if kwargs.get('draw_contour'): - mci = kwargs.get('contour_channel', 0) - mask = np.zeros(patch.shape[0:2], dtype=bool) - mask[sp_sl[0:2]] = mi['mask'] - - for zi in range(0, patch.shape[3]): - patch[:, :, mci, zi] = draw_contours_on_patch( - patch[:, :, mci, zi], - find_contours(mask) - ) - - if pad_to: - patch = pad(patch, pad_to) - - patches.append(patch) - - if not make_3d and pc == 1: - return MonoPatchStack(patches) - else: - return Multichannel3dPatchStack(patches) - -def export_patches_from_zstack( - where: Path, - stack: GenericImageDataAccessor, - zmask_meta: list, - rescale_clip: float = 0.0, - pad_to: int = 256, - make_3d: bool = False, - prefix='patch', - focus_metric: str = None, - **kwargs -): - patches_acc = get_patches_from_zmask_meta( - stack, - zmask_meta, - rescale_clip=rescale_clip, - pad_to=pad_to, - make_3d=make_3d, - focus_metric=focus_metric, - **kwargs - ) - assert len(zmask_meta) == patches_acc.count - - exported = [] - for i in range(0, len(zmask_meta)): - mi = zmask_meta[i] - patch = patches_acc.iat_yxcz(i) - obj = mi['info'] - idx = mi['df_index'] - ext = 'tif' if make_3d else 'png' - fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}' - - if patch.dtype is np.dtype('uint16'): - _write_patch_to_file(where, fname, resample_to_8bit(patch)) - else: - _write_patch_to_file(where, fname, patch) - - exported.append({ - 'df_index': idx, - 'patch_filename': fname, - }) - return exported - -def export_3d_patches_with_focus_metrics( - where: Path, - stack: GenericImageDataAccessor, - zmask_meta: list, - rescale_clip: float = 0.0, - pad_to: int = 256, - prefix='patch', - **kwargs -): - """ - Export 3D patches as multi-level z-stacks, along with CSV of various focus methods for each z-position - - :param kwargs: - annotate_focus_metric: name focus metric to use when drawing bounding box at optimal focus z-position - :return: - list of exported files - """ - assert stack.chroma == 1, 'Expecting monochromatic image data' - assert stack.nz > 1, 'Expecting z-stack' - - def get_zstack_focus_metrics(zs): - nz = zs.shape[3] - me = _focus_metrics() - dd = {} - for zi in range(0, nz): - spf = zs[:, :, :, zi] - dd[zi] = {k: me[k](spf) for k in me.keys()} - return dd - - exported = [] - patch_meta = [] - for mi in zmask_meta: - obj = mi['info'] - sl = mi['slice'] - rbb = mi['relative_bounding_box'] - idx = mi['df_index'] - - patch = stack.data[sl] - - assert len(patch.shape) == 4 - assert patch.shape[2] == stack.chroma - - if rescale_clip is not None: - patch = rescale(patch, rescale_clip) - - # unpack relative bounding box and define subset of patch data - x0 = rbb['x0'] - y0 = rbb['y0'] - x1 = rbb['x1'] - y1 = rbb['y1'] - sp_sl = np.s_[y0: y1, x0: x1, :, :] - subpatch = patch[sp_sl] - - # compute focus metrics for all z-levels - me_dict = get_zstack_focus_metrics(subpatch) - patch_meta.append({'label': obj.label, 'zi': obj.zi, 'metrics': me_dict}) - me_df = pd.DataFrame(me_dict).T - - # drawing bounding box only on focused slice - ak = kwargs.get('annotate_focus_metric') - if ak and ak in me_df.columns: - zi_foc = me_df.idxmax().to_dict()[ak] - patch[:, :, 0, zi_foc] = draw_box_on_patch( - patch[:, :, 0, zi_foc], - ((x0, y0), (x1, y1)), - ) - - if pad_to: - patch = pad(patch, pad_to) - - fstem = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}' - _write_patch_to_file(where, fstem + '.tif', resample_to_8bit(patch)) - me_df.to_csv(where / (fstem + '.csv')) - exported.append({ - 'df_index': idx, - 'patch_filename': fstem + '.tif', - 'focus_metrics_filename': fstem + '.csv', - }) - - return exported - -def export_multichannel_patches_from_zstack( - where: Path, - stack: GenericImageDataAccessor, - zmask_meta: list, - ch_rgb_overlay: tuple = None, - overlay_gain: tuple = (1.0, 1.0, 1.0), - ch_white: int = None, - **kwargs -): - """ - Export RGB patches where each patch is assignable to a channel of the input stack - :param ch_rgb_overlay: tuple of integers (R, G, B) that assign a stack channel index to an RGB channel - :param overlay_gain: optional, tuple of float (R, G, B) multipliers that can be used to balance relative brightness - :param ch_white: int, index of stack channel that becomes grayscale signal in export patches - """ - def _safe_add(a, g, b): - assert a.dtype == b.dtype - assert a.shape == b.shape - assert g >= 0.0 - - return np.clip( - a.astype('uint32') + g * b.astype('uint32'), - 0, - np.iinfo(a.dtype).max - ).astype(a.dtype) - - idata = stack.data - if ch_white: - assert ch_white < stack.chroma - mdata = idata[:, :, [ch_white, ch_white, ch_white], :] - else: - mdata = idata - - if ch_rgb_overlay: - assert len(ch_rgb_overlay) == 3 - assert len(overlay_gain) == 3 - for ii, ci in enumerate(ch_rgb_overlay): - if ci is None: - continue - assert isinstance(ci, int) - assert ci < stack.chroma - mdata[:, :, ii, :] = _safe_add( - mdata[:, :, ii, :], - overlay_gain[ii], - idata[:, :, ci, :] - ) - - mstack = InMemoryDataAccessor(mdata) - return export_patches_from_zstack( - where, mstack, zmask_meta, **kwargs - ) \ No newline at end of file diff --git a/model_server/extensions/chaeo/tests/test_accessors.py b/model_server/extensions/chaeo/tests/test_accessors.py deleted file mode 100644 index d7986879407e8114ec218be125e71033acf34f09..0000000000000000000000000000000000000000 --- a/model_server/extensions/chaeo/tests/test_accessors.py +++ /dev/null @@ -1,74 +0,0 @@ -import unittest - -import numpy as np - -from model_server.conf.testing import monozstackmask -from model_server.extensions.chaeo.accessors import MonoPatchStack, MonoPatchStackFromFile, Multichannel3dPatchStack - - - -class TestCziImageFileAccess(unittest.TestCase): - def setUp(self) -> None: - pass - - def test_make_patch_stack_from_3d_array(self): - w = 256 - h = 512 - n = 4 - acc = MonoPatchStack(np.random.rand(h, w, n)) - self.assertEqual(acc.count, n) - self.assertEqual(acc.hw, (h, w)) - self.assertEqual(acc.make_tczyx().shape, (n, 1, 1, h, w)) - - def test_make_patch_stack_from_list(self): - w = 256 - h = 512 - n = 4 - acc = MonoPatchStack([np.random.rand(h, w) for _ in range(0, n)]) - self.assertEqual(acc.count, n) - self.assertEqual(acc.hw, (h, w)) - self.assertEqual(acc.make_tczyx().shape, (n, 1, 1, h, w)) - - def test_make_patch_stack_from_file(self): - h = monozstackmask['h'] - w = monozstackmask['w'] - c = monozstackmask['c'] - n = monozstackmask['z'] - acc = MonoPatchStackFromFile(monozstackmask['path']) - self.assertEqual(acc.hw, (h, w)) - self.assertEqual(acc.count, n) - self.assertEqual(acc.make_tczyx().shape, (n, c, 1, h, w)) - self.assertEqual(acc.fpath, monozstackmask['path']) - - def test_raises_filenotfound(self): - from model_server.extensions.chaeo.accessors import FileNotFoundError - with self.assertRaises(FileNotFoundError): - acc = MonoPatchStackFromFile('c:/fake/file/name.tif') - - def test_patch_as_yxcz_array(self): - w = 256 - h = 512 - n = 4 - acc = MonoPatchStack([np.random.rand(h, w) for _ in range(0, 4)]) - self.assertEqual(acc.iat_yxcz(0).shape, (h, w, 1, 1)) - - def test_make_3d_patch_stack_from_list(self): - w = 256 - h = 512 - c = 1 - nz = 5 - n = 4 - acc = Multichannel3dPatchStack([np.random.rand(h, w, c, nz) for _ in range(0, n)]) - self.assertEqual(acc.count, n) - self.assertEqual(acc.hw, (h, w)) - self.assertEqual(acc.chroma, c) - self.assertEqual(acc.iat(0).shape, (h, w, c, nz)) - - def test_3d_patch_as_yxcz_array(self): - w = 256 - h = 512 - nz = 5 - c = 1 - n = 4 - acc = Multichannel3dPatchStack([np.random.rand(h, w, c, nz) for _ in range(0, n)]) - self.assertEqual(acc.iat_yxcz(0).shape, (h, w, c, nz)) \ No newline at end of file diff --git a/model_server/extensions/chaeo/tests/test_process.py b/model_server/extensions/chaeo/tests/test_process.py deleted file mode 100644 index 79e6883c5f7099869b9b9901474c5c4cb833d158..0000000000000000000000000000000000000000 --- a/model_server/extensions/chaeo/tests/test_process.py +++ /dev/null @@ -1,32 +0,0 @@ -import unittest - -import numpy as np - -from model_server.extensions.chaeo.process import mask_largest_object - -class TestMaskLargestObject(unittest.TestCase): - def test_mask_largest_touching_object(self): - arr = np.zeros([5, 5], dtype='uint8') - arr[0:3, 0:3] = 2 - arr[3:, 2:] = 4 - masked = mask_largest_object(arr) - self.assertTrue(np.all(np.unique(masked) == [0, 2])) - self.assertTrue(np.all(masked[4:5, 0:2] == 0)) - self.assertTrue(np.all(masked[0:3, 3:5] == 0)) - - def test_no_change(self): - arr = np.zeros([5, 5], dtype='uint8') - arr[0:3, 0:3] = 2 - masked = mask_largest_object(arr) - self.assertTrue(np.all(masked == arr)) - - def test_mask_multiple_objects_in_binary_maks(self): - arr = np.zeros([5, 5], dtype='uint8') - arr[0:3, 0:3] = 255 - arr[4, 2:5] = 255 - masked = mask_largest_object(arr) - print(np.unique(masked)) - self.assertTrue(np.all(np.unique(masked) == [0, 255])) - self.assertTrue(np.all(masked[:, 3:5] == 0)) - self.assertTrue(np.all(masked[3:5, :] == 0)) - diff --git a/model_server/extensions/chaeo/tests/test_roiset_workflow.py b/model_server/extensions/chaeo/tests/test_roiset_workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..53c830b830d3519d9a17ff798a84c691f6724ea0 --- /dev/null +++ b/model_server/extensions/chaeo/tests/test_roiset_workflow.py @@ -0,0 +1,67 @@ +import unittest + +from model_server.base.models import DummyInstanceSegmentationModel +from model_server.base.roiset import RoiSetMetaParams, RoiSetExportParams +from model_server.conf.testing import output_path, roiset_test_data +from model_server.extensions.chaeo.workflows import infer_object_map_from_zstack +from model_server.extensions.ilastik.models import IlastikPixelClassifierModel + +from tests.test_roiset import BaseTestRoiSetMonoProducts + +class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase): + + def test_object_map_workflow(self): + pp = roiset_test_data['pipeline_params'] + + models = { + 'pixel_classifier': { + 'model': IlastikPixelClassifierModel(params={'project_file': roiset_test_data['pixel_classifier']}), + 'params': { + 'px_class': 0, + 'px_prob_threshold': 0.6, + } + }, + 'object_classifier': { + 'name': 'dummy', + 'model': DummyInstanceSegmentationModel(), + } + } + + roi_params = RoiSetMetaParams(**{ + 'mask_type': 'boxes', + 'filters': { + 'area': {'min': 1e3, 'max': 1e8} + }, + 'expand_box_by': [128, 2] + }) + + export_params = RoiSetExportParams(**{ + 'pixel_probabilities': True, + 'patches_3d': {}, + 'annotated_patches_2d': { + 'draw_bounding_box': True, + 'rgb_overlay_channels': [3, None, None], + 'rgb_overlay_weights': [0.2, 1.0, 1.0], + 'pad_to': 512, + }, + 'patches_2d': { + 'draw_bounding_box': False, + 'draw_mask': False, + }, + 'patch_masks': { + 'pad_to': 256, + }, + 'annotated_zstacks': {}, + 'object_classes': True, + 'dataframe': True, + }) + + infer_object_map_from_zstack( + roiset_test_data['multichannel_zstack']['path'], + output_path / 'roiset' / 'workflow', + models, + segmentation_channel=pp['segmentation_channel'], + patches_channel=pp['patches_channel'], + export_params=export_params, + roi_params=roi_params, + ) \ No newline at end of file diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py deleted file mode 100644 index 5a8d62ab0da97eaccb7d05083c78b3df9d9a61d1..0000000000000000000000000000000000000000 --- a/model_server/extensions/chaeo/tests/test_zstack.py +++ /dev/null @@ -1,191 +0,0 @@ -import unittest - -import numpy as np - -from model_server.conf.testing import output_path - -from model_server.extensions.chaeo.conf.testing import multichannel_zstack, pixel_classifier, pipeline_params -from model_server.extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack -from model_server.extensions.chaeo.zmask import build_zmask_from_object_mask -from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file -from model_server.extensions.ilastik.models import IlastikPixelClassifierModel - -class TestZStackDerivedDataProducts(unittest.TestCase): - - def setUp(self) -> None: - # need test data incl obj map - self.stack = generate_file_accessor(multichannel_zstack['path']) - - pxmodel = IlastikPixelClassifierModel( - {'project_file': pixel_classifier['path']}, - ) - mip = InMemoryDataAccessor(self.stack.get_one_channel_data(channel=0).data.max(axis=-1, keepdims=True)) - self.pxmap, result = pxmodel.infer(mip) - - # write_accessor_data_to_file(output_path / 'pxmap.tif', self.pxmap) - self.obmap = InMemoryDataAccessor(self.pxmap.data > pipeline_params['threshold']) - # write_accessor_data_to_file(output_path / 'obmap.tif', self.obmap) - - def test_zmask_makes_correct_boxes(self, mask_type='boxes', **kwargs): - zmask, meta, df, interm = build_zmask_from_object_mask( - self.obmap.get_one_channel_data(0), - self.stack.get_one_channel_data(0), - mask_type=mask_type, - **kwargs, - ) - zmask_acc = InMemoryDataAccessor(zmask) - self.assertTrue(zmask_acc.is_mask()) - - # assert dimensionality of zmask - self.assertGreater(zmask_acc.shape_dict['Z'], 1) - self.assertEqual(zmask_acc.shape_dict['C'], 1) - write_accessor_data_to_file(output_path / 'zmask.tif', zmask_acc) - - # mask values are not just all True or all False - self.assertTrue(np.any(zmask)) - self.assertFalse(np.all(zmask)) - - # assert non-trivial meta info in boxes - self.assertGreater(len(meta), 1) - sh = meta[1]['mask'].shape - ar = meta[1]['info'].area - self.assertGreaterEqual(sh[0] * sh[1], ar) - - # assert dimensionality of intermediate data products - self.assertEqual(interm['label_map'].shape, zmask.shape[0:2]) - self.assertEqual(interm['argmax'].shape, zmask.shape[0:2]) - - return zmask, meta - - def test_zmask_works_on_non_zstacks(self, **kwargs): - acc_zstack_slice = InMemoryDataAccessor(self.stack.data[:, :, 0, 0]) - self.assertEqual(acc_zstack_slice.nz, 1) - zmask, meta, df, interm = build_zmask_from_object_mask( - self.obmap.get_one_channel_data(0), - acc_zstack_slice, - mask_type='boxes', - **kwargs, - ) - zmask_acc = InMemoryDataAccessor(zmask) - self.assertTrue(zmask_acc.is_mask()) - - def test_zmask_makes_correct_contours(self): - return self.test_zmask_makes_correct_boxes(mask_type='contours') - - def test_zmask_makes_correct_boxes_with_filters(self): - return self.test_zmask_makes_correct_boxes(filters={'area': (1e3, 1e4)}) - - def test_zmask_makes_correct_expanded_boxes(self): - return self.test_zmask_makes_correct_boxes(expand_box_by=(64, 2)) - - def test_make_2d_patches_from_zmask(self): - zmask, meta = self.test_zmask_makes_correct_boxes( - filters={'area': (1e3, 1e4)}, - expand_box_by=(64, 2) - ) - files = export_patches_from_zstack( - output_path / '2d_patches', - self.stack.get_one_channel_data(channel=1), - meta, - draw_bounding_box=True, - ) - self.assertGreaterEqual(len(files), 1) - - def test_make_3d_patches_from_zmask(self): - zmask, meta = self.test_zmask_makes_correct_boxes( - filters={'area': (1e3, 1e4)}, - expand_box_by=(64, 2), - ) - files = export_patches_from_zstack( - output_path / '3d_patches', - self.stack.get_one_channel_data(4), - meta, - make_3d=True) - self.assertGreaterEqual(len(files), 1) - - def test_flatten_image(self): - zmask, meta, df, interm = build_zmask_from_object_mask( - self.obmap.get_one_channel_data(0), - self.stack.get_one_channel_data(4), - mask_type='boxes', - ) - - from model_server.extensions.chaeo.zmask import project_stack_from_focal_points - - dff = df[df['keeper']] - - img = project_stack_from_focal_points( - dff['centroid-0'].to_numpy(), - dff['centroid-1'].to_numpy(), - dff['zi'].to_numpy(), - self.stack, - degree=4, - ) - - self.assertEqual(img.shape[0:2], self.stack.shape[0:2]) - - write_accessor_data_to_file( - output_path / 'flattened.tif', - InMemoryDataAccessor(img) - ) - - def test_make_multichannel_2d_patches_from_zmask(self): - zmask, meta = self.test_zmask_makes_correct_boxes( - filters={'area': (1e3, 1e4)}, - expand_box_by=(128, 2) - ) - files = export_multichannel_patches_from_zstack( - output_path / '2d_patches_chlorophyl_bbox_overlay', - InMemoryDataAccessor(self.stack.data), - meta, - ch_white=4, - draw_bounding_box=True, - bounding_box_channel=1, - ) - self.assertGreaterEqual(len(files), 1) - - def test_make_multichannel_2d_patches_with_mask_overlay(self): - zmask, meta = self.test_zmask_makes_correct_boxes( - filters={'area': (1e3, 1e4)}, - expand_box_by=(128, 2) - ) - files = export_multichannel_patches_from_zstack( - output_path / '2d_patches_chlorophyl_mask_overlay', - InMemoryDataAccessor(self.stack.data), - meta, - ch_white=4, - ch_rgb_overlay=(3, None, None), - draw_mask=True, - mask_channel=0, - overlay_gain=(0.1, 1.0, 1.0) - ) - self.assertGreaterEqual(len(files), 1) - - def test_make_multichannel_2d_patches_with_contour_overlay(self): - zmask, meta = self.test_zmask_makes_correct_boxes( - filters={'area': (1e3, 1e4)}, - expand_box_by=(128, 2) - ) - files = export_multichannel_patches_from_zstack( - output_path / '2d_patches_chlorophyl_contour_overlay', - InMemoryDataAccessor(self.stack.data), - meta, - ch_white=4, - ch_rgb_overlay=(3, None, None), - draw_contour=True, - contour_channel=1, - overlay_gain=(0.1, 1.0, 1.0) - ) - self.assertGreaterEqual(len(files), 1) - - def test_make_binary_masks_from_zmask(self): - zmask, meta = self.test_zmask_makes_correct_boxes( - filters={'area': (1e3, 1e4)}, - expand_box_by=(128, 2) - ) - files = export_patch_masks_from_zstack( - output_path / '2d_mask_patches', - InMemoryDataAccessor(self.stack.data), - meta, - ) - self.assertGreaterEqual(len(files), 1) \ No newline at end of file diff --git a/model_server/extensions/chaeo/workflows.py b/model_server/extensions/chaeo/workflows.py index c7995acd2e56d13fe62cf4e0d9b90c2312456b81..c651f95ecdf68f4b8f3aa431e6d4e512a2467aa4 100644 --- a/model_server/extensions/chaeo/workflows.py +++ b/model_server/extensions/chaeo/workflows.py @@ -1,312 +1,71 @@ from pathlib import Path from typing import Dict, List -from uuid import uuid4 import numpy as np import pandas as pd + from skimage.measure import label from skimage.morphology import dilation from sklearn.model_selection import train_test_split -from model_server.extensions.chaeo.annotators import draw_boxes_on_3d_image -from model_server.extensions.chaeo.models import PatchStackObjectClassifier -from model_server.extensions.chaeo.process import mask_largest_object -from model_server.extensions.chaeo.products import export_patches_from_zstack, export_patch_masks_from_zstack, export_multichannel_patches_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta -from model_server.extensions.chaeo.zmask import build_zmask_from_object_mask, project_stack_from_focal_points -from model_server.extensions.ilastik.models import IlastikPixelClassifierModel +from base.roiset import RoiSetMetaParams, RoiSetExportParams +from base.process import mask_largest_object +from base.roiset import _get_label_ids, RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file -from model_server import Model -from model_server.base.process import rescale +from model_server.base.models import Model, InstanceSegmentationModel, SemanticSegmentationModel from model_server.base.workflows import Timer -def get_zmask_meta( - input_file_path: str, - ilastik_pixel_classifier: IlastikPixelClassifierModel, - segmentation_channel: int, - pxmap_threshold: float, - pxmap_foreground_channel: int = 0, - zmask_zindex: int = None, - zmask_clip: int = None, - zmask_filters: Dict = None, - zmask_type: str = 'boxes', - pxmap_use_all_channels: bool = False, - **kwargs, -) -> tuple: - ti = Timer() - stack = generate_file_accessor(Path(input_file_path)) - fstem = Path(input_file_path).stem - ti.click('file_input') - - # MIP if no zmask z-index is given, then classify pixels - if isinstance(zmask_zindex, int): - assert 0 < zmask_zindex < stack.nz - if pxmap_use_all_channels: - zmask_data = stack.data[:, :, :, zmask_zindex] - else: - zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data[:, :, :, zmask_zindex] - else: - if pxmap_use_all_channels: - zmask_data = stack.data.max(axis=-1, keepdims=True) - else: - zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data.max(axis=-1, keepdims=True) - if zmask_clip: - zmask_data = rescale(zmask_data, zmask_clip) - mip = InMemoryDataAccessor( - zmask_data, - ) - pxmap, _ = ilastik_pixel_classifier.infer(mip) - ti.click('infer_pixel_probability') - - obmask = InMemoryDataAccessor( - pxmap.data > pxmap_threshold - ) - ti.click('threshold_pixel_mask') - - # make zmask - zmask, zmask_meta, df, interm = build_zmask_from_object_mask( - obmask.get_one_channel_data(pxmap_foreground_channel), - stack.get_one_channel_data(segmentation_channel), - mask_type=zmask_type, - filters=zmask_filters, - expand_box_by=kwargs['zmask_expand_box_by'], - ) - ti.click('generate_zmasks') - - # record pixel scale - px_scale = stack.pixel_scale_in_micrometers.get('X') - df['pixel_scale_in_micrometers'] = float(px_scale) if px_scale is not None else None - - return ti, stack, fstem, obmask, pxmap, zmask, zmask_meta, df, interm - - -# TODO: unpack and validate inputs -def export_patches_from_multichannel_zstack( - input_file_path: str, - output_folder_path: str, - models: List[Model], - pxmap_threshold: float, - pxmap_foreground_channel: int, - segmentation_channel: int, # -1 to use all channels - patches_channel: int, - zmask_zindex: int = None, # None for MIP, - zmask_clip: int = None, - zmask_type: str = 'boxes', - zmask_filters: Dict = None, - zmask_expand_box_by: int = None, - export_pixel_probabilities=True, - export_2d_patches_for_training=True, - export_2d_patches_for_annotation=True, - draw_bounding_box_on_2d_patch=True, - draw_contour_on_2d_patch=False, - draw_mask_on_2d_patch=False, - export_3d_patches=True, - export_annotated_zstack=True, - draw_label_on_zstack=False, - export_patch_masks=True, - rgb_overlay_channels=(None, None, None), - rgb_overlay_weights=(1.0, 1.0, 1.0), - **kwargs, -) -> Dict: - pixel_classifier = models[0] - - ti, stack, fstem, obmask, pxmap, zmask, zmask_meta, df, interm = get_zmask_meta( - input_file_path, - pixel_classifier, - segmentation_channel, - pxmap_threshold, - pxmap_foreground_channel=pxmap_foreground_channel, - zmask_zindex=zmask_zindex, - zmask_clip=zmask_clip, - zmask_expand_box_by=zmask_expand_box_by, - zmask_filters=zmask_filters, - zmask_type=zmask_type, - **kwargs, - ) - - if export_pixel_probabilities: - write_accessor_data_to_file( - Path(output_folder_path) / 'pixel_probabilities' / (fstem + '.tif'), - pxmap - ) - ti.click('export_pixel_probability') - - if export_3d_patches and len(zmask_meta) > 0: - files = export_patches_from_zstack( - Path(output_folder_path) / '3d_patches', - stack.get_one_channel_data(patches_channel), - zmask_meta, - prefix=fstem, - draw_bounding_box=False, - rescale_clip=0.001, - make_3d=True, - ) - ti.click('export_3d_patches') - - if export_2d_patches_for_annotation and len(zmask_meta) > 0: - files = export_multichannel_patches_from_zstack( - Path(output_folder_path) / '2d_patches_annotation', - stack, - zmask_meta, - prefix=fstem, - rescale_clip=0.001, - make_3d=False, - focus_metric='max_sobel', - ch_white=patches_channel, - ch_rgb_overlay=rgb_overlay_channels, - draw_bounding_box=draw_bounding_box_on_2d_patch, - bounding_box_channel=1, - bounding_box_linewidth=2, - draw_contour=draw_contour_on_2d_patch, - draw_mask=draw_mask_on_2d_patch, - overlay_gain=rgb_overlay_weights, - ) - df_patches = pd.DataFrame(files) - ti.click('export_2d_patches') - # associate 2d patches, dropping labeled objects that were not exported as patches - df = pd.merge(df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') - # prepopulate patch UUID - df['patch_id'] = df.apply(lambda _: uuid4(), axis=1) - - if export_2d_patches_for_training and len(zmask_meta) > 0: - files = export_multichannel_patches_from_zstack( - Path(output_folder_path) / '2d_patches_training', - stack.get_one_channel_data(patches_channel), - zmask_meta, - prefix=fstem, - rescale_clip=0.001, - make_3d=False, - focus_metric='max_sobel', - ) - ti.click('export_2d_patches') - - if export_patch_masks and len(zmask_meta) > 0: - files = export_patch_masks_from_zstack( - Path(output_folder_path) / 'patch_masks', - stack.get_one_channel_data(patches_channel), - zmask_meta, - prefix=fstem, - ) - - if export_annotated_zstack: - annotated = InMemoryDataAccessor( - draw_boxes_on_3d_image( - stack.get_one_channel_data(patches_channel).data, - zmask_meta, - add_label=draw_label_on_zstack, - ) - ) - write_accessor_data_to_file( - Path(output_folder_path) / 'annotated_zstacks' / (fstem + '.tif'), - annotated - ) - ti.click('export_annotated_zstack') - - # generate multichannel projection from label centroids - dff = df[df['keeper']] - if len(zmask_meta) > 0: - interm['projected'] = project_stack_from_focal_points( - dff['centroid-0'].to_numpy(), - dff['centroid-1'].to_numpy(), - dff['zi'].to_numpy(), - stack, - degree=4, - ) - else: # else just return MIP - interm['projected'] = stack.data.max(axis=-1) - - return { - 'pixel_model_id': pixel_classifier.model_id, - 'input_filepath': input_file_path, - 'number_of_objects': len(zmask_meta), - 'pixeL_scale_in_micrometers': stack.pixel_scale_in_micrometers, - 'success': True, - 'timer_results': ti.events, - 'dataframe': df[df['keeper'] == True], - 'interm': interm, - } def infer_object_map_from_zstack( input_file_path: str, output_folder_path: str, models: List[Model], - pxmap_threshold: float, - pxmap_foreground_channel: int, segmentation_channel: int, patches_channel: int, zmask_zindex: int = None, # None for MIP, - zmask_clip: int = None, - zmask_type: str = 'boxes', - zmask_filters: Dict = None, - # zmask_expand_box_by: int = None, - **kwargs, + roi_params: RoiSetMetaParams = RoiSetMetaParams(), + export_params: RoiSetExportParams = RoiSetExportParams(), ) -> Dict: assert len(models) == 2 - pixel_classifier = models[0] - assert isinstance(pixel_classifier, IlastikPixelClassifierModel) - object_classifier = models[1] - assert isinstance(object_classifier, PatchStackObjectClassifier) - - ti, stack, fstem, obmask, pxmap, zmask, zmask_meta, df, interm = get_zmask_meta( - input_file_path, - pixel_classifier, - segmentation_channel, - pxmap_threshold, - pxmap_foreground_channel=pxmap_foreground_channel, - zmask_zindex=zmask_zindex, - zmask_clip=zmask_clip, - # zmask_expand_box_by=zmask_expand_box_by, - zmask_filters=zmask_filters, - zmask_type=zmask_type, - **kwargs - ) + assert isinstance(models['pixel_classifier']['model'], SemanticSegmentationModel) + assert isinstance(models['object_classifier']['model'], InstanceSegmentationModel) - # extract patches to accessor - patches_acc = get_patches_from_zmask_meta( - stack.get_one_channel_data(patches_channel), - zmask_meta, - rescale_clip=zmask_clip, - make_3d=False, - focus_metric='max_sobel', - **kwargs - ) - - # extract masks - patch_masks_acc = get_patch_masks_from_zmask_meta( - stack, - zmask_meta, - **kwargs - ) + ti = Timer() + stack = generate_file_accessor(Path(input_file_path)) + fstem = Path(input_file_path).stem + ti.click('file_input') - # send patches and mask stacks to object classifier - result_acc, _ = object_classifier.infer(patches_acc, patch_masks_acc) + # MIP if no zmask z-index is given, then classify pixels + if isinstance(zmask_zindex, int): + assert 0 < zmask_zindex < stack.nz + zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data[:, :, :, zmask_zindex] + else: + zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data.max(axis=-1, keepdims=True) + mip = InMemoryDataAccessor(zmask_data) - labels_map = interm['label_map'] - output_map = np.zeros(labels_map.shape, dtype=labels_map.dtype) - assert labels_map.shape == interm['label_map'].shape - assert labels_map.dtype == interm['label_map'].dtype + mip_mask = models['pixel_classifier']['model'].label_pixel_class(mip, **models['pixel_classifier']['params']) + ti.click('classify_pixels') - # assign labels to object map: - meta = [] - for ii in range(0, len(zmask_meta)): - object_id = zmask_meta[ii]['info'].label - result_patch = mask_largest_object(result_acc.iat(ii)) - object_class = np.unique(result_patch)[1] - output_map[labels_map == object_id] = object_class - meta.append({'object_id': ii, 'object_class': object_id}) + # make zmask + rois = RoiSet(stack, _get_label_ids(mip_mask), params=roi_params) + ti.click('generate_zmasks') - output_path = Path(output_folder_path) / ('obj_classes_' + (fstem + '.tif')) - write_accessor_data_to_file( - output_path, - InMemoryDataAccessor(output_map) + rois.classify_by( + models['object_classifier']['name'], + patches_channel, + models['object_classifier']['model'] ) - ti.click('export_object_classes') + ti.click('classify_objects') + + rois.run_exports(Path(output_folder_path), patches_channel, fstem, export_params) + ti.click('export_roi_products') return { 'timer_results': ti.events, - 'dataframe': pd.DataFrame(meta), + 'dataframe': rois.get_df(), 'interm': {}, - 'output_path': output_path.__str__(), + 'output_path': output_folder_path, } @@ -414,6 +173,4 @@ def transfer_ecotaxa_labels_to_patch_stacks( for k in zstacks.keys(): write_accessor_data_to_file(Path(where_output) / f'zstack_{k}.tif', InMemoryDataAccessor(zstacks[k])) - pd.DataFrame(stack_meta).to_csv(Path(where_output) / f'{dfk}_stack.csv', index=False) - - + pd.DataFrame(stack_meta).to_csv(Path(where_output) / f'{dfk}_stack.csv', index=False) \ No newline at end of file diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py deleted file mode 100644 index 1f74b5632d96eeb80f4def92f3b81d111b7060df..0000000000000000000000000000000000000000 --- a/model_server/extensions/chaeo/zmask.py +++ /dev/null @@ -1,198 +0,0 @@ -import numpy as np -import pandas as pd - -from skimage.measure import find_contours, label, regionprops_table -from sklearn.preprocessing import PolynomialFeatures -from sklearn.linear_model import LinearRegression - -from model_server.base.accessors import GenericImageDataAccessor - -def build_zmask_from_object_mask( - obmask: GenericImageDataAccessor, - zstack: GenericImageDataAccessor, - filters=None, - mask_type='contours', - expand_box_by=(0, 0), -): - """ - Given a 2D mask of objects, build a 3D mask, where each object's z-position is determined by the index of - maximum intensity in z. Return this zmask and a list of each object's meta information. - :param obmask: GenericImageDataAccessor monochrome 2D inary mask of objects - :param zstack: GenericImageDataAccessor monochrome zstack of same Y, X dimension as obmask - :param filters: dictionary of form {attribute: (min, max)}; valid attributes are 'area' and 'solidity' - :param mask_type: if 'boxes', zmask is True in each object's complete bounding box; otherwise 'contours' - :param expand_box_by: (xy, z) expands bounding box by (xy, z) pixels except where this hits a boundary - :return: tuple (zmask, meta) - np.ndarray: - boolean mask of same size as stack - List containing one Dict per object, with keys: - info: object's properties from skimage.measure.regionprops_table, including bounding box (y0, y1, x0, x1) - slice: named slice (np.s_) of (optionally) expanded bounding box - relative_bounding_box: bounding box (y0, y1, x0, x1) in relative frame of (optionally) expanded bounding box - contour: object's contour returned by skimage.measure.find_contours - mask: mask of object in relative frame of (optionally) expanded bounding box - pd.DataFrame: objects, including bounding, box information after filtering - Dict of intermediate image products: - label_map: np.ndarray (h x w) where each unique object has an integer label - argmax: np.ndarray (h x w x 1 x 1) z-index of highest intensity in zstack - """ - - # validate inputs - assert zstack.chroma == 1 - assert mask_type in ('contours', 'boxes'), mask_type - assert obmask.is_mask() - assert obmask.chroma == 1 - assert obmask.nz == 1 - assert zstack.hw == obmask.hw - - # assign object labels and build object query - lamap = label(obmask.data[:, :, 0, 0]).astype('uint16') - query_str = 'label > 0' # always true - if filters is not None: - for k in filters.keys(): - assert k in ('area', 'solidity') - vmin, vmax = filters[k] - assert vmin >= 0 - query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}' - - # build dataframe of objects, assign z index to each object - argmax = zstack.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16') - df = ( - pd.DataFrame( - regionprops_table( - lamap, - intensity_image=argmax, - properties=('label', 'area', 'intensity_mean', 'solidity', 'bbox', 'centroid') - ) - ) - .rename( - columns={ - 'bbox-0': 'y0', - 'bbox-1': 'x0', - 'bbox-2': 'y1', - 'bbox-3': 'x1', - } - ) - ) - df['zi'] = df['intensity_mean'].round().astype('int') - df['keeper'] = False - df.loc[df.query(query_str).index, 'keeper'] = True - - # make an object map where label is replaced by focus position in stack and background is -1 - lut = np.zeros(lamap.max() + 1) - 1 - lut[df.label] = df.zi - - # convert bounding boxes to numpy slice objects - ebxy, ebz = expand_box_by - h, w, c, nz = zstack.shape - - meta = [] - for ob in df[df['keeper']].itertuples(name='LabeledObject'): - y0 = max(ob.y0 - ebxy, 0) - y1 = min(ob.y1 + ebxy, h) - x0 = max(ob.x0 - ebxy, 0) - x1 = min(ob.x1 + ebxy, w) - z0 = max(ob.zi - ebz, 0) - z1 = min(ob.zi + ebz, nz) - - # relative bounding box positions - rbb = { - 'y0': ob.y0 - y0, - 'y1': ob.y1 - y0, - 'x0': ob.x0 - x0, - 'x1': ob.x1 - x0, - } - - sl = np.s_[y0: y1, x0: x1, :, z0: z1 + 1] - - # compute contours - obmask = (lamap == ob.label) - contour = find_contours(obmask) - mask = obmask[ob.y0: ob.y1, ob.x0: ob.x1] - - assert rbb['x1'] <= (x1 - x0) - assert rbb['y1'] <= (y1 - y0) - - meta.append({ - 'df_index': ob.Index, - 'info': ob, - 'slice': sl, - 'relative_bounding_box': rbb, - 'contour': contour, - 'mask': mask - }) - - # build mask z-stack - zi_st = np.zeros(zstack.shape, dtype='bool') - if mask_type == 'contours': - zi_map = (lut[lamap] + 1.0).astype('int') - idxs = np.array(zi_map) - 1 - np.put_along_axis( - zi_st, - np.expand_dims(idxs, (2, 3)), - 1, - axis=3 - ) - - # change background level from to 0 in final frame - zi_st[:, :, :, -1][lamap == 0] = 0 - - elif mask_type == 'boxes': - for bb in meta: - sl = bb['slice'] - zi_st[sl] = 1 - - # return intermediate image arrays - interm = { - 'label_map': lamap, - 'argmax': argmax, - } - - return zi_st, meta, df, interm - - -def project_stack_from_focal_points( - xx: np.ndarray, - yy: np.ndarray, - zz: np.ndarray, - stack: GenericImageDataAccessor, - degree: int = 2, -) -> np.ndarray: - """ - Given a set of 3D points, project a multichannel z-stack based on a surface fit of the provided points - :param xx: vector of point x-coordinates - :param yy: vector of point y-coordinates - :param zz: vector of point z-coordinates - :param stack: z-stack to project - :param degree: order of polynomial to fit - :return: multichannel 2d projected image array - """ - assert xx.shape == yy.shape - assert xx.shape == zz.shape - - poly = PolynomialFeatures(degree=degree) - X = np.stack([xx, yy]).T - features = poly.fit_transform(X, zz) - model = LinearRegression(fit_intercept=False) - model.fit(features, zz) - - xy_indices = np.indices(stack.hw).reshape(2, -1).T - xy_features = np.dot( - poly.fit_transform(xy_indices, zz), - model.coef_ - ) - zi_image = xy_features.reshape( - stack.hw - ).round().clip( - 0, (stack.nz - 1) - ).astype('uint16') - - return np.take_along_axis( - stack.data, - np.repeat( - np.expand_dims(zi_image, (2, 3)), - stack.chroma, - axis=2 - ), - axis=3 - ) diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 6b55c4cdd14e11335b479ea5f8c6551c0c7c92e1..91493a2bed5fc0f676206a3cf1b75a3ff4f7c00b 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -5,11 +5,12 @@ import numpy as np import vigra import model_server.extensions.ilastik.conf +from base.accessors import PatchStack from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor -from model_server.base.models import ImageToImageModel, ParameterExpectedError +from model_server.base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel -class IlastikImageToImageModel(ImageToImageModel): +class IlastikModel(Model): def __init__(self, params, autoload=True): self.project_file = Path(params['project_file']) @@ -27,7 +28,6 @@ class IlastikImageToImageModel(ImageToImageModel): self.shell = None super().__init__(autoload, params) - def load(self): from ilastik import app from ilastik.applets.dataSelection.opDataSelection import PreloadedArrayDatasetInfo @@ -51,8 +51,9 @@ class IlastikImageToImageModel(ImageToImageModel): return True -class IlastikPixelClassifierModel(IlastikImageToImageModel): +class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): model_id = 'ilastik_pixel_classification' + operations = ['segment', ] @staticmethod def get_workflow(): @@ -77,28 +78,42 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel): ) return InMemoryDataAccessor(data=yxcz), {'success': True} -class IlastikObjectClassifierFromPixelPredictionsModel(IlastikImageToImageModel): - model_id = 'ilastik_object_classification_from_pixel_predictions' + def label_pixel_class(self, img: GenericImageDataAccessor, px_class: int = 0, px_prob_threshold=0.5, **kwargs): + pxmap, _ = self.infer(img) + mask = pxmap.data[:, :, px_class, :] > px_prob_threshold + return InMemoryDataAccessor(mask) + + +class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel): + model_id = 'ilastik_object_classification_from_segmentation' @staticmethod def get_workflow(): - from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction - return ObjectClassificationWorkflowPrediction + from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary + return ObjectClassificationWorkflowBinary - def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): + def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') - tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') + assert segmentation_img.is_mask() + if segmentation_img.dtype == 'bool': + seg = 255 * segmentation_img.data.astype('uint8') + tagged_seg_data = vigra.taggedView( + 255 * segmentation_img.data.astype('uint8'), + 'yxcz' + ) + else: + tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') dsi = [ { 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), - 'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data), + 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), } ] obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] - assert (len(obmaps) == 1, 'ilastik generated more than one object map') + assert len(obmaps) == 1, 'ilastik generated more than one object map' yxcz = np.moveaxis( obmaps[0], @@ -107,30 +122,34 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikImageToImageModel) ) return InMemoryDataAccessor(data=yxcz), {'success': True} + def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): + super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) + obmap, _ = self.infer(img, mask) + return obmap -class IlastikObjectClassifierFromSegmentationModel(IlastikImageToImageModel): - model_id = 'ilastik_object_classification_from_segmentation' + +class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImageModel): + model_id = 'ilastik_object_classification_from_pixel_predictions' @staticmethod def get_workflow(): - from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary - return ObjectClassificationWorkflowBinary + from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction + return ObjectClassificationWorkflowPrediction - def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): - assert segmentation_img.is_mask() + def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') - tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') + tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') dsi = [ { 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), - 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), + 'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data), } ] obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] - assert (len(obmaps) == 1, 'ilastik generated more than one object map') + assert len(obmaps) == 1, 'ilastik generated more than one object map' yxcz = np.moveaxis( obmaps[0], @@ -138,3 +157,63 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikImageToImageModel): [0, 1, 2, 3] ) return InMemoryDataAccessor(data=yxcz), {'success': True} + + + def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs): + """ + Given an image and a map of pixel probabilities of the same shape, return a map where each connected object is + assigned a class. + :param img: input image + :param pxmap: map of pixel probabilities + :param kwargs: + pixel_classification_channel: channel of pxmap used to segment objects + pixel_classification_thresold: threshold of pxmap used to segment objects + :return: + """ + if not img.shape == pxmap.shape: + raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape') + # TODO: check that pxmap is in-range + pxch = kwargs.get('pixel_classification_channel', 0) + pxtr = kwargs('pixel_classification_threshold', 0.5) + mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr) + # super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) + obmap, _ = self.infer(img, mask) + return obmap + + +class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel): + """ + Wrap ilastik object classification for inputs comprising single-object series of raw images and binary + segmentation masks. + """ + + def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict): + assert segmentation_acc.is_mask() + if not input_acc.chroma == 1: + raise InvalidInputImageError('Object classifier expects only monochrome patches') + if not input_acc.nz == 1: + raise InvalidInputImageError('Object classifier expects only 2d patches') + + tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx') + tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx') + + dsi = [ + { + 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), + 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), + } + ] + + obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] + + assert len(obmaps) == 1, 'ilastik generated more than one object map' + + # for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C + assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1) + pyxcz = np.moveaxis( + obmaps[0], + [0, 1, 2, 3, 4], + [0, 4, 1, 2, 3] + ) + + return PatchStack(data=pyxcz), {'success': True} \ No newline at end of file diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py index 5b38aa12ef68e6054e357ed4ca7ba37b0316f035..b7bee14e3d669822bfb5d0138b3c800df9c4869e 100644 --- a/model_server/extensions/ilastik/router.py +++ b/model_server/extensions/ilastik/router.py @@ -14,7 +14,7 @@ router = APIRouter( session = Session() -def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: str, duplicate=True) -> dict: +def load_ilastik_model(model_class: ilm.IlastikModel, project_file: str, duplicate=True) -> dict: """ Load an ilastik model of a given class and project filename. :param model_class: @@ -35,7 +35,7 @@ def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: ) return {'model_id': result} -@router.put('/px/load/') +@router.put('/seg/load/') def load_px_model(project_file: str, duplicate: bool = True) -> dict: return load_ilastik_model(ilm.IlastikPixelClassifierModel, project_file, duplicate=duplicate) diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index d9e8c429a5f65805ca8bb2edacee035db5ff4b43..89fe2b98d1f8ea25702e8a221c6178e8fe3a645b 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -4,10 +4,11 @@ import unittest import numpy as np -from model_server.conf.testing import czifile, ilastik_classifiers, output_path -from model_server.base.accessors import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file +from model_server.conf.testing import czifile, ilastik_classifiers, output_path, roiset_test_data +from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.extensions.ilastik import models as ilm -from model_server.base.workflows import infer_image_to_image +from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams +from model_server.base.workflows import classify_pixels from tests.test_api import TestServerBaseClass class TestIlastikPixelClassification(unittest.TestCase): @@ -35,7 +36,7 @@ class TestIlastikPixelClassification(unittest.TestCase): input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1)) with self.assertRaises(AttributeError): - pxmap, _ = model.infer(input_img) + mask = model.label_pixel_class(input_img) def test_run_pixel_classifier_on_random_data(self): @@ -47,8 +48,8 @@ class TestIlastikPixelClassification(unittest.TestCase): input_img = InMemoryDataAccessor(data=np.random.rand(h, w, 1, 1)) - pxmap, _ = model.infer(input_img) - self.assertEqual(pxmap.shape, (h, w, 2, 1)) + mask = model.label_pixel_class(input_img) + self.assertEqual(mask.shape, (h, w, 1, 1)) def test_run_pixel_classifier(self): @@ -66,28 +67,29 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertEqual(mono_image.shape_dict['C'], 1) self.assertEqual(mono_image.shape_dict['Z'], 1) - pxmap, _ = model.infer(mono_image) + mask = model.label_pixel_class(mono_image) - self.assertEqual(pxmap.shape[0:2], cf.shape[0:2]) - self.assertEqual(pxmap.shape_dict['C'], 2) - self.assertEqual(pxmap.shape_dict['Z'], 1) + self.assertTrue(mask.is_mask()) + self.assertEqual(mask.shape[0:2], cf.shape[0:2]) + self.assertEqual(mask.shape_dict['C'], 1) + self.assertEqual(mask.shape_dict['Z'], 1) self.assertTrue( write_accessor_data_to_file( output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif', - pxmap + mask ) ) self.mono_image = mono_image - self.pxmap = pxmap + self.mask = mask - def test_run_object_classifier(self): + def test_run_object_classifier_from_pixel_predictions(self): self.test_run_pixel_classifier() fp = czifile['path'] model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( {'project_file': ilastik_classifiers['pxmap_to_obj']} ) - objmap, _ = model.infer(self.mono_image, self.pxmap) + objmap, _ = model.infer(self.mono_image, self.mask) self.assertTrue( write_accessor_data_to_file( @@ -97,8 +99,24 @@ class TestIlastikPixelClassification(unittest.TestCase): ) self.assertEqual(objmap.data.max(), 3) + def test_run_object_classifier_from_segmentation(self): + self.test_run_pixel_classifier() + fp = czifile['path'] + model = ilm.IlastikObjectClassifierFromSegmentationModel( + {'project_file': ilastik_classifiers['seg_to_obj']} + ) + objmap = model.label_instance_class(self.mono_image, self.mask) + + self.assertTrue( + write_accessor_data_to_file( + output_path / f'obmap_from_seg_{fp.stem}.tif', + objmap, + ) + ) + self.assertEqual(objmap.data.max(), 3) + def test_ilastik_pixel_classification_as_workflow(self): - result = infer_image_to_image( + result = classify_pixels( czifile['path'], ilm.IlastikPixelClassifierModel( {'project_file': ilastik_classifiers['px']} @@ -113,7 +131,7 @@ class TestIlastikOverApi(TestServerBaseClass): def test_httpexception_if_incorrect_project_file_loaded(self): resp_load = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={'project_file': 'improper.ilp'}, ) self.assertEqual(resp_load.status_code, 404) @@ -121,7 +139,7 @@ class TestIlastikOverApi(TestServerBaseClass): def test_load_ilastik_pixel_model(self): resp_load = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={'project_file': str(ilastik_classifiers['px'])}, ) self.assertEqual(resp_load.status_code, 200, resp_load.json()) @@ -137,7 +155,7 @@ class TestIlastikOverApi(TestServerBaseClass): resp_list_1st = requests.get(self.uri + 'models').json() self.assertEqual(len(resp_list_1st), 1, resp_list_1st) resp_load_2nd = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={ 'project_file': str(ilastik_classifiers['px']), 'duplicate': True, @@ -146,7 +164,7 @@ class TestIlastikOverApi(TestServerBaseClass): resp_list_2nd = requests.get(self.uri + 'models').json() self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd) resp_load_3rd = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={ 'project_file': str(ilastik_classifiers['px']), 'duplicate': False, @@ -172,14 +190,14 @@ class TestIlastikOverApi(TestServerBaseClass): # load models with these paths resp1 = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={ 'project_file': ilp_win, 'duplicate': False, }, ) resp2 = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={ 'project_file': ilp_posx, 'duplicate': False, @@ -226,7 +244,7 @@ class TestIlastikOverApi(TestServerBaseClass): model_id = self.test_load_ilastik_pixel_model() resp_infer = requests.put( - self.uri + f'infer/from_image_file', + self.uri + f'workflows/segment', params={ 'model_id': model_id, 'input_filename': czifile['filename'], @@ -251,4 +269,32 @@ class TestIlastikOverApi(TestServerBaseClass): ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) - # TODO: test IlastikObjectClassifierFromSegmentationModel when a test model is complete \ No newline at end of file +class TestIlastikObjectClassification(unittest.TestCase): + def setUp(self): + stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) + stack_ch_pa = stack.get_one_channel_data(roiset_test_data['pipeline_params']['patches_channel']) + seg_mask = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path']) + + self.roiset = RoiSet( + stack_ch_pa, + _get_label_ids(seg_mask), + params=RoiSetMetaParams( + mask_type='boxes', + filters={'area': {'min': 1e3, 'max': 1e4}}, + expand_box_by=(64, 2) + ) + ) + + self.object_classifier = ilm.PatchStackObjectClassifier( + params={'project_file': ilastik_classifiers['seg_to_obj']} + ) + + def test_classify_patches(self): + raw_patches = self.roiset.get_raw_patches() + patch_masks = self.roiset.get_patch_masks() + res_patches, _ = self.object_classifier.infer(raw_patches, patch_masks) + self.assertEqual(res_patches.count, self.roiset.count) + for pi in range(0, res_patches.count): # assert that there is only one nonzero label per patch + unique = np.unique(res_patches.iat(pi).data) + self.assertEqual(len(unique), 2) + self.assertEqual(unique[0], 0) diff --git a/model_server/extensions/ilastik/validators.py b/model_server/extensions/ilastik/validators.py deleted file mode 100644 index 28d92c15b79bf923b1cb4098bfc688e47d7b3896..0000000000000000000000000000000000000000 --- a/model_server/extensions/ilastik/validators.py +++ /dev/null @@ -1,5 +0,0 @@ -from model_server.base.accessors import GenericImageDataAccessor - -def is_object_map(img: GenericImageDataAccessor): - # TODO: implement - pass \ No newline at end of file diff --git a/tests/test_accessors.py b/tests/test_accessors.py new file mode 100644 index 0000000000000000000000000000000000000000..d622a368dae6e89a24dabf1044391092f48b88b5 --- /dev/null +++ b/tests/test_accessors.py @@ -0,0 +1,188 @@ +import unittest + +import numpy as np + +from base.accessors import PatchStack, make_patch_stack_from_file, FileNotFoundError +from conf.testing import monozstackmask + +from model_server.conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile, monozstackmask +from model_server.base.accessors import CziImageFileAccessor, DataShapeError, generate_file_accessor, InMemoryDataAccessor, PngFileAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor + +class TestCziImageFileAccess(unittest.TestCase): + + def setUp(self) -> None: + pass + + def test_tiffile_is_correct_shape(self): + tf = generate_file_accessor(tifffile['path']) + + self.assertIsInstance(tf, TifSingleSeriesFileAccessor) + self.assertEqual(tf.shape_dict['Y'], tifffile['h']) + self.assertEqual(tf.shape_dict['X'], tifffile['w']) + self.assertEqual(tf.chroma, tifffile['c']) + self.assertTrue(tf.is_3d()) + self.assertEqual(len(tf.data.shape), 4) + self.assertEqual(tf.shape[0], tifffile['h']) + self.assertEqual(tf.shape[1], tifffile['w']) + + def test_czifile_is_correct_shape(self): + cf = CziImageFileAccessor(czifile['path']) + self.assertEqual(cf.shape_dict['Y'], czifile['h']) + self.assertEqual(cf.shape_dict['X'], czifile['w']) + self.assertEqual(cf.chroma, czifile['c']) + self.assertFalse(cf.is_3d()) + self.assertEqual(len(cf.data.shape), 4) + self.assertEqual(cf.shape[0], czifile['h']) + self.assertEqual(cf.shape[1], czifile['w']) + + def test_get_single_channel_from_zstack(self): + w = 256 + h = 512 + nc = 4 + nz = 11 + c = 3 + cf = InMemoryDataAccessor(np.random.rand(h, w, nc, nz)) + sc = cf.get_one_channel_data(c) + self.assertEqual(sc.shape, (h, w, 1, nz)) + + def test_write_single_channel_tif(self): + ch = 4 + cf = CziImageFileAccessor(czifile['path']) + mono = cf.get_one_channel_data(ch) + self.assertTrue( + write_accessor_data_to_file( + output_path / f'{cf.fpath.stem}_ch{ch}.tif', + mono + ) + ) + self.assertEqual(cf.data.shape[0:2], mono.data.shape[0:2]) + self.assertEqual(cf.data.shape[3], mono.data.shape[2]) + + def test_conform_data_shorter_than_xycz(self): + h = 256 + w = 512 + data = np.random.rand(h, w, 1) + acc = InMemoryDataAccessor(data) + self.assertEqual( + InMemoryDataAccessor.conform_data(data).shape, + (256, 512, 1, 1) + ) + self.assertEqual( + acc.shape_dict, + {'Y': 256, 'X': 512, 'C': 1, 'Z': 1} + ) + + def test_conform_data_longer_than_xycz(self): + data = np.random.rand(256, 512, 12, 8, 3) + with self.assertRaises(DataShapeError): + acc = InMemoryDataAccessor(data) + + + def test_write_multichannel_image_preserve_axes(self): + h = 256 + w = 512 + c = 3 + nz = 10 + + yxcz = (2**8 * np.random.rand(h, w, c, nz)).astype('uint8') + acc = InMemoryDataAccessor(yxcz) + fp = output_path / f'rand3d.tif' + self.assertTrue( + write_accessor_data_to_file(fp, acc) + ) + # need to sort out x,y flipping since np convention yxcz flips axes in 3d tif + self.assertEqual(acc.shape_dict['X'], w, acc.shape_dict) + self.assertEqual(acc.shape_dict['Y'], h, acc.shape_dict) + + # re-open file and check axes order + from tifffile import TiffFile + fh = TiffFile(fp) + self.assertEqual(len(fh.series), 1) + se = fh.series[0] + fh_shape_dict = {se.axes[i]: se.shape[i] for i in range(0, len(se.shape))} + self.assertEqual(fh_shape_dict, acc.shape_dict, 'Axes are not preserved in TIF output') + + def test_read_png(self, pngfile=rgbpngfile): + acc = PngFileAccessor(pngfile['path']) + self.assertEqual(acc.hw, (pngfile['h'], pngfile['w'])) + self.assertEqual(acc.chroma, pngfile['c']) + self.assertEqual(acc.nz, 1) + + def test_read_mono_png(self): + return self.test_read_png(pngfile=monopngfile) + + def test_read_zstack_mono_mask(self): + acc = generate_file_accessor(monozstackmask['path']) + self.assertTrue(acc.is_mask()) + + def test_read_in_pixel_scale_from_czi(self): + cf = CziImageFileAccessor(czifile['path']) + pxs = cf.pixel_scale_in_micrometers + self.assertAlmostEqual(pxs['X'], czifile['um_per_pixel'], places=3) + + +class TestPatchStackAccessor(unittest.TestCase): + def setUp(self) -> None: + pass + + def test_make_patch_stack_from_3d_array(self): + w = 256 + h = 512 + n = 4 + acc = PatchStack(np.random.rand(n, h, w, 1, 1)) + self.assertEqual(acc.count, n) + self.assertEqual(acc.hw, (h, w)) + self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1)) + + def test_make_patch_stack_from_list(self): + w = 256 + h = 512 + n = 4 + acc = PatchStack([np.random.rand(h, w, 1, 1) for _ in range(0, n)]) + self.assertEqual(acc.count, n) + self.assertEqual(acc.hw, (h, w)) + self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1)) + return acc + + + def test_make_patch_stack_from_file(self): + h = monozstackmask['h'] + w = monozstackmask['w'] + c = monozstackmask['c'] + n = monozstackmask['z'] + + acc = make_patch_stack_from_file(monozstackmask['path']) + self.assertEqual(acc.hw, (h, w)) + self.assertEqual(acc.count, n) + self.assertEqual(acc.pyxcz.shape, (n, h, w, c, 1)) + + def test_raises_filenotfound(self): + with self.assertRaises(FileNotFoundError): + acc = make_patch_stack_from_file('c:/fake/file/name.tif') + + def test_make_3d_patch_stack_from_nonuniform_list(self): + w = 256 + h = 512 + c = 1 + nz = 5 + n = 4 + + patches = [np.random.rand(h, w, c, nz) for _ in range(0, n)] + patches.append(np.random.rand(h, 2 * w, c, nz)) + acc = PatchStack(patches) + self.assertEqual(acc.count, n + 1) + self.assertEqual(acc.hw, (h, 2 * w)) + self.assertEqual(acc.chroma, c) + self.assertEqual(acc.iat(0).shape, (h, 2 * w, c, nz)) + self.assertEqual(acc.iat_yxcz(0).shape, (h, 2 * w, c, nz)) + + def test_pczyx(self): + w = 256 + h = 512 + n = 4 + nz = 15 + nc = 2 + acc = PatchStack(np.random.rand(n, h, w, nc, nz)) + self.assertEqual(acc.count, n) + self.assertEqual(acc.pczyx.shape, (n, nc, nz, h, w)) + self.assertEqual(acc.hw, (h, w)) diff --git a/tests/test_api.py b/tests/test_api.py index 46bfbe235e24615081aad39ebca2b67079a32de4..ffcb96e7841c84498e96db1f1c3fde7f8a78fde3 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -70,7 +70,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() - self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') + self.assertEqual(rj[model_id]['class'], 'DummySemanticSegmentationModel') return model_id def test_respond_with_error_when_invalid_filepath_requested(self): @@ -89,7 +89,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): def test_i2i_inference_errors_when_model_not_found(self): model_id = 'not_a_real_model' resp = requests.put( - self.uri + f'infer/from_image_file', + self.uri + f'workflows/segment', params={ 'model_id': model_id, 'input_filename': 'not_a_real_file.name' @@ -101,7 +101,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): model_id = self.test_load_dummy_model() self.copy_input_file_to_server() resp_infer = requests.put( - self.uri + f'infer/from_image_file', + self.uri + f'workflows/segment', params={ 'model_id': model_id, 'input_filename': czifile['filename'], diff --git a/tests/test_model.py b/tests/test_model.py index 1938e80c26f32e0ffccd00a8dd6f8b028d692dd2..3f9f28727c250e1bc399a98c81e9c0206d55bb3d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,49 +1,56 @@ import unittest from model_server.conf.testing import czifile from model_server.base.accessors import CziImageFileAccessor -from model_server.base.models import DummyImageToImageModel, CouldNotLoadModelError +from model_server.base.models import DummySemanticSegmentationModel, DummyInstanceSegmentationModel, CouldNotLoadModelError class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) def test_instantiate_model(self): - model = DummyImageToImageModel(params=None) + model = DummySemanticSegmentationModel(params=None) self.assertTrue(model.loaded) def test_instantiate_model_with_nondefault_kwarg(self): - model = DummyImageToImageModel(autoload=False) + model = DummySemanticSegmentationModel(autoload=False) self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.') def test_raise_error_if_cannot_load_model(self): - class UnloadableDummyImageToImageModel(DummyImageToImageModel): + class UnloadableDummyImageToImageModel(DummySemanticSegmentationModel): def load(self): return False with self.assertRaises(CouldNotLoadModelError): mi = UnloadableDummyImageToImageModel() - def test_czifile_is_correct_shape(self): - model = DummyImageToImageModel() - img, _ = model.infer(self.cf) + def test_dummy_pixel_segmentation(self): + model = DummySemanticSegmentationModel() + img = self.cf.get_one_channel_data(0) + mask = model.label_pixel_class(img) w = czifile['w'] h = czifile['h'] self.assertEqual( - img.shape, + mask.shape, (h, w, 1, 1), 'Inferred image is not the expected shape' ) self.assertEqual( - img.data[int(w/2), int(h/2)], + mask.data[int(w/2), int(h/2)], 255, 'Middle pixel is not white as expected' ) self.assertEqual( - img.data[0, 0], + mask.data[0, 0], 0, 'First pixel is not black as expected' - ) \ No newline at end of file + ) + return img, mask + + def test_dummy_instance_segmentation(self): + img, mask = self.test_dummy_pixel_segmentation() + model = DummyInstanceSegmentationModel() + obmap = model.label_instance_class(img, mask) diff --git a/tests/test_process.py b/tests/test_process.py index 6908e9d7c1b779065d4de226b7a203979570fb9a..454151bc4f84156348624318a43837b5e9a370c8 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -2,6 +2,7 @@ import unittest import numpy as np +from base.process import mask_largest_object from model_server.base.process import pad class TestProcessingUtilityMethods(unittest.TestCase): @@ -27,4 +28,31 @@ class TestProcessingUtilityMethods(unittest.TestCase): nc = self.data4d.shape[2] nz = self.data4d.shape[3] padded = pad(self.data4d, 256) - self.assertEqual(padded.shape, (256, 256, nc, nz)) \ No newline at end of file + self.assertEqual(padded.shape, (256, 256, nc, nz)) + + +class TestMaskLargestObject(unittest.TestCase): + def test_mask_largest_touching_object(self): + arr = np.zeros([5, 5], dtype='uint8') + arr[0:3, 0:3] = 2 + arr[3:, 2:] = 4 + masked = mask_largest_object(arr) + self.assertTrue(np.all(np.unique(masked) == [0, 2])) + self.assertTrue(np.all(masked[4:5, 0:2] == 0)) + self.assertTrue(np.all(masked[0:3, 3:5] == 0)) + + def test_no_change(self): + arr = np.zeros([5, 5], dtype='uint8') + arr[0:3, 0:3] = 2 + masked = mask_largest_object(arr) + self.assertTrue(np.all(masked == arr)) + + def test_mask_multiple_objects_in_binary_maks(self): + arr = np.zeros([5, 5], dtype='uint8') + arr[0:3, 0:3] = 255 + arr[4, 2:5] = 255 + masked = mask_largest_object(arr) + print(np.unique(masked)) + self.assertTrue(np.all(np.unique(masked) == [0, 255])) + self.assertTrue(np.all(masked[:, 3:5] == 0)) + self.assertTrue(np.all(masked[3:5, :] == 0)) diff --git a/tests/test_roiset.py b/tests/test_roiset.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5a8b4608ec176594fa9763ccd0a9a8f0493726 --- /dev/null +++ b/tests/test_roiset.py @@ -0,0 +1,233 @@ +import unittest + +import numpy as np +from pathlib import Path + +from model_server.conf.testing import output_path, roiset_test_data + +from model_server.base.roiset import RoiSetMetaParams +from model_server.base.roiset import _get_label_ids, RoiSet +from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file +from model_server.base.models import DummyInstanceSegmentationModel + +class BaseTestRoiSetMonoProducts(object): + + def setUp(self) -> None: + # set up test raw data and segmentation from file + self.stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) + self.stack_ch_pa = self.stack.get_one_channel_data(roiset_test_data['pipeline_params']['patches_channel']) + self.seg_mask = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path']) + + +class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): + + def _make_roi_set(self, mask_type='boxes', **kwargs): + id_map = _get_label_ids(self.seg_mask) + roiset = RoiSet( + self.stack_ch_pa, + id_map, + params=RoiSetMetaParams( + mask_type=mask_type, + filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}), + expand_box_by=(64, 2) + ) + ) + return roiset + + def test_roi_mask_shape(self, **kwargs): + roiset = self._make_roi_set(**kwargs) + zmask = roiset.get_zmask() + zmask_acc = InMemoryDataAccessor(zmask) + self.assertTrue(zmask_acc.is_mask()) + + # assert dimensionality of zmask + self.assertGreater(zmask_acc.shape_dict['Z'], 1) + self.assertEqual(zmask_acc.shape_dict['C'], 1) + write_accessor_data_to_file(output_path / 'zmask.tif', zmask_acc) + + # mask values are not just all True or all False + self.assertTrue(np.any(zmask)) + self.assertFalse(np.all(zmask)) + + # assert non-trivial meta info in boxes + self.assertGreater(roiset.count, 1) + sh = roiset.get_df().iloc[1]['mask'].shape + ar = roiset.get_df().iloc[1]['area'] + self.assertGreaterEqual(sh[0] * sh[1], ar) + + def test_roiset_from_non_zstacks(self, **kwargs): + acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0]) + self.assertEqual(acc_zstack_slice.nz, 1) + id_map = _get_label_ids(self.seg_mask) + + roiset = RoiSet(acc_zstack_slice, id_map, params=RoiSetMetaParams(mask_type='boxes')) + zmask = roiset.get_zmask() + + zmask_acc = InMemoryDataAccessor(zmask) + self.assertTrue(zmask_acc.is_mask()) + + def test_slices_are_valid(self): + roiset = self._make_roi_set() + for s in roiset.get_slices(): + ebb = roiset.acc_raw.data[s] + self.assertEqual(len(ebb.shape), 4) + self.assertTrue(np.all([si >= 1 for si in ebb.shape])) + + def test_rel_slices_are_valid(self): + roiset = self._make_roi_set() + for roi in roiset: + ebb = roiset.acc_raw.data[roi.slice] + self.assertEqual(len(ebb.shape), 4) + self.assertTrue(np.all([si >= 1 for si in ebb.shape])) + rbb = ebb[roi.relative_slice] + self.assertEqual(len(rbb.shape), 4) + self.assertTrue(np.all([si >= 1 for si in rbb.shape])) + + def test_make_2d_patches(self): + roiset = self._make_roi_set() + files = roiset.export_patches( + output_path / '2d_patches', + draw_bounding_box=True, + ) + self.assertGreaterEqual(len(files), 1) + + def test_make_3d_patches(self): + roiset = self._make_roi_set() + files = roiset.export_patches( + output_path / '3d_patches', + make_3d=True) + self.assertGreaterEqual(len(files), 1) + + def test_export_annotated_zstack(self): + roiset = self._make_roi_set() + file = roiset.export_annotated_zstack( + output_path / 'annotated_zstack', + ) + result = generate_file_accessor(Path(file['location']) / file['filename']) + self.assertEqual(result.shape, roiset.acc_raw.shape) + + def test_flatten_image(self): + id_map = _get_label_ids(self.seg_mask) + + roiset = RoiSet(self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='boxes')) + df = roiset.get_df() + + from base.roiset import project_stack_from_focal_points + + img = project_stack_from_focal_points( + df['centroid-0'].to_numpy(), + df['centroid-1'].to_numpy(), + df['zi'].to_numpy(), + self.stack, + degree=4, + ) + + self.assertEqual(img.shape[0:2], self.stack.shape[0:2]) + + write_accessor_data_to_file( + output_path / 'flattened.tif', + InMemoryDataAccessor(img) + ) + + def test_make_binary_masks(self): + roiset = self._make_roi_set() + files = roiset.export_patch_masks(output_path / '2d_mask_patches', ) + self.assertGreaterEqual(len(files), 1) + + def test_classify_by(self): + roiset = self._make_roi_set() + roiset.classify_by('dummy_class', 0, DummyInstanceSegmentationModel()) + self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) + self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) + + def test_raw_patches_are_correct_shape(self): + roiset = self._make_roi_set() + patches = roiset.get_raw_patches() + np, h, w, nc, nz = patches.shape + self.assertEqual(np, roiset.count) + self.assertEqual(nc, roiset.acc_raw.chroma) + + def test_patch_masks_are_correct_shape(self): + roiset = self._make_roi_set() + patch_masks = roiset.get_patch_masks() + np, h, w, nc, nz = patch_masks.shape + self.assertEqual(np, roiset.count) + self.assertEqual(nc, 1) + + +class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): + + def setUp(self) -> None: + super().setUp() + id_map = _get_label_ids(self.seg_mask) + self.roiset = RoiSet( + self.stack, + id_map, + params=RoiSetMetaParams( + expand_box_by=(128, 2), + mask_type='boxes', + filters={'area': {'min': 1e3, 'max': 1e4}}, + ) + ) + + def test_multichannel_to_mono_2d_patches(self): + files = self.roiset.export_patches( + output_path / 'multichannel' / 'mono_2d_patches', + white_channel=3, + draw_bounding_box=True, + ) + result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + self.assertEqual(result.chroma, 1) + + def test_multichannnel_to_mono_2d_patches_rgb_bbox(self): + files = self.roiset.export_patches( + output_path / 'multichannel' / 'mono_2d_patches_rgb_bbox', + white_channel=3, + draw_bounding_box=True, + bounding_box_channel=1, + ) + result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + self.assertEqual(result.chroma, 3) + + def test_multichannnel_to_rgb_2d_patches_bbox(self): + files = self.roiset.export_patches( + output_path / 'multichannel' / 'rgb_2d_patches_bbox', + white_channel=4, + rgb_overlay_channels=(3, None, None), + draw_mask=True, + mask_channel=0, + rgb_overlay_weights=(0.1, 1.0, 1.0) + ) + result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + self.assertEqual(result.chroma, 3) + + def test_multichannnel_to_rgb_2d_patches_contour(self): + files = self.roiset.export_patches( + output_path / 'multichannel' / 'rgb_2d_patches_contour', + rgb_overlay_channels=(3, None, None), + draw_contour=True, + contour_channel=1, + rgb_overlay_weights=(0.1, 1.0, 1.0) + ) + result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + self.assertEqual(result.chroma, 3) + self.assertEqual(result.get_one_channel_data(2).data.max(), 0) # blue channel is black + + def test_multichannel_to_multichannel_tif_patches(self): + files = self.roiset.export_patches( + output_path / 'multichannel' / 'multichannel_tif_patches', + ) + result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + self.assertEqual(result.chroma, 5) + + def test_multichannel_annotated_zstack(self): + file = self.roiset.export_annotated_zstack( + output_path / 'multichannel' / 'annotated_zstack', + 'test_multichannel_annotated_zstack', + ) + result = generate_file_accessor(Path(file['location']) / file['filename']) + self.assertEqual(result.chroma, self.stack.chroma) + self.assertEqual(result.nz, self.stack.nz) + + + diff --git a/tests/test_session.py b/tests/test_session.py index dd143d5f48e6eec0057fca012a7e8e555139e016..9679aad61c699de8f07ff0a31e2cc4750cb6f9e3 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,6 +1,6 @@ import pathlib import unittest -from model_server.base.models import DummyImageToImageModel +from model_server.base.models import DummySemanticSegmentationModel from model_server.base.session import Session class TestGetSessionObject(unittest.TestCase): @@ -63,7 +63,7 @@ class TestGetSessionObject(unittest.TestCase): def test_session_loads_model(self): sesh = Session() - MC = DummyImageToImageModel + MC = DummySemanticSegmentationModel success = sesh.load_model(MC) self.assertTrue(success) loaded_models = sesh.describe_loaded_models() @@ -77,7 +77,7 @@ class TestGetSessionObject(unittest.TestCase): def test_session_loads_second_instance_of_same_model(self): sesh = Session() - MC = DummyImageToImageModel + MC = DummySemanticSegmentationModel sesh.load_model(MC) sesh.load_model(MC) self.assertIn(MC.__name__ + '_00', sesh.models.keys()) @@ -86,7 +86,7 @@ class TestGetSessionObject(unittest.TestCase): def test_session_loads_model_with_params(self): sesh = Session() - MC = DummyImageToImageModel + MC = DummySemanticSegmentationModel p1 = {'p1': 'abc'} success = sesh.load_model(MC, params=p1) self.assertTrue(success) @@ -103,7 +103,7 @@ class TestGetSessionObject(unittest.TestCase): def test_session_finds_existing_model_with_different_path_formats(self): sesh = Session() - MC = DummyImageToImageModel + MC = DummySemanticSegmentationModel p1 = {'path': 'c:\\windows\\dummy.pa'} p2 = {'path': 'c:/windows/dummy.pa'} mid = sesh.load_model(MC, params=p1) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 3c7c02a61d6d73fdfa2167c8d973af35c069a366..6e9603ea7418ba367b482a00b8dd8c9aafa8820d 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -1,16 +1,16 @@ import unittest from model_server.conf.testing import czifile, output_path -from model_server.base.models import DummyImageToImageModel -from model_server.base.workflows import infer_image_to_image +from model_server.base.models import DummySemanticSegmentationModel +from model_server.base.workflows import classify_pixels class TestGetSessionObject(unittest.TestCase): def setUp(self) -> None: - self.model = DummyImageToImageModel() + self.model = DummySemanticSegmentationModel() def test_single_session_instance(self): - result = infer_image_to_image(czifile['path'], self.model, output_path, channel=2) + result = classify_pixels(czifile['path'], self.model, output_path, channel=2) self.assertTrue(result.success) import tifffile