diff --git a/.gitignore b/.gitignore index 1d9bf09126a5b8a512737b7416bb734d064d9c91..36a2805f7974aed48458a8127918b60d0e03f951 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,2 @@ -/.idea/ +*/.idea/* *__pycache__* -/dist/ -/model_server_package_rhodes.egg-info/ diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index 23b3b8e2422b35fbea3584ed1bd6b59fd57e7eff..4b91eb5cda12dc79c3b1bd300c5630d1370d2c91 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -3,11 +3,12 @@ import os from pathlib import Path import numpy as np -from skimage.io import imread +from skimage.io import imread, imsave import czifile import tifffile +from base.process import make_rgb from model_server.base.process import is_mask class GenericImageDataAccessor(ABC): @@ -37,12 +38,9 @@ class GenericImageDataAccessor(ABC): def is_mask(self): return is_mask(self._data) - def get_one_channel_data (self, channel: int, mip: bool = False): + def get_one_channel_data (self, channel: int): c = int(channel) - if mip: - return InMemoryDataAccessor(self.data[:, :, c:(c+1), :].max(axis=-1)) - else: - return InMemoryDataAccessor(self.data[:, :, c:(c+1), :]) + return InMemoryDataAccessor(self.data[:, :, c:(c+1), :]) @property def pixel_scale_in_micrometers(self): @@ -103,7 +101,7 @@ class TifSingleSeriesFileAccessor(GenericImageFileAccessor): tf = tifffile.TiffFile(fpath) self.tf = tf except Exception: - raise FileAccessorError(f'Unable to access data in {fpath}') + FileAccessorError(f'Unable to access data in {fpath}') if len(tf.series) != 1: raise DataShapeError(f'Expect only one series in {fpath}') @@ -186,29 +184,57 @@ class CziImageFileAccessor(GenericImageFileAccessor): return sc -def write_accessor_data_to_file(fpath: Path, accessor: GenericImageDataAccessor, mkdir=True) -> bool: +def write_accessor_data_to_file(fpath: Path, acc: GenericImageDataAccessor, mkdir=True) -> bool: + """ + Export an image accessor to file. + :param fpath: complete path including filename and extension + :param acc: image accessor to be written + :param mkdir: create any needed subdirectories in fpath if True + :return: True + """ + if 'P' in acc.shape_dict.keys(): + raise FileWriteError(f'Can only write single-position accessor to file') + ext = fpath.suffix.upper() + if mkdir: fpath.parent.mkdir(parents=True, exist_ok=True) - try: + + if ext == '.PNG': + if acc.dtype != 'uint8': + raise FileWriteError(f'Invalid data type {acc.dtype}') + if acc.chroma == 1: + data = acc.data[:, :, 0, 0] + elif acc.chroma == 2: # add a blank blue channel + data = make_rgb(acc.data) + else: # preserve RGB order + data = acc.data[:, :, :, 0] + imsave(fpath, data, check_contrast=False) + return True + + elif ext in ['.TIF', '.TIFF']: zcyx= np.moveaxis( - accessor.data, # yxcz + acc.data, # yxcz [3, 2, 0, 1], [0, 1, 2, 3] ) - if accessor.is_mask(): - if accessor.dtype == 'bool': + if acc.is_mask(): + if acc.dtype == 'bool': data = (zcyx * 255).astype('uint8') else: data = zcyx.astype('uint8') tifffile.imwrite(fpath, data, imagej=True) else: tifffile.imwrite(fpath, zcyx, imagej=True) - except: - raise FileWriteError(f'Unable to write data to file') + else: + raise FileWriteError(f'Unable to write data to file of extension {ext}') return True def generate_file_accessor(fpath): + """ + Given an image file path, return an image accessor, assuming the file is a supported format and represents + a single position array, which may be single or multi-channel, single plane or z-stack. + """ if str(fpath).upper().endswith('.TIF') or str(fpath).upper().endswith('.TIFF'): return TifSingleSeriesFileAccessor(fpath) elif str(fpath).upper().endswith('.CZI'): @@ -218,6 +244,86 @@ def generate_file_accessor(fpath): else: raise FileAccessorError(f'Could not match a file accessor with {fpath}') + +class PatchStack(InMemoryDataAccessor): + + def __init__(self, data): + """ + A sequence of n (generally) color 3D images of the same size + :param data: either a list of np.ndarrays of size YXCZ, or np.ndarray of size PYXCZ + """ + + if isinstance(data, list): # list of YXCZ patches + n = len(data) + yxcz_shape = np.array([e.shape for e in data]).max(axis=0) + nda = np.zeros( + (n, *yxcz_shape), dtype=data[0].dtype + ) + for i in range(0, len(data)): + s = tuple([slice(0, c) for c in data[i].shape]) + nda[i][s] = data[i] + + elif isinstance(data, np.ndarray) and len(data.shape) == 5: # interpret as PYXCZ + nda = data + else: + raise InvalidDataForPatchStackError(f'Cannot create accessor from {type(data)}') + + assert nda.ndim == 5 + self._data = nda + + def iat(self, i): + return InMemoryDataAccessor(self.data[i, :, :, :, :]) + + def iat_yxcz(self, i): + return self.iat(i) + + @property + def count(self): + return self.shape_dict['P'] + + @property + def shape_dict(self): + return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) + + def get_list(self): + n = self.nz + return [self.data[:, :, 0, zi] for zi in range(0, n)] + + @property + def pyxcz(self): + return self.data + + @property + def pczyx(self): + return np.moveaxis( + self.data, + [0, 3, 4, 1, 2], + [0, 1, 2, 3, 4] + ) + + @property + def shape(self): + return self.data.shape + + @property + def shape_dict(self): + return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) + + +def make_patch_stack_from_file(fpath): # interpret z-dimension as patch position + if not Path(fpath).exists(): + raise FileNotFoundError(f'Could not find {fpath}') + + pyxc = np.moveaxis( + generate_file_accessor(fpath).data, # yxcz + [0, 1, 2, 3], + [1, 2, 3, 0] + ) + pyxcz = np.expand_dims(pyxc, axis=3) + return PatchStack(pyxcz) + + + class Error(Exception): pass @@ -237,4 +343,10 @@ class InvalidAxisKey(Error): pass class InvalidDataShape(Error): - pass \ No newline at end of file + pass + +class InvalidDataForPatchStackError(Error): + pass + + + diff --git a/model_server/base/annotators.py b/model_server/base/annotators.py new file mode 100644 index 0000000000000000000000000000000000000000..b2df418d46fc5801f39e34f4512863dee69e7cfb --- /dev/null +++ b/model_server/base/annotators.py @@ -0,0 +1,52 @@ +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +from model_server.base.process import rescale + +def draw_boxes_on_3d_image(roiset, draw_full_depth=False, **kwargs): + _, _, chroma, nz = roiset.acc_raw.shape + font_size = kwargs.get('font_size', 18) + linewidth = kwargs.get('linewidth', 4) + + annotated = np.zeros(roiset.acc_raw.shape, dtype=roiset.acc_raw.dtype) + + for zi in range(0, nz): + if draw_full_depth: + subset = roiset.get_df() + else: + subset = roiset.get_df().query(f'zi == {zi}') + for c in range(0, chroma): + pilimg = Image.fromarray(roiset.acc_raw.data[:, :, c, zi]) + draw = ImageDraw.Draw(pilimg) + draw.font = ImageFont.truetype(font="arial.ttf", size=font_size) + + for roi in subset.itertuples('Roi'): + xm = round((roi.x0 + roi.x1) / 2) + draw.rectangle([(roi.x0, roi.y0), (roi.x1, roi.y1)], outline='white', width=linewidth) + + if kwargs.get('draw_label') is True: + draw.text((xm, roi.y0), f'{roi.label:04d}', fill='white', anchor='mb') + annotated[:, :, c, zi] = pilimg + + + if clip := kwargs.get('rescale_clip'): + assert clip >= 0.0 and clip <= 1.0 + annotated = rescale(annotated, clip=clip) + + return annotated + +def draw_box_on_patch(patch, bbox, linewidth=1): + assert len(patch.shape) == 2 + ((x0, y0), (x1, y1)) = bbox + pilimg = Image.fromarray(patch) # drawing modifies array in-place + draw = ImageDraw.Draw(pilimg) + draw.rectangle([(x0, y0), (x1, y1)], outline='white', width=linewidth) + return np.array(pilimg) + +def draw_contours_on_patch(patch, contours, linewidth=1): + assert len(patch.shape) == 2 + pilimg = Image.fromarray(patch) # drawing modifies array in-place + draw = ImageDraw.Draw(pilimg) + for co in contours: + draw.line([(p[1], p[0]) for p in co], width=linewidth, joint='curve') + return np.array(pilimg) \ No newline at end of file diff --git a/model_server/base/api.py b/model_server/base/api.py index 613b05cb647d2b6eaddf129d98665008b4cc906e..2ab37ad3da614c3a54ec1c01250a60d5dbf4fec4 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -1,9 +1,10 @@ from fastapi import FastAPI, HTTPException -from model_server.base.models import DummyImageToImageModel +from model_server.base.models import DummySemanticSegmentationModel from model_server.base.session import Session, InvalidPathError from model_server.base.validators import validate_workflow_inputs -from model_server.base.workflows import infer_image_to_image +from model_server.base.workflows import classify_pixels +from model_server.extensions.ilastik.workflows import infer_px_then_ob_model app = FastAPI(debug=True) session = Session() @@ -66,13 +67,13 @@ def list_active_models(): @app.put('/models/dummy/load/') def load_dummy_model() -> dict: - return {'model_id': session.load_model(DummyImageToImageModel)} + return {'model_id': session.load_model(DummySemanticSegmentationModel)} -@app.put('/infer/from_image_file') +@app.put('/workflows/segment') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: inpath = session.paths['inbound_images'] / input_filename validate_workflow_inputs([model_id], [inpath]) - record = infer_image_to_image( + record = classify_pixels( inpath, session.models[model_id]['object'], session.paths['outbound_images'], diff --git a/model_server/base/czi_util.py b/model_server/base/czi_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6c8c5db8c27eeefea2e4bad32122d34f306bf16e --- /dev/null +++ b/model_server/base/czi_util.py @@ -0,0 +1,72 @@ +import csv +from pathlib import Path + +import czifile +import numpy as np +import pandas as pd + +from model_server.base.accessors import InMemoryDataAccessor + + +def dump_czi_subblock_table(czif: czifile.CziFile, where: Path): + csvfn = Path(where) / 'subblocks.csv' + with open(csvfn, 'w', newline='') as csvf: + he_shape = ['shape_' + a for a in list(czif.axes)] + he_start = ['start_' + a for a in list(czif.axes)] + wr = csv.DictWriter(csvf, he_shape + he_start) + wr.writeheader() + for sb in czif.subblock_directory: + shape = dict(zip(['shape_' + a for a in sb.axes], sb.shape)) + start = dict(zip(['start_' + a for a in sb.axes], sb.start)) + wr.writerow(shape | start) + print(f'Dumped CSV to {csvfn}') + + +def dump_czi_metadata(czif: czifile.CziFile, where: Path): + xmlfn = Path(where) / 'czi_meta.xml' + with open(xmlfn, 'w') as xmlf: + xmlf.write(czif.metadata()) + print(f'Dumped XML to {xmlfn}') + + +def get_accessor_from_multiposition_czi(cf: czifile.CziFile, pi: int): + # assumes different channels across different subblocks + + df = pd.DataFrame([dict(zip(sbd.axes, sbd.start)) for sbd in cf.subblock_directory]) + dfq = df[(df['S'] == pi)] + + c_arr = dfq['C'].sort_values() + nc = len(dfq) + + # assert that c_arr is sequential and 0-indexed + assert list(c_arr) == list(range(0, nc)) + + # assert all other dimensions in dfq are the same + assert all(dfq.drop(['C'], axis=1).nunique() == 1) + + # sbis = list(dfq.index) # subblock indices + sbd = cf.subblock_directory + df_shapes = pd.DataFrame([dict(zip(sbd[i].axes, sbd[i].shape)) for i in dfq.index]) + assert all(df_shapes.nunique() == 1) + + (h, w, nz) = tuple(df_shapes.loc[0, ['Y', 'X', 'Z']]) + yxcz = np.zeros((h, w, nc, nz), dtype=cf.dtype) + + # iterate over mono subblocks + for i in range(0, len(c_arr)): + sbi = c_arr[c_arr == i].index[0] + sb = list(cf.subblocks())[sbi] + data = sb.data() + sd = {ch: sb.shape[sb.axes.index(ch)] for ch in sb.axes} + # only non-unit dimensions are Y, X, C, and Z + assert len({k: v for k, v in sd.items() if v != 1 and k not in list('YXZ')}) == 0 + + yxz = np.moveaxis( + data, + [sb.axes.index(k) for k in list('YXZ')], + [0, 1, 2] + ).squeeze( + axis=tuple(range(3, len(sd))) + ) + yxcz[:, :, i, :] = yxz + return InMemoryDataAccessor(yxcz) \ No newline at end of file diff --git a/model_server/base/models.py b/model_server/base/models.py index 014957be8ada6bc059a9da7852e561a6ad4e63b0..8413f68d4420fe31add5e25818b4198bdc1a76eb 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -35,7 +35,12 @@ class Model(ABC): pass @abstractmethod - def infer(self, img: GenericImageDataAccessor) -> (object, dict): # return json describing inference result + def infer(self, *args) -> (object, dict): + """ + Abstract method that carries out the computationally intensive step of running data through a model + :param args: + :return: + """ pass def reload(self): @@ -51,7 +56,40 @@ class ImageToImageModel(Model): def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): pass -class DummyImageToImageModel(ImageToImageModel): + +class SemanticSegmentationModel(ImageToImageModel): + """ + Base model that exposes a method that returns a binary mask for a given input image and pixel class + """ + + @abstractmethod + def label_pixel_class( + self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: + """ + Given an image, return an image of the same shape where each pixel is assigned to one or more integer classes + """ + pass + + +class InstanceSegmentationModel(ImageToImageModel): + """ + Base model that exposes a method that returns an instance classification map for a given input image and mask + """ + + @abstractmethod + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + """ + Given an image and a mask of the same size, return a map where each connected object is assigned a class + """ + if not mask.is_mask(): + raise InvalidInputImageError('Expecting a binary mask') + if not img.shape == mask.shape: + raise InvalidInputImageError('Expect input image and mask to be the same shape') + + +class DummySemanticSegmentationModel(SemanticSegmentationModel): model_id = 'dummy_make_white_square' @@ -66,6 +104,34 @@ class DummyImageToImageModel(ImageToImageModel): result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255 return InMemoryDataAccessor(data=result), {'success': True} + def label_pixel_class( + self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: + mask, _ = self.infer(img) + return mask + +class DummyInstanceSegmentationModel(InstanceSegmentationModel): + + model_id = 'dummy_pass_input_mask' + + def load(self): + return True + + def infer( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor + ) -> (GenericImageDataAccessor, dict): + return img.__class__( + (mask.data / mask.data.max()).astype('uint16') + ) + + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + """ + Returns a trivial segmentation, i.e. the input mask with value 1 + """ + super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs) + return self.infer(img, mask) + class Error(Exception): pass @@ -73,4 +139,7 @@ class CouldNotLoadModelError(Error): pass class ParameterExpectedError(Error): + pass + +class InvalidInputImageError(Error): pass \ No newline at end of file diff --git a/model_server/base/process.py b/model_server/base/process.py index 481282237964ef147821fe99f79f025cf768c948..992c9fb93140ab49859f0807b55a1f8e321cd80d 100644 --- a/model_server/base/process.py +++ b/model_server/base/process.py @@ -4,6 +4,7 @@ Image processing utility functions from math import ceil, floor import numpy as np +import skimage from skimage.exposure import rescale_intensity @@ -71,4 +72,57 @@ def rescale(nda, clip=0.0): clip_pct = (100.0 * clip, 100.0 * (1.0 - clip)) cmin, cmax = np.percentile(nda, clip_pct) rescaled = rescale_intensity(nda, in_range=(cmin, cmax)) - return rescaled \ No newline at end of file + return rescaled + + +def make_rgb(nda): + """ + Convert a YXCZ stack array to RGB, and error if more than three channels + :param nda: np.ndarray (YXCZ dimensions) + :return: np.ndarray of 3-channel stack + """ + h, w, c, nz = nda.shape + assert c <= 3 + outdata = np.zeros((h, w, 3, nz), dtype=nda.dtype) + outdata[:, :, 0:c, :] = nda[:, :, :, :] + return outdata + + +def mask_largest_object( + img: np.ndarray, + max_allowed: int = 10, + verbose: bool = True +) -> np.ndarray: + """ + Where more than one connected component is found in an image, return the largest object by area + :param img: (np.ndarray) containing object labels or binary mask + :param max_allowed: raise an error if more than this number of objects is found + :param verbose: print a message each time more than one object is found + :return: np.ndarray of same size as img + """ + if is_mask(img): # assign object labels + ob_id = skimage.measure.label(img) + else: # assume img is contains object labels + ob_id = img + + num_obj = len(np.unique(ob_id)) - 1 + if num_obj > max_allowed: + raise TooManyObjectError(f'Found {num_obj} objects in frame') + if num_obj > 1: + if verbose: + print(f'Found {num_obj} nonzero unique values in object map; keeping the one with the largest area') + val, cts = np.unique(ob_id, return_counts=True) + mask = ob_id == val[1 + cts[1:].argmax()] + return mask * img + else: + return img + + +class Error(Exception): + pass + + +class TooManyObjectError(Exception): + pass + + diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py new file mode 100644 index 0000000000000000000000000000000000000000..936b53682821f015f46e2dcb07c9ffcb6f2455db --- /dev/null +++ b/model_server/base/roiset.py @@ -0,0 +1,561 @@ +from math import sqrt, floor +from pathlib import Path +from typing import List, Union +from uuid import uuid4 + +import numpy as np +import pandas as pd +from pydantic import BaseModel +from scipy.stats import moment +from skimage.filters import sobel + +from skimage.measure import label, regionprops_table, shannon_entropy, find_contours +from sklearn.preprocessing import PolynomialFeatures +from sklearn.linear_model import LinearRegression + +from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file +from model_server.base.models import InstanceSegmentationModel +from model_server.base.process import pad, rescale, resample_to_8bit, make_rgb +from base.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image +from base.accessors import PatchStack +from base.process import mask_largest_object + + +class PatchParams(BaseModel): + draw_bounding_box: bool = False + draw_contour: bool = False + draw_mask: bool = False + rescale_clip: float = 0.001 + focus_metric: str = 'max_sobel' + rgb_overlay_channels: List[Union[int, None]] = [None, None, None] + rgb_overlay_weights: List[float] = [1.0, 1.0, 1.0] + pad_to: int = 256 + + +class AnnotatedZStackParams(BaseModel): + draw_label: bool = False + + +class RoiFilterRange(BaseModel): + min: float + max: float + + +class RoiFilter(BaseModel): + area: Union[RoiFilterRange, None] = None + solidity: Union[RoiFilterRange, None] = None + + +class RoiSetMetaParams(BaseModel): + filters: Union[RoiFilter, None] = None + expand_box_by: List[int] = [128, 0] + + +class RoiSetExportParams(BaseModel): + pixel_probabilities: bool = False + patches_3d: Union[PatchParams, None] = None + annotated_patches_2d: Union[PatchParams, None] = None + patches_2d: Union[PatchParams, None] = None + patch_masks: Union[PatchParams, None] = None + annotated_zstacks: Union[AnnotatedZStackParams, None] = None + object_classes: bool = False + dataframe: bool = False + + + + +def _get_label_ids(acc_seg_mask: GenericImageDataAccessor) -> InMemoryDataAccessor: + return InMemoryDataAccessor(label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16')) + + +def _focus_metrics(): + return { + 'max_intensity': lambda x: np.max(x), + 'stdev': lambda x: np.std(x), + 'max_sobel': lambda x: np.max(sobel(x)), + 'rms_sobel': lambda x: sqrt(np.mean(sobel(x) ** 2)), + 'entropy': lambda x: shannon_entropy(x), + 'moment': lambda x: moment(x.flatten(), moment=2), + } + + +def _safe_add(a, g, b): + assert a.dtype == b.dtype + assert a.shape == b.shape + assert g >= 0.0 + + return np.clip( + a.astype('uint32') + g * b.astype('uint32'), + 0, + np.iinfo(a.dtype).max + ).astype(a.dtype) + + +class RoiSet(object): + + def __init__( + self, + acc_raw: GenericImageDataAccessor, + acc_obj_ids: GenericImageDataAccessor, + params: RoiSetMetaParams = RoiSetMetaParams(), + ): + """ + A set of regions of interest, referenced by their positions and contours in the YXCZ space of stack acc_raw. + RoiSet contains their internal state, which may be exported as patches, maps, and other products by export methods. + :param acc_raw: accessor to a generally a multichannel z-stack + :param acc_obj_ids: accessor to a 2D single-channel object identities map, where each pixel's intensity + labels its membership in a connected object + :param params: optional arguments that influence the definition and representation of ROIs + """ + assert acc_obj_ids.chroma == 1 + assert acc_obj_ids.nz == 1 + self.acc_obj_ids = acc_obj_ids + self.acc_raw = acc_raw + self.params = params + + self._df = self.filter_df( + self.make_df( + self.acc_raw, self.acc_obj_ids, expand_box_by=params.expand_box_by + ), + params.filters, + ) + + self.count = len(self._df) + self.object_class_maps = {} # classification results + + def __iter__(self): + """Expose ROI meta information via the Pandas.DataFrame API""" + return self._df.itertuples(name='Roi') + + @staticmethod + def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame: + """ + Build dataframe associate object IDs with summary stats + :param acc_raw: accessor to raw image data + :param acc_obj_ids: accessor to map of object IDs + :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary) + :return: pd.DataFrame + """ + # build dataframe of objects, assign z index to each object + argmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16') + df = ( + pd.DataFrame( + regionprops_table( + acc_obj_ids.data[:, :, 0, 0], + intensity_image=argmax, + properties=('label', 'area', 'intensity_mean', 'solidity', 'bbox', 'centroid') + ) + ) + .rename( + columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1', } + ) + ) + df['zi'] = df['intensity_mean'].round().astype('int') + + # compute expanded bounding boxes + h, w, c, nz = acc_raw.shape + ebxy, ebz = expand_box_by + df['ebb_y0'] = (df.y0 - ebxy).apply(lambda x: max(x, 0)) + df['ebb_y1'] = (df.y1 + ebxy).apply(lambda x: min(x, h)) + df['ebb_x0'] = (df.x0 - ebxy).apply(lambda x: max(x, 0)) + df['ebb_x1'] = (df.x1 + ebxy).apply(lambda x: min(x, w)) + df['ebb_z0'] = (df.zi - ebz).apply(lambda x: max(x, 0)) + df['ebb_z1'] = (df.zi + ebz).apply(lambda x: min(x, nz)) + df['ebb_h'] = df['ebb_y1'] - df['ebb_y0'] + df['ebb_w'] = df['ebb_x1'] - df['ebb_x0'] + df['ebb_nz'] = df['ebb_z1'] - df['ebb_z0'] + 1 + + # compute relative bounding boxes + df['rel_y0'] = df.y0 - df.ebb_y0 + df['rel_y1'] = df.y1 - df.ebb_y0 + df['rel_x0'] = df.x0 - df.ebb_x0 + df['rel_x1'] = df.x1 - df.ebb_x0 + + assert np.all(df['rel_x1'] <= (df['ebb_x1'] - df['ebb_x0'])) + assert np.all(df['rel_y1'] <= (df['ebb_y1'] - df['ebb_y0'])) + + df['slice'] = df.apply( + lambda r: + np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1], + axis=1 + ) + df['relative_slice'] = df.apply( + lambda r: + np.s_[int(r.rel_y0): int(r.rel_y1), int(r.rel_x0): int(r.rel_x1), :, :], + axis=1 + ) + df['mask'] = df.apply( + lambda r: (acc_obj_ids.data == r.label)[r.y0: r.y1, r.x0: r.x1, 0, 0], + axis=1 + ) + return df + + + @staticmethod + def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: + query_str = 'label > 0' # always true + if filters is not None: # parse filters + for k, val in filters.dict(exclude_unset=True).items(): + assert k in ('area', 'solidity') + vmin = val['min'] + vmax = val['max'] + assert vmin >= 0 + query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}' + return df.loc[df.query(query_str).index, :] + + def get_df(self) -> pd.DataFrame: + return self._df + + def get_slices(self) -> pd.Series: + return self.get_df()['slice'] + + def add_df_col(self, name, se: pd.Series) -> None: + self._df[name] = se + + def get_multichannel_projection(self): + if self.count: + projected = project_stack_from_focal_points( + self._df['centroid-0'].to_numpy(), + self._df['centroid-1'].to_numpy(), + self._df['zi'].to_numpy(), + self.acc_raw, + degree=4, + ) + else: # else just return MIP + projected = self.acc_raw.data.max(axis=-1) + return projected + + def get_raw_patches(self, channel=None, pad_to=256, make_3d=False): # padded, un-annotated 2d patches + if channel: + patches_df = self.get_patches(white_channel=channel, pad_to=pad_to) + else: + patches_df = self.get_patches(pad_to=pad_to) + patches = list(patches_df['patch']) + return PatchStack(patches) + + def export_annotated_zstack(self, where, prefix='zstack', **kwargs): + annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs)) + success = write_accessor_data_to_file(where / (prefix + '.tif'), annotated) + return {'location': where.__str__(), 'filename': prefix + '.tif'} + + def get_zmask(self, mask_type='boxes'): + """ + Return a mask of same dimensionality as raw data + + :param kwargs: variable-length keyword arguments + mask_type: if 'boxes', zmask is True in each object's complete bounding box; otherwise 'contours' + """ + + assert mask_type in ('contours', 'boxes') + zi_st = np.zeros(self.acc_raw.shape, dtype='bool') + lamap = self.acc_obj_ids.data + + # make an object map where label is replaced by focus position in stack and background is -1 + lut = np.zeros(lamap.max() + 1) - 1 + df = self.get_df() + lut[df.label] = df.zi + + if mask_type == 'contours': + zi_map = (lut[lamap] + 1.0).astype('int') + idxs = np.array(zi_map) - 1 + np.put_along_axis( + zi_st, + np.expand_dims(idxs, (2, 3)), + 1, + axis=3 + ) + + # change background level from to 0 in final frame + zi_st[:, :, :, -1][lamap == 0] = 0 + + elif mask_type == 'boxes': + for roi in self: + zi_st[roi.relative_slice] = 1 + + return zi_st + + + def classify_by(self, name: str, channel: int, object_classification_model: InstanceSegmentationModel, ): + + # do this on a patch basis, i.e. only one object per frame + obmap_patches = object_classification_model.label_instance_class( + self.get_raw_patches(channel=channel), + self.get_patch_masks() + ) + + om = np.zeros(self.acc_obj_ids.shape, self.acc_obj_ids.dtype) + + self._df['classify_by_' + name] = pd.Series(dtype='Int64') + + # assign labels to object map: + for i, roi in enumerate(self): + oc = np.unique( + mask_largest_object( + obmap_patches.iat(i).data + ) + )[1] + self._df.loc[roi.Index, 'classify_by_' + name] = oc + om[self.acc_obj_ids.data == roi.label] = oc + self.object_class_maps[name] = InMemoryDataAccessor(om) + + def export_patch_masks(self, where: Path, pad_to: int = 256, prefix='mask', **kwargs) -> list: + patches_acc = self.get_patch_masks(pad_to=pad_to) + + exported = [] + for i, roi in enumerate(self): # assumes index of patches_acc is same as dataframe + patch = patches_acc.iat_yxcz(i) + ext = 'png' + fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' + write_accessor_data_to_file(where / fname, patch) + exported.append(fname) + return exported + + + def export_patches(self, where: Path, prefix='patch', **kwargs): + make_3d = kwargs.get('make_3d', False) + patches_df = self.get_patches(**kwargs) + + def _export_patch(roi): + patch = InMemoryDataAccessor(roi.patch) + ext = 'tif' if make_3d or patch.chroma > 3 else 'png' + fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' + + if patch.dtype is np.dtype('uint16'): + resampled = InMemoryDataAccessor(resample_to_8bit(patch.data)) + write_accessor_data_to_file(where / fname, resampled) + else: + write_accessor_data_to_file(where / fname, patch) + + exported.append({ + 'df_index': roi.Index, + 'patch_filename': fname, + 'location': where.__str__(), + }) + + exported = [] + for roi in patches_df.itertuples(): # just used for label info + _export_patch(roi) + + return exported + + def get_patch_masks(self, pad_to: int = 256) -> PatchStack: + patches = [] + for roi in self: + patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') + patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255 + + if pad_to: + patch = pad(patch, pad_to) + + patches.append(patch) + return PatchStack(patches) + + + def get_patches( + self, + rescale_clip: float = 0.0, + pad_to: int = 256, + make_3d: bool = False, + focus_metric: str = None, + rgb_overlay_channels: list = None, + rgb_overlay_weights: list = [1.0, 1.0, 1.0], + white_channel: int = None, + **kwargs + ) -> pd.DataFrame: + + # arrange RGB channels if so specified, otherwise copy roiset.raw_acc data + raw = self.acc_raw + if isinstance(rgb_overlay_channels, (list, tuple)) and isinstance(rgb_overlay_weights, (list, tuple)): + assert all([c < raw.chroma for c in rgb_overlay_channels if c is not None]) + assert len(rgb_overlay_channels) == 3 + assert len(rgb_overlay_weights) == 3 + + if white_channel: + assert white_channel < raw.chroma + stack = raw.data[:, :, [white_channel, white_channel, white_channel], :] + else: + stack = np.zeros([*raw.shape[0:2], 3, raw.shape[3]], dtype=raw.dtype) + + for ii, ci in enumerate(rgb_overlay_channels): + if ci is None: + continue + assert isinstance(ci, int) + assert ci < raw.chroma + stack[:, :, ii, :] = _safe_add( + stack[:, :, ii, :], # either black or grayscale channel + rgb_overlay_weights[ii], + raw.data[:, :, ci, :] + ) + else: + if white_channel: # interpret as just a single channel + assert white_channel < raw.chroma + annotate_rgb = False + for k in ['contour_channel', 'bounding_box_channel', 'mask_channel']: + ca = kwargs.get(k) + if ca is None: + continue + assert (ca < raw.chroma) + if ca != white_channel: + annotate_rgb = True + break + if annotate_rgb: # make RGB patches anyway to include annotation color + stack = raw.data[:, :, [white_channel, white_channel, white_channel], :] + else: # make monochrome patches + stack = raw.data[:, :, [white_channel], :] + else: + stack = raw.data + + def _make_patch(roi): + patch3d = stack[roi.slice] + ph, pw, pc, pz = patch3d.shape + subpatch = patch3d[roi.relative_slice] + + # make a 3d patch + if make_3d: + patch = patch3d + + # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately + elif focus_metric is not None: + foc = _focus_metrics()[focus_metric] + + patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype) + + for ci in range(0, pc): + me = [foc(subpatch[:, :, ci, zi]) for zi in range(0, pz)] + zif = np.argmax(me) + patch[:, :, ci, 0] = patch3d[:, :, ci, zif] + + # make a 2d patch from middle of z-stack + else: + zim = floor(pz / 2) + patch = patch3d[:, :, :, [zim]] + + assert len(patch.shape) == 4 + + if rescale_clip is not None: + patch = rescale(patch, rescale_clip) + + if kwargs.get('draw_bounding_box') is True: + bci = kwargs.get('bounding_box_channel', 0) + assert bci < 3 + if bci > 0: + patch = make_rgb(patch) + for zi in range(0, patch.shape[3]): + patch[:, :, bci, zi] = draw_box_on_patch( + patch[:, :, bci, zi], + ((roi.rel_x0, roi.rel_y0), (roi.rel_x1, roi.rel_y1)), + linewidth=kwargs.get('bounding_box_linewidth', 1) + ) + + if kwargs.get('draw_mask'): + mci = kwargs.get('mask_channel', 0) + mask = np.zeros(patch.shape[0:2], dtype=bool) + mask[roi.relative_slice[0:2]] = roi.mask + for zi in range(0, patch.shape[3]): + patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi] + + if kwargs.get('draw_contour'): + mci = kwargs.get('contour_channel', 0) + mask = np.zeros(patch.shape[0:2], dtype=bool) + mask[roi.relative_slice[0:2]] = roi.mask + + for zi in range(0, patch.shape[3]): + patch[:, :, mci, zi] = draw_contours_on_patch( + patch[:, :, mci, zi], + find_contours(mask) + ) + + if pad_to: + patch = pad(patch, pad_to) + return patch + + dfe = self._df + dfe['patch'] = self._df.apply(lambda r: _make_patch(r), axis=1) + return dfe + + def run_exports(self, where, channel, prefix, params: RoiSetExportParams): + if not self.count: + return + raw_ch = self.acc_raw.get_one_channel_data(channel) + for k in params.dict().keys(): + subdir = where / k + pr = prefix + kp = params.dict()[k] + if kp is None: + continue + if k == 'patches_3d': + files = self.export_patches( + subdir, white_channel=channel, prefix=pr, make_3d=True, **kp + ) + if k == 'annotated_patches_2d': + files = self.export_patches( + subdir, prefix=pr, make_3d=False, white_channel=channel, + bounding_box_channel=1, bounding_box_linewidth=2, **kp, + ) + if k == 'patches_2d': + files = self.export_patches( + subdir, white_channel=channel, prefix=pr, make_3d=False, **kp + ) + df_patches = pd.DataFrame(files) + self._df = pd.merge(self._df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') + self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1) + if k == 'patch_masks': + self.export_patch_masks(subdir, prefix=pr, **kp) + if k == 'annotated_zstacks': + self.export_annotated_zstack(subdir, prefix=pr, **kp) + if k == 'object_classes': + for k, acc in self.object_class_maps.items(): + write_accessor_data_to_file(subdir / k / (pr + '.tif'), acc) + if k == 'dataframe': + dfpa = subdir / (pr + '.csv') + dfpa.parent.mkdir(parents=True, exist_ok=True) + self._df.to_csv(dfpa, index=False) + + +def project_stack_from_focal_points( + xx: np.ndarray, + yy: np.ndarray, + zz: np.ndarray, + stack: GenericImageDataAccessor, + degree: int = 2, +) -> np.ndarray: + """ + Given a set of 3D points, project a multichannel z-stack based on a surface fit of the provided points + :param xx: vector of point x-coordinates + :param yy: vector of point y-coordinates + :param zz: vector of point z-coordinates + :param stack: z-stack to project + :param degree: order of polynomial to fit + :return: multichannel 2d projected image array + """ + assert xx.shape == yy.shape + assert xx.shape == zz.shape + + poly = PolynomialFeatures(degree=degree) + X = np.stack([xx, yy]).T + features = poly.fit_transform(X, zz) + model = LinearRegression(fit_intercept=False) + model.fit(features, zz) + + xy_indices = np.indices(stack.hw).reshape(2, -1).T + xy_features = np.dot( + poly.fit_transform(xy_indices, zz), + model.coef_ + ) + zi_image = xy_features.reshape( + stack.hw + ).round().clip( + 0, (stack.nz - 1) + ).astype('uint16') + + return np.take_along_axis( + stack.data, + np.repeat( + np.expand_dims(zi_image, (2, 3)), + stack.chroma, + axis=2 + ), + axis=3 + ) + + diff --git a/model_server/base/util.py b/model_server/base/util.py index 68f28f1ace54664651ff52eaa2d8439d687d14a5..112118832acb0d15caed1ef29118c3bcc43ea7df 100644 --- a/model_server/base/util.py +++ b/model_server/base/util.py @@ -1,3 +1,4 @@ +from math import ceil from pathlib import Path import re from time import localtime, strftime @@ -78,10 +79,12 @@ def loop_workflow( output_folder_path: str, workflow_func: callable, models: List[Model], - params: dict, + # params: dict, export_batch_csvs: bool = True, write_intermediate_products: bool = True, catch_and_continue: bool = True, + chunk_size: int = None, + **params, ): """ Iteratively call the specified workflow function on each of a list of input files @@ -92,7 +95,23 @@ def loop_workflow( :param export_batch_csvs: if True, write any tabular data returned by workflow_func to CSV files :param write_intermediate_products: if True, write any intermediate image products to TIF files :param catch_and_continue: if True, catch exceptions returned by workflow_func and keep iterating + :param chunk_size: create subdirectories with specified number of input files, or all in top-level directory if None """ + if chunk_size and chunk_size < len(files): + for ci in range(0, ceil(len(files) / ceil(chunk_size))): + loop_workflow( + files[ci * chunk_size: (ci + 1) * chunk_size], + (Path(output_folder_path) / f'part-{ci:04d}').__str__(), + workflow_func, + models, + params, + export_batch_csvs=export_batch_csvs, + write_intermediate_products=write_intermediate_products, + catch_and_continue=catch_and_continue, + chunk_size=None + ) + return True + failures = [] for ii, ff in enumerate(files): export_kwargs = { diff --git a/model_server/base/workflows.py b/model_server/base/workflows.py index e3eedb68d9ed2574180e47c1a16e65da5b50372e..f3719433b727e770285a53dd3f2fb55c7517dbf7 100644 --- a/model_server/base/workflows.py +++ b/model_server/base/workflows.py @@ -6,7 +6,7 @@ from time import perf_counter from typing import Dict from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file -from model_server.base.models import Model +from model_server.base.models import SemanticSegmentationModel from pydantic import BaseModel @@ -29,11 +29,12 @@ class WorkflowRunRecord(BaseModel): timer_results: Dict[str, float] -def infer_image_to_image(fpi: Path, model: Model, where_output: Path, **kwargs) -> WorkflowRunRecord: +def classify_pixels(fpi: Path, model: SemanticSegmentationModel, where_output: Path, **kwargs) -> WorkflowRunRecord: """ - Generic workflow where a model processes an input image into an output image + Run a semantic segmentation model to compute a binary mask from an input image + :param fpi: Path object that references input image file - :param model: model object + :param model: semantic segmentation model instance :param where_output: Path object that references output image directory :param kwargs: variable-length keyword arguments :return: record object @@ -43,7 +44,7 @@ def infer_image_to_image(fpi: Path, model: Model, where_output: Path, **kwargs) img = generate_file_accessor(fpi).get_one_channel_data(ch) ti.click('file_input') - outdata, _ = model.infer(img) + outdata = model.label_pixel_class(img) ti.click('inference') outpath = where_output / (model.model_id + '_' + fpi.stem + '.tif') diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py index 5b49dbbd907f96ffe66761bd89f9580b8340e366..3d07931ee087dba7eb56f416f232b16b89056161 100644 --- a/model_server/conf/testing.py +++ b/model_server/conf/testing.py @@ -53,20 +53,28 @@ monozstackmask = { 'z': 85 } -filename = 'yxc_test.tif' -yxctif = { - 'filename': filename, - 'path': root / filename, - 'w': 256, - 'h': 256, - 'c': 4, - 'z': 1 -} - ilastik_classifiers = { 'px': root / 'ilastik' / 'demo_px.ilp', 'pxmap_to_obj': root / 'ilastik' / 'demo_obj.ilp', - 'seg_to_obj': root / 'ilastik' / 'new_auto_obj.ilp', + 'seg_to_obj': root / 'ilastik' / 'demo_obj_seg.ilp', +} + +roiset_test_data = { + 'multichannel_zstack': { + 'path': root / 'zmask-test-stack-chlorophyl.tif', + 'w': 512, + 'h': 512, + 'c': 5, + 'z': 7, + 'mask_path': root / 'zmask-test-stack-mask.tif', + }, + 'pipeline_params': { + 'segmentation_channel': 0, + 'patches_channel': 4, + 'pxmap_channel': 0, + 'pxmap_threshold': 0.6, + }, + 'pixel_classifier': root / 'zmask' / 'AF405-bodies_boundaries.ilp', } output_path = root / 'testing_output' diff --git a/model_server/extensions/ilastik/conf.py b/model_server/extensions/ilastik/conf.py index 9447c4b147b3ec62c580c270e98541120c2ffebf..835e8a8078b308675f568df7a55c06a4663232dc 100644 --- a/model_server/extensions/ilastik/conf.py +++ b/model_server/extensions/ilastik/conf.py @@ -1,5 +1,5 @@ from pathlib import Path paths = { - 'project_files': Path.home() / 'base' / 'ilastik' + 'project_files': Path.home() / 'model_server' / 'ilastik' } \ No newline at end of file diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 39d31d3fa7920ecb33ffbf37263b019231748597..91493a2bed5fc0f676206a3cf1b75a3ff4f7c00b 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -4,12 +4,13 @@ from pathlib import Path import numpy as np import vigra -import model_server +import model_server.extensions.ilastik.conf +from base.accessors import PatchStack from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor -from model_server.base.models import ImageToImageModel, ParameterExpectedError +from model_server.base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel -class IlastikImageToImageModel(ImageToImageModel): +class IlastikModel(Model): def __init__(self, params, autoload=True): self.project_file = Path(params['project_file']) @@ -17,8 +18,7 @@ class IlastikImageToImageModel(ImageToImageModel): if self.project_file.is_absolute(): pap = self.project_file else: - from model_server.extensions.ilastik.conf import paths as ilastik_paths - pap = ilastik_paths['project_files'] / self.project_file + pap = model_server.extensions.ilastik.conf.paths['project_files'] / self.project_file self.project_file_abspath = pap if not pap.exists(): raise FileNotFoundError(f'Project file does not exist: {pap}') @@ -28,7 +28,6 @@ class IlastikImageToImageModel(ImageToImageModel): self.shell = None super().__init__(autoload, params) - def load(self): from ilastik import app from ilastik.applets.dataSelection.opDataSelection import PreloadedArrayDatasetInfo @@ -52,8 +51,9 @@ class IlastikImageToImageModel(ImageToImageModel): return True -class IlastikPixelClassifierModel(IlastikImageToImageModel): +class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): model_id = 'ilastik_pixel_classification' + operations = ['segment', ] @staticmethod def get_workflow(): @@ -78,28 +78,42 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel): ) return InMemoryDataAccessor(data=yxcz), {'success': True} -class IlastikObjectClassifierFromPixelPredictionsModel(IlastikImageToImageModel): - model_id = 'ilastik_object_classification_from_pixel_predictions' + def label_pixel_class(self, img: GenericImageDataAccessor, px_class: int = 0, px_prob_threshold=0.5, **kwargs): + pxmap, _ = self.infer(img) + mask = pxmap.data[:, :, px_class, :] > px_prob_threshold + return InMemoryDataAccessor(mask) + + +class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel): + model_id = 'ilastik_object_classification_from_segmentation' @staticmethod def get_workflow(): - from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction - return ObjectClassificationWorkflowPrediction + from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary + return ObjectClassificationWorkflowBinary - def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): + def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') - tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') + assert segmentation_img.is_mask() + if segmentation_img.dtype == 'bool': + seg = 255 * segmentation_img.data.astype('uint8') + tagged_seg_data = vigra.taggedView( + 255 * segmentation_img.data.astype('uint8'), + 'yxcz' + ) + else: + tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') dsi = [ { 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), - 'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data), + 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), } ] obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] - assert (len(obmaps) == 1, 'ilastik generated more than one object map') + assert len(obmaps) == 1, 'ilastik generated more than one object map' yxcz = np.moveaxis( obmaps[0], @@ -108,30 +122,34 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikImageToImageModel) ) return InMemoryDataAccessor(data=yxcz), {'success': True} + def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): + super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) + obmap, _ = self.infer(img, mask) + return obmap -class IlastikObjectClassifierFromSegmentationModel(IlastikImageToImageModel): - model_id = 'ilastik_object_classification_from_segmentation' + +class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImageModel): + model_id = 'ilastik_object_classification_from_pixel_predictions' @staticmethod def get_workflow(): - from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary - return ObjectClassificationWorkflowBinary + from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction + return ObjectClassificationWorkflowPrediction - def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): - assert segmentation_img.is_mask() + def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') - tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') + tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') dsi = [ { 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), - 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), + 'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data), } ] obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] - assert (len(obmaps) == 1, 'ilastik generated more than one object map') + assert len(obmaps) == 1, 'ilastik generated more than one object map' yxcz = np.moveaxis( obmaps[0], @@ -139,3 +157,63 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikImageToImageModel): [0, 1, 2, 3] ) return InMemoryDataAccessor(data=yxcz), {'success': True} + + + def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs): + """ + Given an image and a map of pixel probabilities of the same shape, return a map where each connected object is + assigned a class. + :param img: input image + :param pxmap: map of pixel probabilities + :param kwargs: + pixel_classification_channel: channel of pxmap used to segment objects + pixel_classification_thresold: threshold of pxmap used to segment objects + :return: + """ + if not img.shape == pxmap.shape: + raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape') + # TODO: check that pxmap is in-range + pxch = kwargs.get('pixel_classification_channel', 0) + pxtr = kwargs('pixel_classification_threshold', 0.5) + mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr) + # super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) + obmap, _ = self.infer(img, mask) + return obmap + + +class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel): + """ + Wrap ilastik object classification for inputs comprising single-object series of raw images and binary + segmentation masks. + """ + + def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict): + assert segmentation_acc.is_mask() + if not input_acc.chroma == 1: + raise InvalidInputImageError('Object classifier expects only monochrome patches') + if not input_acc.nz == 1: + raise InvalidInputImageError('Object classifier expects only 2d patches') + + tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx') + tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx') + + dsi = [ + { + 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), + 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), + } + ] + + obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] + + assert len(obmaps) == 1, 'ilastik generated more than one object map' + + # for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C + assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1) + pyxcz = np.moveaxis( + obmaps[0], + [0, 1, 2, 3, 4], + [0, 4, 1, 2, 3] + ) + + return PatchStack(data=pyxcz), {'success': True} \ No newline at end of file diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py index 910844040399192fdbb6f27cdc80d8c4447336c0..b7bee14e3d669822bfb5d0138b3c800df9c4869e 100644 --- a/model_server/extensions/ilastik/router.py +++ b/model_server/extensions/ilastik/router.py @@ -14,7 +14,7 @@ router = APIRouter( session = Session() -def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: str, duplicate=True) -> dict: +def load_ilastik_model(model_class: ilm.IlastikModel, project_file: str, duplicate=True) -> dict: """ Load an ilastik model of a given class and project filename. :param model_class: @@ -35,7 +35,7 @@ def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: ) return {'model_id': result} -@router.put('/px/load/') +@router.put('/seg/load/') def load_px_model(project_file: str, duplicate: bool = True) -> dict: return load_ilastik_model(ilm.IlastikPixelClassifierModel, project_file, duplicate=duplicate) @@ -48,7 +48,7 @@ def load_seg_to_obj_model(project_file: str, duplicate: bool = True) -> dict: return load_ilastik_model(ilm.IlastikObjectClassifierFromSegmentationModel, project_file, duplicate=duplicate) @router.put('/pixel_then_object_classification/infer') -def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None, mip: bool = False) -> dict: +def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict: inpath = session.paths['inbound_images'] / input_filename validate_workflow_inputs([px_model_id, ob_model_id], [inpath]) try: @@ -57,8 +57,7 @@ def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: st session.models[px_model_id]['object'], session.models[ob_model_id]['object'], session.paths['outbound_images'], - channel=channel, - mip=mip, + channel=channel ) except AssertionError: raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}') diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index d9e8c429a5f65805ca8bb2edacee035db5ff4b43..89fe2b98d1f8ea25702e8a221c6178e8fe3a645b 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -4,10 +4,11 @@ import unittest import numpy as np -from model_server.conf.testing import czifile, ilastik_classifiers, output_path -from model_server.base.accessors import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file +from model_server.conf.testing import czifile, ilastik_classifiers, output_path, roiset_test_data +from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.extensions.ilastik import models as ilm -from model_server.base.workflows import infer_image_to_image +from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams +from model_server.base.workflows import classify_pixels from tests.test_api import TestServerBaseClass class TestIlastikPixelClassification(unittest.TestCase): @@ -35,7 +36,7 @@ class TestIlastikPixelClassification(unittest.TestCase): input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1)) with self.assertRaises(AttributeError): - pxmap, _ = model.infer(input_img) + mask = model.label_pixel_class(input_img) def test_run_pixel_classifier_on_random_data(self): @@ -47,8 +48,8 @@ class TestIlastikPixelClassification(unittest.TestCase): input_img = InMemoryDataAccessor(data=np.random.rand(h, w, 1, 1)) - pxmap, _ = model.infer(input_img) - self.assertEqual(pxmap.shape, (h, w, 2, 1)) + mask = model.label_pixel_class(input_img) + self.assertEqual(mask.shape, (h, w, 1, 1)) def test_run_pixel_classifier(self): @@ -66,28 +67,29 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertEqual(mono_image.shape_dict['C'], 1) self.assertEqual(mono_image.shape_dict['Z'], 1) - pxmap, _ = model.infer(mono_image) + mask = model.label_pixel_class(mono_image) - self.assertEqual(pxmap.shape[0:2], cf.shape[0:2]) - self.assertEqual(pxmap.shape_dict['C'], 2) - self.assertEqual(pxmap.shape_dict['Z'], 1) + self.assertTrue(mask.is_mask()) + self.assertEqual(mask.shape[0:2], cf.shape[0:2]) + self.assertEqual(mask.shape_dict['C'], 1) + self.assertEqual(mask.shape_dict['Z'], 1) self.assertTrue( write_accessor_data_to_file( output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif', - pxmap + mask ) ) self.mono_image = mono_image - self.pxmap = pxmap + self.mask = mask - def test_run_object_classifier(self): + def test_run_object_classifier_from_pixel_predictions(self): self.test_run_pixel_classifier() fp = czifile['path'] model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( {'project_file': ilastik_classifiers['pxmap_to_obj']} ) - objmap, _ = model.infer(self.mono_image, self.pxmap) + objmap, _ = model.infer(self.mono_image, self.mask) self.assertTrue( write_accessor_data_to_file( @@ -97,8 +99,24 @@ class TestIlastikPixelClassification(unittest.TestCase): ) self.assertEqual(objmap.data.max(), 3) + def test_run_object_classifier_from_segmentation(self): + self.test_run_pixel_classifier() + fp = czifile['path'] + model = ilm.IlastikObjectClassifierFromSegmentationModel( + {'project_file': ilastik_classifiers['seg_to_obj']} + ) + objmap = model.label_instance_class(self.mono_image, self.mask) + + self.assertTrue( + write_accessor_data_to_file( + output_path / f'obmap_from_seg_{fp.stem}.tif', + objmap, + ) + ) + self.assertEqual(objmap.data.max(), 3) + def test_ilastik_pixel_classification_as_workflow(self): - result = infer_image_to_image( + result = classify_pixels( czifile['path'], ilm.IlastikPixelClassifierModel( {'project_file': ilastik_classifiers['px']} @@ -113,7 +131,7 @@ class TestIlastikOverApi(TestServerBaseClass): def test_httpexception_if_incorrect_project_file_loaded(self): resp_load = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={'project_file': 'improper.ilp'}, ) self.assertEqual(resp_load.status_code, 404) @@ -121,7 +139,7 @@ class TestIlastikOverApi(TestServerBaseClass): def test_load_ilastik_pixel_model(self): resp_load = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={'project_file': str(ilastik_classifiers['px'])}, ) self.assertEqual(resp_load.status_code, 200, resp_load.json()) @@ -137,7 +155,7 @@ class TestIlastikOverApi(TestServerBaseClass): resp_list_1st = requests.get(self.uri + 'models').json() self.assertEqual(len(resp_list_1st), 1, resp_list_1st) resp_load_2nd = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={ 'project_file': str(ilastik_classifiers['px']), 'duplicate': True, @@ -146,7 +164,7 @@ class TestIlastikOverApi(TestServerBaseClass): resp_list_2nd = requests.get(self.uri + 'models').json() self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd) resp_load_3rd = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={ 'project_file': str(ilastik_classifiers['px']), 'duplicate': False, @@ -172,14 +190,14 @@ class TestIlastikOverApi(TestServerBaseClass): # load models with these paths resp1 = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={ 'project_file': ilp_win, 'duplicate': False, }, ) resp2 = requests.put( - self.uri + 'ilastik/px/load/', + self.uri + 'ilastik/seg/load/', params={ 'project_file': ilp_posx, 'duplicate': False, @@ -226,7 +244,7 @@ class TestIlastikOverApi(TestServerBaseClass): model_id = self.test_load_ilastik_pixel_model() resp_infer = requests.put( - self.uri + f'infer/from_image_file', + self.uri + f'workflows/segment', params={ 'model_id': model_id, 'input_filename': czifile['filename'], @@ -251,4 +269,32 @@ class TestIlastikOverApi(TestServerBaseClass): ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) - # TODO: test IlastikObjectClassifierFromSegmentationModel when a test model is complete \ No newline at end of file +class TestIlastikObjectClassification(unittest.TestCase): + def setUp(self): + stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) + stack_ch_pa = stack.get_one_channel_data(roiset_test_data['pipeline_params']['patches_channel']) + seg_mask = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path']) + + self.roiset = RoiSet( + stack_ch_pa, + _get_label_ids(seg_mask), + params=RoiSetMetaParams( + mask_type='boxes', + filters={'area': {'min': 1e3, 'max': 1e4}}, + expand_box_by=(64, 2) + ) + ) + + self.object_classifier = ilm.PatchStackObjectClassifier( + params={'project_file': ilastik_classifiers['seg_to_obj']} + ) + + def test_classify_patches(self): + raw_patches = self.roiset.get_raw_patches() + patch_masks = self.roiset.get_patch_masks() + res_patches, _ = self.object_classifier.infer(raw_patches, patch_masks) + self.assertEqual(res_patches.count, self.roiset.count) + for pi in range(0, res_patches.count): # assert that there is only one nonzero label per patch + unique = np.unique(res_patches.iat(pi).data) + self.assertEqual(len(unique), 2) + self.assertEqual(unique[0], 0) diff --git a/model_server/scripts/run_server.py b/model_server/scripts/run_server.py index 6b8ad33df5e59a4bfd117b3c508ed4f30f7ff939..b278252bfb9c82a54ea69b491e4893d4471cf510 100644 --- a/model_server/scripts/run_server.py +++ b/model_server/scripts/run_server.py @@ -1,8 +1,6 @@ import argparse from multiprocessing import Process import requests -from requests.adapters import HTTPAdapter -from urllib3 import Retry import uvicorn import webbrowser @@ -36,7 +34,7 @@ if __name__ == '__main__': print('CLI args:\n' + str(args)) server_process = Process( target=uvicorn.run, - args=('base.api:app',), + args=('model_server.api:app',), kwargs={ 'app_dir': '.', 'host': args.host, @@ -50,13 +48,7 @@ if __name__ == '__main__': server_process.start() try: - sesh = requests.Session() - retries = Retry( - total=5, - backoff_factor=0.1, - ) - sesh.mount('http://', HTTPAdapter(max_retries=retries)) - resp = sesh.get(url) + resp = requests.get(url) assert resp.status_code == 200 except Exception: print('Error starting server') diff --git a/readme.md b/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..c424e968a2c16149ff04cba2f45fbf7ccc8e20fc --- /dev/null +++ b/readme.md @@ -0,0 +1,17 @@ +# model_server + +# How to extend service +Add sub-package to extensions +Add models that inherit from model_server.Model +In workflows, implement pipelines with File I/O via accessors.GenericImageDataAccessor +(to decouple model logic from image data source) +Set extensions-specific folders, etc. in conf relative to overall package root (set by user) +As much as possible, set pipeline and model parameters with defaults and support overrides by optional API arguments; +this helps non-coding users control their jobs +Set up API endpoints in router, following as much as possible existing conventions with load, infer, etc. keyword + +decouple data access from processing + +control either via batch runners or API (serial) + +workflow: combines data access with processing via models, produces primary outputs \ No newline at end of file diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 41939777a0dc8924bc9d291eff186e3e7dd694d0..d622a368dae6e89a24dabf1044391092f48b88b5 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -2,7 +2,10 @@ import unittest import numpy as np -from model_server.conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile, monozstackmask, yxctif +from base.accessors import PatchStack, make_patch_stack_from_file, FileNotFoundError +from conf.testing import monozstackmask + +from model_server.conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile, monozstackmask from model_server.base.accessors import CziImageFileAccessor, DataShapeError, generate_file_accessor, InMemoryDataAccessor, PngFileAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor class TestCziImageFileAccess(unittest.TestCase): @@ -42,16 +45,6 @@ class TestCziImageFileAccess(unittest.TestCase): sc = cf.get_one_channel_data(c) self.assertEqual(sc.shape, (h, w, 1, nz)) - def test_get_single_channel_mip_from_zstack(self): - w = 256 - h = 512 - nc = 4 - nz = 11 - c = 3 - cf = InMemoryDataAccessor(np.random.rand(h, w, nc, nz)) - sc = cf.get_one_channel_data(c, mip=True) - self.assertEqual(sc.shape, (h, w, 1, 1)) - def test_write_single_channel_tif(self): ch = 4 cf = CziImageFileAccessor(czifile['path']) @@ -122,11 +115,74 @@ class TestCziImageFileAccess(unittest.TestCase): acc = generate_file_accessor(monozstackmask['path']) self.assertTrue(acc.is_mask()) - def test_read_yxc_tif(self): - acc = generate_file_accessor(yxctif['path']) - self.assertEqual(acc.nz, 1) - def test_read_in_pixel_scale_from_czi(self): cf = CziImageFileAccessor(czifile['path']) pxs = cf.pixel_scale_in_micrometers - self.assertAlmostEqual(pxs['X'], czifile['um_per_pixel'], places=3) \ No newline at end of file + self.assertAlmostEqual(pxs['X'], czifile['um_per_pixel'], places=3) + + +class TestPatchStackAccessor(unittest.TestCase): + def setUp(self) -> None: + pass + + def test_make_patch_stack_from_3d_array(self): + w = 256 + h = 512 + n = 4 + acc = PatchStack(np.random.rand(n, h, w, 1, 1)) + self.assertEqual(acc.count, n) + self.assertEqual(acc.hw, (h, w)) + self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1)) + + def test_make_patch_stack_from_list(self): + w = 256 + h = 512 + n = 4 + acc = PatchStack([np.random.rand(h, w, 1, 1) for _ in range(0, n)]) + self.assertEqual(acc.count, n) + self.assertEqual(acc.hw, (h, w)) + self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1)) + return acc + + + def test_make_patch_stack_from_file(self): + h = monozstackmask['h'] + w = monozstackmask['w'] + c = monozstackmask['c'] + n = monozstackmask['z'] + + acc = make_patch_stack_from_file(monozstackmask['path']) + self.assertEqual(acc.hw, (h, w)) + self.assertEqual(acc.count, n) + self.assertEqual(acc.pyxcz.shape, (n, h, w, c, 1)) + + def test_raises_filenotfound(self): + with self.assertRaises(FileNotFoundError): + acc = make_patch_stack_from_file('c:/fake/file/name.tif') + + def test_make_3d_patch_stack_from_nonuniform_list(self): + w = 256 + h = 512 + c = 1 + nz = 5 + n = 4 + + patches = [np.random.rand(h, w, c, nz) for _ in range(0, n)] + patches.append(np.random.rand(h, 2 * w, c, nz)) + acc = PatchStack(patches) + self.assertEqual(acc.count, n + 1) + self.assertEqual(acc.hw, (h, 2 * w)) + self.assertEqual(acc.chroma, c) + self.assertEqual(acc.iat(0).shape, (h, 2 * w, c, nz)) + self.assertEqual(acc.iat_yxcz(0).shape, (h, 2 * w, c, nz)) + + def test_pczyx(self): + w = 256 + h = 512 + n = 4 + nz = 15 + nc = 2 + acc = PatchStack(np.random.rand(n, h, w, nc, nz)) + self.assertEqual(acc.count, n) + self.assertEqual(acc.pczyx.shape, (n, nc, nz, h, w)) + self.assertEqual(acc.hw, (h, w)) diff --git a/tests/test_api.py b/tests/test_api.py index 46bfbe235e24615081aad39ebca2b67079a32de4..ffcb96e7841c84498e96db1f1c3fde7f8a78fde3 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -70,7 +70,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() - self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') + self.assertEqual(rj[model_id]['class'], 'DummySemanticSegmentationModel') return model_id def test_respond_with_error_when_invalid_filepath_requested(self): @@ -89,7 +89,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): def test_i2i_inference_errors_when_model_not_found(self): model_id = 'not_a_real_model' resp = requests.put( - self.uri + f'infer/from_image_file', + self.uri + f'workflows/segment', params={ 'model_id': model_id, 'input_filename': 'not_a_real_file.name' @@ -101,7 +101,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): model_id = self.test_load_dummy_model() self.copy_input_file_to_server() resp_infer = requests.put( - self.uri + f'infer/from_image_file', + self.uri + f'workflows/segment', params={ 'model_id': model_id, 'input_filename': czifile['filename'], diff --git a/tests/test_model.py b/tests/test_model.py index 1938e80c26f32e0ffccd00a8dd6f8b028d692dd2..3f9f28727c250e1bc399a98c81e9c0206d55bb3d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,49 +1,56 @@ import unittest from model_server.conf.testing import czifile from model_server.base.accessors import CziImageFileAccessor -from model_server.base.models import DummyImageToImageModel, CouldNotLoadModelError +from model_server.base.models import DummySemanticSegmentationModel, DummyInstanceSegmentationModel, CouldNotLoadModelError class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) def test_instantiate_model(self): - model = DummyImageToImageModel(params=None) + model = DummySemanticSegmentationModel(params=None) self.assertTrue(model.loaded) def test_instantiate_model_with_nondefault_kwarg(self): - model = DummyImageToImageModel(autoload=False) + model = DummySemanticSegmentationModel(autoload=False) self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.') def test_raise_error_if_cannot_load_model(self): - class UnloadableDummyImageToImageModel(DummyImageToImageModel): + class UnloadableDummyImageToImageModel(DummySemanticSegmentationModel): def load(self): return False with self.assertRaises(CouldNotLoadModelError): mi = UnloadableDummyImageToImageModel() - def test_czifile_is_correct_shape(self): - model = DummyImageToImageModel() - img, _ = model.infer(self.cf) + def test_dummy_pixel_segmentation(self): + model = DummySemanticSegmentationModel() + img = self.cf.get_one_channel_data(0) + mask = model.label_pixel_class(img) w = czifile['w'] h = czifile['h'] self.assertEqual( - img.shape, + mask.shape, (h, w, 1, 1), 'Inferred image is not the expected shape' ) self.assertEqual( - img.data[int(w/2), int(h/2)], + mask.data[int(w/2), int(h/2)], 255, 'Middle pixel is not white as expected' ) self.assertEqual( - img.data[0, 0], + mask.data[0, 0], 0, 'First pixel is not black as expected' - ) \ No newline at end of file + ) + return img, mask + + def test_dummy_instance_segmentation(self): + img, mask = self.test_dummy_pixel_segmentation() + model = DummyInstanceSegmentationModel() + obmap = model.label_instance_class(img, mask) diff --git a/tests/test_process.py b/tests/test_process.py index 6908e9d7c1b779065d4de226b7a203979570fb9a..454151bc4f84156348624318a43837b5e9a370c8 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -2,6 +2,7 @@ import unittest import numpy as np +from base.process import mask_largest_object from model_server.base.process import pad class TestProcessingUtilityMethods(unittest.TestCase): @@ -27,4 +28,31 @@ class TestProcessingUtilityMethods(unittest.TestCase): nc = self.data4d.shape[2] nz = self.data4d.shape[3] padded = pad(self.data4d, 256) - self.assertEqual(padded.shape, (256, 256, nc, nz)) \ No newline at end of file + self.assertEqual(padded.shape, (256, 256, nc, nz)) + + +class TestMaskLargestObject(unittest.TestCase): + def test_mask_largest_touching_object(self): + arr = np.zeros([5, 5], dtype='uint8') + arr[0:3, 0:3] = 2 + arr[3:, 2:] = 4 + masked = mask_largest_object(arr) + self.assertTrue(np.all(np.unique(masked) == [0, 2])) + self.assertTrue(np.all(masked[4:5, 0:2] == 0)) + self.assertTrue(np.all(masked[0:3, 3:5] == 0)) + + def test_no_change(self): + arr = np.zeros([5, 5], dtype='uint8') + arr[0:3, 0:3] = 2 + masked = mask_largest_object(arr) + self.assertTrue(np.all(masked == arr)) + + def test_mask_multiple_objects_in_binary_maks(self): + arr = np.zeros([5, 5], dtype='uint8') + arr[0:3, 0:3] = 255 + arr[4, 2:5] = 255 + masked = mask_largest_object(arr) + print(np.unique(masked)) + self.assertTrue(np.all(np.unique(masked) == [0, 255])) + self.assertTrue(np.all(masked[:, 3:5] == 0)) + self.assertTrue(np.all(masked[3:5, :] == 0)) diff --git a/tests/test_roiset.py b/tests/test_roiset.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5a8b4608ec176594fa9763ccd0a9a8f0493726 --- /dev/null +++ b/tests/test_roiset.py @@ -0,0 +1,233 @@ +import unittest + +import numpy as np +from pathlib import Path + +from model_server.conf.testing import output_path, roiset_test_data + +from model_server.base.roiset import RoiSetMetaParams +from model_server.base.roiset import _get_label_ids, RoiSet +from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file +from model_server.base.models import DummyInstanceSegmentationModel + +class BaseTestRoiSetMonoProducts(object): + + def setUp(self) -> None: + # set up test raw data and segmentation from file + self.stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) + self.stack_ch_pa = self.stack.get_one_channel_data(roiset_test_data['pipeline_params']['patches_channel']) + self.seg_mask = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path']) + + +class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): + + def _make_roi_set(self, mask_type='boxes', **kwargs): + id_map = _get_label_ids(self.seg_mask) + roiset = RoiSet( + self.stack_ch_pa, + id_map, + params=RoiSetMetaParams( + mask_type=mask_type, + filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}), + expand_box_by=(64, 2) + ) + ) + return roiset + + def test_roi_mask_shape(self, **kwargs): + roiset = self._make_roi_set(**kwargs) + zmask = roiset.get_zmask() + zmask_acc = InMemoryDataAccessor(zmask) + self.assertTrue(zmask_acc.is_mask()) + + # assert dimensionality of zmask + self.assertGreater(zmask_acc.shape_dict['Z'], 1) + self.assertEqual(zmask_acc.shape_dict['C'], 1) + write_accessor_data_to_file(output_path / 'zmask.tif', zmask_acc) + + # mask values are not just all True or all False + self.assertTrue(np.any(zmask)) + self.assertFalse(np.all(zmask)) + + # assert non-trivial meta info in boxes + self.assertGreater(roiset.count, 1) + sh = roiset.get_df().iloc[1]['mask'].shape + ar = roiset.get_df().iloc[1]['area'] + self.assertGreaterEqual(sh[0] * sh[1], ar) + + def test_roiset_from_non_zstacks(self, **kwargs): + acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0]) + self.assertEqual(acc_zstack_slice.nz, 1) + id_map = _get_label_ids(self.seg_mask) + + roiset = RoiSet(acc_zstack_slice, id_map, params=RoiSetMetaParams(mask_type='boxes')) + zmask = roiset.get_zmask() + + zmask_acc = InMemoryDataAccessor(zmask) + self.assertTrue(zmask_acc.is_mask()) + + def test_slices_are_valid(self): + roiset = self._make_roi_set() + for s in roiset.get_slices(): + ebb = roiset.acc_raw.data[s] + self.assertEqual(len(ebb.shape), 4) + self.assertTrue(np.all([si >= 1 for si in ebb.shape])) + + def test_rel_slices_are_valid(self): + roiset = self._make_roi_set() + for roi in roiset: + ebb = roiset.acc_raw.data[roi.slice] + self.assertEqual(len(ebb.shape), 4) + self.assertTrue(np.all([si >= 1 for si in ebb.shape])) + rbb = ebb[roi.relative_slice] + self.assertEqual(len(rbb.shape), 4) + self.assertTrue(np.all([si >= 1 for si in rbb.shape])) + + def test_make_2d_patches(self): + roiset = self._make_roi_set() + files = roiset.export_patches( + output_path / '2d_patches', + draw_bounding_box=True, + ) + self.assertGreaterEqual(len(files), 1) + + def test_make_3d_patches(self): + roiset = self._make_roi_set() + files = roiset.export_patches( + output_path / '3d_patches', + make_3d=True) + self.assertGreaterEqual(len(files), 1) + + def test_export_annotated_zstack(self): + roiset = self._make_roi_set() + file = roiset.export_annotated_zstack( + output_path / 'annotated_zstack', + ) + result = generate_file_accessor(Path(file['location']) / file['filename']) + self.assertEqual(result.shape, roiset.acc_raw.shape) + + def test_flatten_image(self): + id_map = _get_label_ids(self.seg_mask) + + roiset = RoiSet(self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='boxes')) + df = roiset.get_df() + + from base.roiset import project_stack_from_focal_points + + img = project_stack_from_focal_points( + df['centroid-0'].to_numpy(), + df['centroid-1'].to_numpy(), + df['zi'].to_numpy(), + self.stack, + degree=4, + ) + + self.assertEqual(img.shape[0:2], self.stack.shape[0:2]) + + write_accessor_data_to_file( + output_path / 'flattened.tif', + InMemoryDataAccessor(img) + ) + + def test_make_binary_masks(self): + roiset = self._make_roi_set() + files = roiset.export_patch_masks(output_path / '2d_mask_patches', ) + self.assertGreaterEqual(len(files), 1) + + def test_classify_by(self): + roiset = self._make_roi_set() + roiset.classify_by('dummy_class', 0, DummyInstanceSegmentationModel()) + self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) + self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) + + def test_raw_patches_are_correct_shape(self): + roiset = self._make_roi_set() + patches = roiset.get_raw_patches() + np, h, w, nc, nz = patches.shape + self.assertEqual(np, roiset.count) + self.assertEqual(nc, roiset.acc_raw.chroma) + + def test_patch_masks_are_correct_shape(self): + roiset = self._make_roi_set() + patch_masks = roiset.get_patch_masks() + np, h, w, nc, nz = patch_masks.shape + self.assertEqual(np, roiset.count) + self.assertEqual(nc, 1) + + +class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): + + def setUp(self) -> None: + super().setUp() + id_map = _get_label_ids(self.seg_mask) + self.roiset = RoiSet( + self.stack, + id_map, + params=RoiSetMetaParams( + expand_box_by=(128, 2), + mask_type='boxes', + filters={'area': {'min': 1e3, 'max': 1e4}}, + ) + ) + + def test_multichannel_to_mono_2d_patches(self): + files = self.roiset.export_patches( + output_path / 'multichannel' / 'mono_2d_patches', + white_channel=3, + draw_bounding_box=True, + ) + result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + self.assertEqual(result.chroma, 1) + + def test_multichannnel_to_mono_2d_patches_rgb_bbox(self): + files = self.roiset.export_patches( + output_path / 'multichannel' / 'mono_2d_patches_rgb_bbox', + white_channel=3, + draw_bounding_box=True, + bounding_box_channel=1, + ) + result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + self.assertEqual(result.chroma, 3) + + def test_multichannnel_to_rgb_2d_patches_bbox(self): + files = self.roiset.export_patches( + output_path / 'multichannel' / 'rgb_2d_patches_bbox', + white_channel=4, + rgb_overlay_channels=(3, None, None), + draw_mask=True, + mask_channel=0, + rgb_overlay_weights=(0.1, 1.0, 1.0) + ) + result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + self.assertEqual(result.chroma, 3) + + def test_multichannnel_to_rgb_2d_patches_contour(self): + files = self.roiset.export_patches( + output_path / 'multichannel' / 'rgb_2d_patches_contour', + rgb_overlay_channels=(3, None, None), + draw_contour=True, + contour_channel=1, + rgb_overlay_weights=(0.1, 1.0, 1.0) + ) + result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + self.assertEqual(result.chroma, 3) + self.assertEqual(result.get_one_channel_data(2).data.max(), 0) # blue channel is black + + def test_multichannel_to_multichannel_tif_patches(self): + files = self.roiset.export_patches( + output_path / 'multichannel' / 'multichannel_tif_patches', + ) + result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + self.assertEqual(result.chroma, 5) + + def test_multichannel_annotated_zstack(self): + file = self.roiset.export_annotated_zstack( + output_path / 'multichannel' / 'annotated_zstack', + 'test_multichannel_annotated_zstack', + ) + result = generate_file_accessor(Path(file['location']) / file['filename']) + self.assertEqual(result.chroma, self.stack.chroma) + self.assertEqual(result.nz, self.stack.nz) + + + diff --git a/tests/test_session.py b/tests/test_session.py index dd143d5f48e6eec0057fca012a7e8e555139e016..9679aad61c699de8f07ff0a31e2cc4750cb6f9e3 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,6 +1,6 @@ import pathlib import unittest -from model_server.base.models import DummyImageToImageModel +from model_server.base.models import DummySemanticSegmentationModel from model_server.base.session import Session class TestGetSessionObject(unittest.TestCase): @@ -63,7 +63,7 @@ class TestGetSessionObject(unittest.TestCase): def test_session_loads_model(self): sesh = Session() - MC = DummyImageToImageModel + MC = DummySemanticSegmentationModel success = sesh.load_model(MC) self.assertTrue(success) loaded_models = sesh.describe_loaded_models() @@ -77,7 +77,7 @@ class TestGetSessionObject(unittest.TestCase): def test_session_loads_second_instance_of_same_model(self): sesh = Session() - MC = DummyImageToImageModel + MC = DummySemanticSegmentationModel sesh.load_model(MC) sesh.load_model(MC) self.assertIn(MC.__name__ + '_00', sesh.models.keys()) @@ -86,7 +86,7 @@ class TestGetSessionObject(unittest.TestCase): def test_session_loads_model_with_params(self): sesh = Session() - MC = DummyImageToImageModel + MC = DummySemanticSegmentationModel p1 = {'p1': 'abc'} success = sesh.load_model(MC, params=p1) self.assertTrue(success) @@ -103,7 +103,7 @@ class TestGetSessionObject(unittest.TestCase): def test_session_finds_existing_model_with_different_path_formats(self): sesh = Session() - MC = DummyImageToImageModel + MC = DummySemanticSegmentationModel p1 = {'path': 'c:\\windows\\dummy.pa'} p2 = {'path': 'c:/windows/dummy.pa'} mid = sesh.load_model(MC, params=p1) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 3c7c02a61d6d73fdfa2167c8d973af35c069a366..6e9603ea7418ba367b482a00b8dd8c9aafa8820d 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -1,16 +1,16 @@ import unittest from model_server.conf.testing import czifile, output_path -from model_server.base.models import DummyImageToImageModel -from model_server.base.workflows import infer_image_to_image +from model_server.base.models import DummySemanticSegmentationModel +from model_server.base.workflows import classify_pixels class TestGetSessionObject(unittest.TestCase): def setUp(self) -> None: - self.model = DummyImageToImageModel() + self.model = DummySemanticSegmentationModel() def test_single_session_instance(self): - result = infer_image_to_image(czifile['path'], self.model, output_path, channel=2) + result = classify_pixels(czifile['path'], self.model, output_path, channel=2) self.assertTrue(result.success) import tifffile