diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index a9366156bb6a5fcd63267cd3fb6e319d13bfbe37..b250605d4df5de6b1cd18d3a2570a85ed532f319 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -3,6 +3,7 @@ import os from pathlib import Path import numpy as np +import pandas as pd from skimage.io import imread, imsave import czifile @@ -40,6 +41,9 @@ class GenericImageDataAccessor(ABC): def is_mask(self): return is_mask(self._data) + def can_mask(self, acc): + return self.is_mask() and self.shape == acc.get_mono(0).shape + def get_channels(self, channels: list, mip: bool = False): carr = [int(c) for c in channels] if mip: @@ -77,14 +81,14 @@ class GenericImageDataAccessor(ABC): return self.data.sum(axis=(0, 1, 2)) @property - def data_xy(self) -> np.ndarray: + def data_yx(self) -> np.ndarray: if not self.chroma == 1 and self.nz == 1: raise InvalidDataShape('Can only return XY array from accessors with a single channel and single z-level') else: return self.data[:, :, 0, 0] @property - def data_xyz(self) -> np.ndarray: + def data_yxz(self) -> np.ndarray: if not self.chroma == 1: raise InvalidDataShape('Can only return XYZ array from accessors with a single channel') else: @@ -96,6 +100,10 @@ class GenericImageDataAccessor(ABC): def unique(self): return np.unique(self.data, return_counts=True) + @property + def dtype_max(self): + return np.iinfo(self.dtype).max + @property def pixel_scale_in_micrometers(self): return {} @@ -175,6 +183,16 @@ class GenericImageDataAccessor(ABC): 'filepath': '', } + def append_channels(self, acc): + if self.dtype != acc.dtype: + raise DataTypeError(f'Cannot append data of type {acc.dtype} to an accessor with type {self.dtype}') + return self._derived_accessor( + np.concatenate( + (self.data, acc.data), + axis=self._ga('C') + ) + ) + class InMemoryDataAccessor(GenericImageDataAccessor): def __init__(self, data): self._data = self.conform_data(data) @@ -424,6 +442,28 @@ class PatchStack(InMemoryDataAccessor): def count(self): return self.shape_dict['P'] + @property + def data_yx(self) -> np.ndarray: + if not self.chroma == 1 and self.nz == 1: + raise InvalidDataShape('Can only return XY array from accessors with a single channel and single z-level') + else: + return self.data[:, :, :, 0, 0] + + @property + def data_yxz(self) -> np.ndarray: + if not self.chroma == 1: + raise InvalidDataShape('Can only return XYZ array from accessors with a single channel') + else: + return self.data[:, :, :, 0, :] + + @property + def smallest(self): + return np.array([p.shape for p in self.get_list()]).min(axis=0), + + @property + def largest(self): + return np.array([p.shape for p in self.get_list()]).max(axis=0) + def export_pyxcz(self, fpath: Path): tzcyx = np.moveaxis( self.pyxcz, # yxcz @@ -449,9 +489,15 @@ class PatchStack(InMemoryDataAccessor): def shape_dict(self): return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) - def get_list(self): - n = self.nz - return [self.data[:, :, 0, zi] for zi in range(0, n)] + def get_channels(self, channels: list, mip: bool = False): + nda = super().get_channels(channels, mip).data + return self._derived_accessor([nda[i][self._slices[i]] for i in range(0, self.count)]) + + def get_list(self, crop=True): + if crop: + return [self._data[i][self._slices[i]] for i in range(0, self.count)] + else: + return [self._data[i] for i in range(0, self.count)] @property def pyxcz(self): @@ -473,6 +519,24 @@ class PatchStack(InMemoryDataAccessor): def shape_dict(self): return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) + def get_object_df(self, mask) -> pd.DataFrame: + """ + Given a mask patch stack of the same size, return a DataFrame summarizing the area and intensity of objects, + assuming the each patch in the patch stack represents a single object. + :param mask of the same dimensions + """ + if not mask.can_mask(self): + raise DataShapeError(f'Patch stack object dataframe expects a mask of the same dimensions') + df = pd.DataFrame([ + { + 'label': i, + 'area': (mask.iat(i).data > 0).sum(), + 'intensity_sum': (self.iat(i).data * (mask.iat(i).data > 0)).sum() + } for i in range(0, self.count) + ]) + df['intensity_mean'] = df['intensity_sum'] / df['area'] + return df + def make_patch_stack_from_file(fpath): # interpret t-dimension as patch position if not Path(fpath).exists(): @@ -515,6 +579,9 @@ class FileNotFoundError(Error): class DataShapeError(Error): pass +class DataTypeError(Error): + pass + class FileWriteError(Error): pass diff --git a/model_server/base/api.py b/model_server/base/api.py index 608738a166768fb1ff8420a8cd0b4a58b170a3be..972c909e22ed28c6bdd115d96c10be8ffa3bf9e3 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -1,9 +1,11 @@ -from typing import Union +from pydantic import BaseModel, Field +from typing import List, Union from fastapi import FastAPI, HTTPException from .accessors import generate_file_accessor -from .session import session, AccessorIdError, InvalidPathError, WriteAccessorError - +from .models import BinaryThresholdSegmentationModel +from .roiset import IntensityThresholdInstanceMaskSegmentationModel, RoiSetExportParams, SerializeRoiSetError +from .session import session, AccessorIdError, InvalidPathError, RoiSetIdError, WriteAccessorError app = FastAPI(debug=True) @@ -21,6 +23,21 @@ def read_root(): return {'success': True} +@app.get('/paths/session') +def get_top_session_path(): + return session.get_paths()['session'] + + +@app.get('/paths/inbound') +def get_inbound_path(): + return session.get_paths()['inbound_images'] + + +@app.get('/paths/outbound') +def get_outbound_path(): + return session.get_paths()['outbound_images'] + + @app.get('/paths') def list_session_paths(): return session.get_paths() @@ -32,6 +49,7 @@ def show_session_status(): 'status': 'running', 'models': session.describe_loaded_models(), 'paths': session.get_paths(), + 'rois': session.list_rois(), 'accessors': session.list_accessors(), } @@ -54,8 +72,8 @@ def watch_output_path(path: str): @app.get('/session/restart') -def restart_session(root: str = None) -> dict: - session.restart(root=root) +def restart_session() -> dict: + session.restart() return session.describe_loaded_models() @@ -69,6 +87,24 @@ def list_active_models(): return session.describe_loaded_models() +class BinaryThresholdSegmentationParams(BaseModel): + tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation') + + +@app.put('/models/seg/threshold/load/') +def load_binary_threshold_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict: + result = session.load_model(BinaryThresholdSegmentationModel, key=model_id, params=p) + session.log_info(f'Loaded binary threshold segmentation model {result}') + return {'model_id': result} + + +@app.put('/models/classify/threshold/load') +def load_intensity_threshold_instance_segmentation_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict: + result = session.load_model(IntensityThresholdInstanceMaskSegmentationModel, key=model_id, params=p) + session.log_info(f'Loaded permissive instance segmentation model {result}') + return {'model_id': result} + + @app.get('/accessors') def list_accessors(): return session.list_accessors() @@ -110,4 +146,60 @@ def write_accessor_to_file(accessor_id: str, filename: Union[str, None] = None) except AccessorIdError as e: raise HTTPException(404, f'Did not find accessor with ID {accessor_id}') except WriteAccessorError as e: - raise HTTPException(409, str(e)) \ No newline at end of file + raise HTTPException(409, str(e)) + +@app.get('/rois') +def list_rois(): + return session.list_rois() + +def _session_roiset(func, roiset_id): + try: + return func(roiset_id) + except RoiSetIdError as e: + raise HTTPException(404, f'Did not find RoiSet with ID {roiset_id}') + +@app.get('/rois/{roiset_id}') +def get_roiset(roiset_id: str): + return _session_roiset(session.get_roiset_info, roiset_id) + +class RoiSetSerializationRecord(BaseModel): + dataframe: str + tight_patch_masks: Union[List[str], None] + +@app.put('/rois/write/{roiset_id}') +def write_roiset_to_file(roiset_id: str, where: Union[str, None] = None) -> RoiSetSerializationRecord: + try: + return session.write_roiset(roiset_id, where) + except RoiSetIdError as e: + raise HTTPException(404, f'Did not find RoiSet with ID {roiset_id}') + except SerializeRoiSetError as e: + raise HTTPException(409, str(e)) + +@app.put('/rois/products/{roiset_id}') +def roiset_export_products( + roiset_id: str, + channel: Union[int, None] = None, + params: RoiSetExportParams = RoiSetExportParams() +) -> dict: + roiset = _session_roiset(session.get_roiset, roiset_id) + products = roiset.get_export_product_accessors(channel, params) + return {k: session.add_accessor(v, f'{roiset_id}_{k}') for k, v in products.items()} + +@app.put('/rois/obmap/{roiset_id}/{model_id}') +def roiset_get_object_map( + roiset_id: str, + model_id: str, + channel: Union[int, None] = None, +) -> str: + products = roiset_export_products( + roiset_id, + channel, + RoiSetExportParams(object_classes=True), + ) + for k, rid in products.items(): + if k == f'object_classes_{model_id}': + return rid + raise HTTPException( + 404, + f'Did not find object map from classification model {model_id} in RoiSet {roiset_id}' + ) \ No newline at end of file diff --git a/model_server/base/models.py b/model_server/base/models.py index eaded2c9f24c8793600019d86a1dcbb5cdab7c52..f1fccd37d99e237c10747a0a83c30db745c1fbf3 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -1,25 +1,21 @@ from abc import ABC, abstractmethod -from math import floor import numpy as np -from pydantic import BaseModel -from .accessors import GenericImageDataAccessor, InMemoryDataAccessor, PatchStack +from .accessors import GenericImageDataAccessor, PatchStack class Model(ABC): - def __init__(self, autoload=True, params: BaseModel = None): + def __init__(self, autoload: bool = True, info: dict = None): """ Abstract base class for an inference model that uses image data as an input. - :param autoload: automatically load model and dependencies into memory if True - :param params: (optional) BaseModel of model parameters e.g. configuration files required to load model + :param info: optional dictionary of JSON-serializable information to report to API """ self.autoload = autoload - if params: - self.params = params.dict() self.loaded = False + self._info = info if not autoload: return None if self.load(): @@ -28,6 +24,10 @@ class Model(ABC): raise CouldNotLoadModelError() return None + @property + def info(self): + return self._info + @abstractmethod def load(self): """ @@ -37,7 +37,7 @@ class Model(ABC): pass @abstractmethod - def infer(self, *args) -> (object, dict): + def infer(self, *args) -> object: """ Abstract method that carries out the computationally intensive step of running data through a model :param args: @@ -60,7 +60,7 @@ class ImageToImageModel(Model): """ @abstractmethod - def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): + def infer(self, img: GenericImageDataAccessor, *args, **kwargs) -> GenericImageDataAccessor: pass @@ -88,11 +88,38 @@ class SemanticSegmentationModel(ImageToImageModel): return PatchStack(data) -class InstanceSegmentationModel(ImageToImageModel): +class BinaryThresholdSegmentationModel(SemanticSegmentationModel): + def __init__(self, tr: float = 0.5, channel: int = 0): + """ + Model that labels all pixels as class 1 if the intensity in specified channel exceeds a threshold. + :param tr: threshold in range of 0.0 to 1.0; model handles normalization to full pixel intensity range + :param channel: channel to use for thresholding + """ + self.tr = tr + self.channel = channel + self.loaded = self.load() + super().__init__(info={'tr': tr, 'channel': channel}) + + def infer(self, acc: GenericImageDataAccessor) -> GenericImageDataAccessor: + norm_tr = self.tr * acc.dtype_max + return acc.apply(lambda x: x > norm_tr) + + def label_pixel_class(self, acc: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: + return self.infer(acc, **kwargs) + + def load(self): + return True + + +class InstanceMaskSegmentationModel(ImageToImageModel): """ Base model that exposes a method that returns an instance classification map for a given input image and mask """ + @abstractmethod + def infer(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: + pass + @abstractmethod def label_instance_class( self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs @@ -105,6 +132,8 @@ class InstanceSegmentationModel(ImageToImageModel): if img.hw != mask.hw or img.nz != mask.nz: raise InvalidInputImageError('Expect input image and mask to be the same shape') + return self.infer(img, mask) + def label_patch_stack(self, img: PatchStack, mask: PatchStack, allow_multiple=True, force_single=False, **kwargs): """ Call inference on all patches in a PatchStack at once @@ -129,21 +158,6 @@ class InstanceSegmentationModel(ImageToImageModel): return PatchStack(data) -class BinaryThresholdSegmentationModel(SemanticSegmentationModel): - - def __init__(self, tr: float = 0.5): - self.tr = tr - - def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): - return img.apply(lambda x: x > self.tr), {'success': True} - - def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: - return self.infer(img, **kwargs)[0] - - def load(self): - pass - - class Error(Exception): pass diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py index 037fb7fc015c5646152854f5a7fa4f2258d9808d..8fea5a836d7e66de30063a6149a89dbdf8abfae3 100644 --- a/model_server/base/pipelines/roiset_obmap.py +++ b/model_server/base/pipelines/roiset_obmap.py @@ -1,17 +1,16 @@ from typing import Dict, Union -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field from ..accessors import GenericImageDataAccessor from .router import router from .segment_zproj import segment_zproj_pipeline from .shared import call_pipeline from ..roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams -from ..session import session from ..pipelines.shared import PipelineTrace, PipelineParams, PipelineRecord -from ..models import Model, InstanceSegmentationModel +from ..models import Model, InstanceMaskSegmentationModel class RoiSetObjectMapParams(PipelineParams): @@ -35,10 +34,6 @@ class RoiSetObjectMapParams(PipelineParams): None, description='Object classifier used to classify segmented objectss' ) - pixel_classifier_derived_model_id: Union[str, None] = Field( - None, - description='Pixel classifier used to derive channel(s) as additional inputs to object classification' - ) patches_channel: int = Field( description='Channel of input image used in patches sent to object classifier' ) @@ -55,35 +50,19 @@ class RoiSetObjectMapParams(PipelineParams): 'deproject_channel': None, }) export_params: RoiSetExportParams = RoiSetExportParams() - derived_channels_input_channel: Union[int, None] = Field( - None, - description='Channel of input image from which to compute derived channels; use all if empty' - ) - derived_channels_output_channels: Union[int, list] = Field( - None, - description='Derived channels to send to object classifier; use all if empty' - ) + export_label_interm: bool = False class RoiSetToObjectMapRecord(PipelineRecord): - roiset_table: dict + pass @router.put('/roiset_to_obmap/infer') def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord: """ Compute a RoiSet from 2d segmentation, apply to z-stack, and optionally apply object classification. """ - record, rois = call_pipeline(roiset_object_map_pipeline, p) - - table = rois.get_serializable_dataframe() - - session.write_to_table('RoiSet', {'input_filename': p.accessor_id}, table) - ret = RoiSetToObjectMapRecord( - roiset_table=table.to_dict(), - **record.dict() - ) - return ret + return call_pipeline(roiset_object_map_pipeline, p) def roiset_object_map_pipeline( @@ -91,11 +70,11 @@ def roiset_object_map_pipeline( models: Dict[str, Model], **k ) -> (PipelineTrace, RoiSet): - d = PipelineTrace(accessors['accessor']) + d = PipelineTrace(accessors['']) d['mask'] = segment_zproj_pipeline( accessors, - {'model': models['pixel_classifier_segmentation_model']}, + {'': models['pixel_classifier_segmentation_']}, **k['segmentation'], ).last @@ -107,9 +86,9 @@ def roiset_object_map_pipeline( d[ki] = vi # optionally run an object classifier if specified - if obmod := models.get('object_classifier_model'): + if obmod := models.get('object_classifier_'): obmod_name = k['object_classifier_model_id'] - assert isinstance(obmod, InstanceSegmentationModel) + assert isinstance(obmod, InstanceMaskSegmentationModel) rois.classify_by( obmod_name, [k['patches_channel']], diff --git a/model_server/base/pipelines/segment.py b/model_server/base/pipelines/segment.py index fac1835111872ab563b71d6795982e170d52e61f..2af13d77dafa9563265cca530db3af18e1dd3fa2 100644 --- a/model_server/base/pipelines/segment.py +++ b/model_server/base/pipelines/segment.py @@ -31,11 +31,11 @@ def segment_pipeline( models: Dict[str, Model], **k ) -> PipelineTrace: - d = PipelineTrace(accessors.get('accessor')) - model = models.get('model') + d = PipelineTrace(accessors.get('')) + model = models.get('') if not isinstance(model, SemanticSegmentationModel): - raise IncompatibleModelsError('Expecting a pixel classification model') + raise IncompatibleModelsError('Expecting a semantic segmentation model') if ch := k.get('channel') is not None: d['mono'] = d['input'].get_mono(ch) diff --git a/model_server/base/pipelines/segment_zproj.py b/model_server/base/pipelines/segment_zproj.py index 5f5ed9ba755433ef8445198b80885412e2161d98..a5f459f49aa30b4fcbaabd17d1d55471cd111be1 100644 --- a/model_server/base/pipelines/segment_zproj.py +++ b/model_server/base/pipelines/segment_zproj.py @@ -29,12 +29,12 @@ def segment_zproj_pipeline( models: Dict[str, Model], **k ) -> PipelineTrace: - d = PipelineTrace(accessors.get('accessor')) + d = PipelineTrace(accessors.get('')) if isinstance(k.get('zi'), int): assert 0 < k['zi'] < d.last.nz d['mip'] = d.last.get_zi(k['zi']) else: d['mip'] = d.last.get_mip() - return segment_pipeline({'accessor': d.last}, models, **k) + return segment_pipeline({'': d.last}, models, **k) diff --git a/model_server/base/pipelines/shared.py b/model_server/base/pipelines/shared.py index df43cd3c6b55af8b99a2425be568c69786fb1ba9..3fbf1d5838cd05017240d11678876b76eca3ff03 100644 --- a/model_server/base/pipelines/shared.py +++ b/model_server/base/pipelines/shared.py @@ -6,7 +6,8 @@ from typing import List, Union from fastapi import HTTPException from pydantic import BaseModel, Field, root_validator -from ..accessors import GenericImageDataAccessor +from ..accessors import GenericImageDataAccessor, InMemoryDataAccessor +from ..roiset import RoiSet from ..session import session, AccessorIdError @@ -40,14 +41,16 @@ class PipelineRecord(BaseModel): interm_accessor_ids: Union[List[str], None] success: bool timer: dict + roiset_id: Union[str, None] = None def call_pipeline(func, p: PipelineParams) -> PipelineRecord: # match accessor IDs to loaded accessor objects accessors_in = {} + for k, v in p.dict().items(): if k.endswith('accessor_id'): - accessors_in[k.split('_id')[0]] = session.get_accessor(v, pop=True) + accessors_in[k.split('accessor_id')[0]] = session.get_accessor(v, pop=True) if len(accessors_in) == 0: raise NoAccessorsFoundError('Expecting as least one valid accessor to run pipeline') @@ -57,7 +60,7 @@ def call_pipeline(func, p: PipelineParams) -> PipelineRecord: models = {} for k, v in p.dict().items(): if k.endswith('model_id') and v is not None: - models[k.split('_id')[0]] = session.models[v]['object'] + models[k.split('model_id')[0]] = session.models[v]['object'] # call the actual pipeline; expect a single PipelineTrace or a tuple where first element is PipelineTrace ret = func( @@ -66,11 +69,11 @@ def call_pipeline(func, p: PipelineParams) -> PipelineRecord: **p.dict(), ) if isinstance(ret, PipelineTrace): - steps = ret - misc = None - elif isinstance(ret, tuple) and isinstance(ret[0], PipelineTrace): - steps = ret[0] - misc = ret[1:] + trace = ret + roiset_id = None + elif isinstance(ret, tuple) and isinstance(ret[0], PipelineTrace) and isinstance(ret[1], RoiSet): + trace = ret[0] + roiset_id = session.add_roiset(ret[1]) else: raise UnexpectedPipelineReturnError( f'{func.__name__} returned unexpected value of {type(ret)}' @@ -80,7 +83,7 @@ def call_pipeline(func, p: PipelineParams) -> PipelineRecord: # map intermediate data accessors to accessor IDs if p.keep_interm: interm_ids = [] - acc_interm = steps.accessors(skip_first=True, skip_last=True).items() + acc_interm = trace.accessors(skip_first=True, skip_last=True).items() for i, item in enumerate(acc_interm): stk, acc = item interm_ids.append( @@ -94,7 +97,7 @@ def call_pipeline(func, p: PipelineParams) -> PipelineRecord: # map final result to an accessor ID result_id = session.add_accessor( - steps.last, + trace.last, accessor_id=f'{p.accessor_id}_{func.__name__}_result' ) @@ -102,14 +105,11 @@ def call_pipeline(func, p: PipelineParams) -> PipelineRecord: output_accessor_id=result_id, interm_accessor_ids=interm_ids, success=True, - timer=steps.times + timer=trace.times, + roiset_id=roiset_id, ) - # return miscellaneous objects if pipeline returns these - if misc: - return record, *misc - else: - return record + return record class PipelineTrace(OrderedDict): @@ -132,13 +132,33 @@ class PipelineTrace(OrderedDict): self['input'] = start_acc def __setitem__(self, key, value: GenericImageDataAccessor): - if self.enforce_accessors: - assert isinstance(value, GenericImageDataAccessor), f'Pipeline trace expects data accessor type' + if isinstance(value, GenericImageDataAccessor): + acc = value + else: + if self.enforce_accessors: + raise NoAccessorsFoundError(f'Pipeline trace expects data accessor type') + else: + acc = InMemoryDataAccessor(value) if not self.allow_overwrite and key in self.keys(): raise KeyAlreadyExists(f'key {key} already exists in pipeline trace') self.timer.__setitem__(key, self.tfunc() - self.last_time) self.last_time = self.tfunc() - return super().__setitem__(key, value) + return super().__setitem__(key, acc) + + def append(self, tr, skip_first=True): + new_tr = self.copy() + for k, v in tr.items(): + if skip_first and v == tr.first: + continue + dt = tr.timer[k] + if k == 'input': + k = 'appended_input' + if not self.allow_overwrite and k in self.keys(): + raise KeyAlreadyExists(f'Trying to append trace with key {k} that already exists') + new_tr.__setitem__(k, v) + new_tr.timer.__setitem__(k, dt) + new_tr.last_time = self.tfunc() + return new_tr @property def times(self): @@ -147,6 +167,14 @@ class PipelineTrace(OrderedDict): """ return {k: self.timer[k] for k in self.keys()} + @property + def first(self): + """ + Return first item + :return: + """ + return list(self.values())[0] + @property def last(self): """ diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index cfeec9eec6d4de1e2d56a358055140a58893b4a6..f7cc75f2420df7ddda4990758080162bcdd9b414 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -17,7 +17,7 @@ from skimage.measure import approximate_polygon, find_contours, label, points_in from skimage.morphology import binary_dilation, disk from .accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file -from .models import InstanceSegmentationModel +from .models import InstanceMaskSegmentationModel from .process import get_safe_contours, pad, rescale, resample_to_8bit, make_rgb from .annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image from .accessors import generate_file_accessor, PatchStack @@ -48,6 +48,8 @@ class RoiFilterRange(BaseModel): class RoiFilter(BaseModel): area: Union[RoiFilterRange, None] = None + diag: Union[RoiFilterRange, None] = None + min_hw: Union[RoiFilterRange, None] = None class RoiSetMetaParams(BaseModel): @@ -75,7 +77,7 @@ def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connec """ if allow_3d and connect_3d: nda_la = label( - acc_seg_mask.data_xyz, + acc_seg_mask.data_yxz, connectivity=3, ).astype('uint16') return InMemoryDataAccessor(np.expand_dims(nda_la, 2)) @@ -84,7 +86,7 @@ def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connec la_3d = np.zeros((*acc_seg_mask.hw, 1, acc_seg_mask.nz), dtype='uint16') for zi in range(0, acc_seg_mask.nz): la_2d = label( - acc_seg_mask.data_xyz[:, :, zi], + acc_seg_mask.data_yxz[:, :, zi], connectivity=2, ).astype('uint16') la_2d[la_2d > 0] = la_2d[la_2d > 0] + nla @@ -94,7 +96,7 @@ def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connec else: return InMemoryDataAccessor( label( - acc_seg_mask.get_mip().data_xy, + acc_seg_mask.get_mip().data_yx, connectivity=1, ).astype('uint16') ) @@ -115,7 +117,9 @@ def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: query_str = 'label > 0' # always true if filters is not None: # parse filters for k, val in filters.dict(exclude_unset=True).items(): - assert k in ('area') + assert k in ('area', 'diag', 'min_hw') + if val is None: + continue vmin = val['min'] vmax = val['max'] assert vmin >= 0 @@ -212,7 +216,7 @@ def filter_df_overlap_seg(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.Dat return dfbb -def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by, deproject_channel=None) -> pd.DataFrame: +def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by, deproject_channel=None, filters=None) -> pd.DataFrame: """ Build dataframe that associate object IDs with summary stats; :param acc_raw: accessor to raw image data @@ -234,10 +238,10 @@ def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by, deproject_chann ) acc_raw.get_mono(deproject_channel) - zi_map = acc_raw.get_mono(deproject_channel).get_z_argmax().data_xy.astype('uint16') + zi_map = acc_raw.get_mono(deproject_channel).get_z_argmax().data_yx.astype('uint16') assert len(zi_map.shape) == 2 df = pd.DataFrame(regionprops_table( - acc_obj_ids.data_xy, + acc_obj_ids.data_yx, intensity_image=zi_map, properties=('label', 'area', 'intensity_mean', 'bbox') )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'}) @@ -245,7 +249,7 @@ def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by, deproject_chann else: # objects' z-coordinates come from arg of max count in object identities map df = pd.DataFrame(regionprops_table( - acc_obj_ids.data_xyz, + acc_obj_ids.data_yxz, properties=('label', 'area', 'bbox') )).rename(columns={ 'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1' @@ -258,17 +262,19 @@ def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by, deproject_chann df = df_insert_slices(df, acc_raw.shape_dict, expand_box_by) + df_fil = filter_df(df, filters) + def _make_binary_mask(r): acc = InMemoryDataAccessor(acc_obj_ids.data == r.label) - cropped = acc.get_mono(0, mip=True).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_xy + cropped = acc.get_mono(0, mip=True).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_yx return cropped - df['binary_mask'] = df.apply( + df_fil['binary_mask'] = df_fil.apply( _make_binary_mask, axis=1, result_type='reduce', ) - return df + return df_fil def df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame: @@ -278,6 +284,9 @@ def df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame: df['h'] = df['y1'] - df['y0'] df['w'] = df['x1'] - df['x0'] + df['diag'] = (df['w']**2 + df['h']**2).apply(sqrt) + df['min_hw'] = df[['w', 'h']].min(axis=1) + ebxy, ebz = expand_box_by df['ebb_y0'] = (df.y0 - ebxy).apply(lambda x: max(x, 0)) df['ebb_y1'] = (df.y1 + ebxy).apply(lambda x: min(x, h)) @@ -387,13 +396,11 @@ class RoiSet(object): """ assert acc_obj_ids.chroma == 1 - df = filter_df( - make_df_from_object_ids( - acc_raw, acc_obj_ids, - expand_box_by=params.expand_box_by, - deproject_channel=params.deproject_channel, - ), - params.filters, + df = make_df_from_object_ids( + acc_raw, acc_obj_ids, + expand_box_by=params.expand_box_by, + deproject_channel=params.deproject_channel, + filters=params.filters, ) return cls(acc_raw, df, params) @@ -508,6 +515,14 @@ class RoiSet(object): params=params, ) + @property + def info(self): + return { + 'raw_shape_dict': self.acc_raw.shape_dict, + 'count': self.count, + 'classify_by': self.classification_columns, + } + def get_df(self) -> pd.DataFrame: return self._df @@ -569,7 +584,7 @@ class RoiSet(object): def classify_by( self, name: str, channels: list[int], - object_classification_model: InstanceSegmentationModel, + object_classification_model: InstanceMaskSegmentationModel, ): """ Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing @@ -637,10 +652,11 @@ class RoiSet(object): return df_overlaps - def get_object_class_map(self, name: str) -> InMemoryDataAccessor: + def get_object_class_map(self, name: str, filter_by: Union[List, None] = None) -> InMemoryDataAccessor: """ For a given classification result, return a map where object IDs are replaced by each object's class :param name: name of the classification result, same as passed to RoiSet.classify_by() + :param filter_by: only include ROIs if the intersection of all specified classifications is True :return: accessor of object class map """ colname = ('classify_by_' + name) @@ -650,8 +666,11 @@ class RoiSet(object): def _label_object_class(roi): om[self.acc_obj_ids.data == roi.label] = roi[colname] - self._df.apply(_label_object_class, axis=1) - + if filter_by is None: + self._df.apply(_label_object_class, axis=1) + else: + pd_fil = self._df[[f'classify_by_{fb}' for fb in filter_by]] + self._df.loc[pd_fil.all(axis=1), :].apply(_label_object_class, axis=1) return InMemoryDataAccessor(om) def get_serializable_dataframe(self) -> pd.DataFrame: @@ -735,7 +754,7 @@ class RoiSet(object): if white_channel: assert white_channel < raw.chroma - mono = raw.get_mono(white_channel).data_xyz + mono = raw.get_mono(white_channel).data_yxz stack = np.stack([mono, mono, mono], axis=2) else: stack = np.zeros([*raw.shape[0:2], 3, raw.shape[3]], dtype=raw.dtype) @@ -748,7 +767,7 @@ class RoiSet(object): stack[:, :, ii, :] = safe_add( stack[:, :, ii, :], # either black or grayscale channel rgb_overlay_weights[ii], - raw.get_mono(ci).data_xyz + raw.get_mono(ci).data_yxz ) else: if white_channel is not None: # interpret as just a single channel @@ -763,7 +782,7 @@ class RoiSet(object): annotate_rgb = True break if annotate_rgb: # make RGB patches anyway to include annotation color - mono = raw.get_mono(white_channel).data_xyz + mono = raw.get_mono(white_channel).data_yxz stack = np.stack([mono, mono, mono], axis=2) else: # make monochrome patches stack = raw.get_mono(white_channel).data @@ -954,11 +973,12 @@ class RoiSet(object): return interm - def serialize(self, where: Path, prefix='roiset') -> dict: + def serialize(self, where: Path, prefix='roiset', allow_overwrite=True) -> dict: """ Export the minimal information needed to recreate RoiSet object, i.e. CSV data file and tight patch masks :param where: path of directory in which to write files :param prefix: (optional) prefix + :param allow_overwrite: freely overwrite CSV file of same name if True :return: nested dict of Path objects describing the locations of export products """ record = {} @@ -977,6 +997,9 @@ class RoiSet(object): record['tight_patch_masks'] = list(se_pa) csv_path = where / 'dataframe' / (prefix + '.csv') + if not allow_overwrite and csv_path.exists(): + raise SerializeRoiSetError(f'Cannot overwrite RoiSet file {csv_path.__str__()}') + csv_path.parent.mkdir(parents=True, exist_ok=True) self.export_dataframe(csv_path) @@ -1045,10 +1068,10 @@ class RoiSet(object): try: ma_acc = generate_file_accessor(pa_masks / fname) assert ma_acc.chroma == 1 and ma_acc.nz == 1 - mask_data = ma_acc.data_xy / np.iinfo(ma_acc.data.dtype).max + mask_data = ma_acc.data_yx / ma_acc.dtype_max return mask_data except Exception as e: - raise DeserializeRoiSet(e) + raise DeserializeRoiSetError(e) df['binary_mask'] = df.apply(_read_binary_mask, axis=1) id_mask = make_object_ids_from_df(df, acc_raw.shape_dict) @@ -1069,6 +1092,7 @@ class RoiSet(object): class RoiSetWithDerivedChannelsExportParams(RoiSetExportParams): derived_channels: bool = False + class RoiSetWithDerivedChannels(RoiSet): def __init__(self, *a, **k): @@ -1077,7 +1101,7 @@ class RoiSetWithDerivedChannels(RoiSet): def classify_by( self, name: str, channels: list[int], - object_classification_model: InstanceSegmentationModel, + object_classification_model: InstanceMaskSegmentationModel, derived_channel_functions: list[callable] = None ): """ @@ -1093,36 +1117,27 @@ class RoiSetWithDerivedChannels(RoiSet): :return: None """ - raw_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) # all channels + acc_in = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) if derived_channel_functions is not None: - mono_data = [raw_acc.get_mono(c).data for c in range(0, raw_acc.chroma)] for fcn in derived_channel_functions: - der = fcn(raw_acc) # returns patch stack - if der.shape != mono_data[0].shape or der.dtype not in ['uint8', 'uint16']: - raise DerivedChannelError( - f'Error processing derived channel {der} with shape {der.shape_dict} and dtype {der.dtype}' - ) + der = fcn(acc_in) # returns patch stack self.accs_derived.append(der) # combine channels - data_derived = [acc.data for acc in self.accs_derived] - input_acc = PatchStack( - np.concatenate( - [*mono_data, *data_derived], - axis=raw_acc._ga('C') - ) - ) + acc_app = acc_in + for acc_der in self.accs_derived: + acc_app = acc_app.append_channels(acc_der) else: - input_acc = raw_acc + acc_app = acc_in # do this on a patch basis, i.e. only one object per frame obmap_patches = object_classification_model.label_patch_stack( - input_acc, + acc_app, self.get_patch_masks_acc(expanded=False, pad_to=None) ) - self._df['classify_by_' + name] = pd.Series(dtype='Int64') + self._df['classify_by_' + name] = pd.Series(dtype='Int16') for i, roi in enumerate(self): oc = np.unique( @@ -1153,29 +1168,100 @@ class RoiSetWithDerivedChannels(RoiSet): record[k].append(str(fp)) return record + +class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationModel): + def __init__(self, tr: float = 0.5): + """ + Model that labels all objects as class 1 if the intensity in specified channel exceeds a threshold; labels all + objects as class 1 if threshold = 0.0 + :param tr: threshold in range of 0.0 to 1.0; model handles normalization to full pixel intensity range + :param channel: channel to use for thresholding + """ + self.tr = tr + self.loaded = self.load() + super().__init__(info={'tr': tr}) + + def load(self): + return True + + def infer( + self, + img: GenericImageDataAccessor, + mask: GenericImageDataAccessor, + allow_3d: bool = False, + connect_3d: bool = True, + ) -> GenericImageDataAccessor: + if img.chroma != 1: + raise ShapeMismatchError( + f'IntensityThresholdInstanceMaskSegmentationModel expects 1 channel but received {img.chroma}' + ) + if isinstance(img, PatchStack): # assume one object per patch + df = img.get_object_df(mask) + om = np.zeros(mask.shape, 'uint16') + def _label_patch_class(la): + om[la] = (mask.iat(la).data > 0) * 1 + df.loc[df['intensity_mean'] > (self.tr * img.dtype_max), 'label'].apply(_label_patch_class) + return PatchStack(om) + else: + labels = get_label_ids(mask) + df = pd.DataFrame(regionprops_table( + labels.data_yxz, + intensity_image=img.data_yxz, + properties=('label', 'area', 'intensity_mean') + )) + + om = np.zeros(labels.shape, labels.dtype) + def _label_object_class(la): + om[labels.data == la] = 1 + df.loc[df['intensity_mean'] > (self.tr * img.dtype_max), 'label'].apply(_label_object_class) + return InMemoryDataAccessor(om) + + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + super().label_instance_class(img, mask, **kwargs) + return self.infer(img, mask) + + class Error(Exception): pass + class BoundingBoxError(Error): pass -class DeserializeRoiSet(Error): + +class DataFrameQueryError(Error): + pass + + +class DeserializeRoiSetError(Error): pass + +class SerializeRoiSetError(Error): + pass + + class NoDeprojectChannelSpecifiedError(Error): pass + class DerivedChannelError(Error): pass + class MissingSegmentationError(Error): pass + class PatchMaskShapeError(Error): pass + class ShapeMismatchError(Error): pass + class MissingInstanceLabelsError(Error): pass \ No newline at end of file diff --git a/model_server/base/session.py b/model_server/base/session.py index 7e06003f8be9407f6a8ee4f1816b70fac12f1900..c2c596ed28cd43d7625e68589824aeeb5c7afa4d 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -12,6 +12,7 @@ import pandas as pd from ..conf import defaults from .accessors import GenericImageDataAccessor, PatchStack from .models import Model +from .roiset import RoiSet logger = logging.getLogger(__name__) @@ -38,10 +39,11 @@ class _Session(object): log_format = '%(asctime)s - %(levelname)s - %(message)s' - def __init__(self, root: str = None): - self.models = {} # model_id : model object - self.paths = self.make_paths(root) + def __init__(self): + self.models = {} # model_id : model object + self.paths = self.make_paths() self.accessors = OrderedDict() + self.rois = OrderedDict() self.logfile = self.paths['logs'] / f'session.log' logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True, format=self.log_format) @@ -93,7 +95,7 @@ class _Session(object): raise AccessorIdError(f'Access with ID {accessor_id} already exists') if accessor_id is None: idx = len(self.accessors) - accessor_id = f'auto_{idx:06d}' + accessor_id = f'acc_{idx:06d}' self.accessors[accessor_id] = {'loaded': True, 'object': acc, **acc.info} return accessor_id @@ -187,20 +189,99 @@ class _Session(object): self.accessors[acc_id]['filepath'] = fp.__str__() return fp.name + def add_roiset(self, roiset: RoiSet, roiset_id: str = None) -> str: + """ + Add an RoiSet to session context + :param roiset: RoiSet to add + :param roiset_id: unique ID, or autogenerate if None + :return: ID of RoiSet + """ + if roiset_id in self.rois.keys(): + raise AccessorIdError(f'RoiSet with ID {roiset_id} already exists') + if roiset_id is None: + idx = len(self.rois) + roiset_id = f'roiset_{idx:06d}' + self.rois[roiset_id] = {'loaded': True, 'object': roiset, **roiset.info} + return roiset_id + + def del_roiset(self, roiset_id: str) -> str: + """ + Remove RoiSet object but retain its info dictionary + :param roiset_id: RoiSet's ID + :return: ID of RoiSet + """ + if roiset_id not in self.rois.keys(): + raise RoiSetIdError(f'No RoiSet with ID {roiset_id} is registered') + v = self.rois[roiset_id] + if isinstance(v, dict) and v['loaded'] is False: + logger.warning(f'RoiSet {roiset_id} is already deleted') + else: + assert isinstance(v['object'], RoiSet) + v['loaded'] = False + v['object'] = None + return roiset_id + + def list_rois(self) -> dict: + """ + List information about all ROIs in JSON-readable format + """ + if len(self.rois): + return pd.DataFrame(self.rois).drop('object').to_dict() + else: + return {} + + def get_roiset_info(self, roiset_id: str) -> dict: + """ + Get information about a single RoiSe + """ + if roiset_id not in self.rois.keys(): + raise RoiSetIdError(f'No RoiSet with ID {roiset_id} is registered') + return self.list_rois()[roiset_id] + + def get_roiset(self, roiset_id: str, pop: bool = False) -> RoiSet: + """ + Return an RoiSet object + :param acc_id: RoiSet's ID + :param pop: remove object from session RoiSet registry if True + :return: RoiSet object + """ + if roiset_id not in self.rois.keys(): + raise RoiSetIdError(f'No RoiSet with ID {roiset_id} is registered') + acc = self.rois[roiset_id]['object'] + if pop: + self.del_roiset(roiset_id) + return acc + + def write_roiset(self, roiset_id: str, where: Union[str, None] = None) -> str: + """ + Write an RoiSet to file and unload it from the session + :param roiset_id: RoiSet's ID + :param filename: force use of a specific filename, raise InvalidPathError if this already exists + :return: name of file + """ + if where is None: + fp = self.paths['rois'] + else: + fp = Path(where) + + roiset = self.get_roiset(roiset_id, pop=True) + rec = roiset.serialize(fp, prefix=roiset_id, allow_overwrite=False) + self.rois[roiset_id]['files'] = rec + return rec + @staticmethod - def make_paths(root: str = None) -> dict: + def make_paths() -> dict: """ - Set paths where images, logs, etc. are located in this session - :param root: absolute path to top-level directory + Set paths where images, logs, etc. are located in this session; can set custom session root data directory + with SVLT_SESSION_ROOT environmental variable :return: dictionary of session paths """ - if root is None: - root_path = Path(defaults.root) - else: - root_path = Path(root) + + root = os.environ.get('SVLT_SESSION_ROOT', defaults.root) + root_path = Path(root) sid = _Session.create_session_id(root_path) - paths = {'root': root_path} - for pk in ['inbound_images', 'outbound_images', 'logs', 'tables']: + paths = {'root': root_path, 'session': root_path / sid} + for pk in ['inbound_images', 'outbound_images', 'logs', 'rois', 'tables']: pa = root_path / sid / defaults.subdirectories[pk] paths[pk] = pa try: @@ -248,10 +329,13 @@ class _Session(object): Load an instance of a given model class and attach to this session's model registry :param ModelClass: subclass of Model :param key: unique identifier of model, or autogenerate if None - :param params: optional parameters that are passed to the model's construct + :param params: optional parameters that are passed to the model's constructor :return: model_id of loaded model """ - mi = ModelClass(params=params) + if params: + mi = ModelClass(**params.dict()) + else: + mi = ModelClass() assert mi.loaded, f'Error loading instance of {ModelClass.__name__}' ii = 0 @@ -268,7 +352,7 @@ class _Session(object): self.models[key] = { 'object': mi, - 'params': getattr(mi, 'params', None) + 'info': getattr(mi, 'info', None) } self.log_info(f'Loaded model {key}') return key @@ -277,7 +361,7 @@ class _Session(object): return { k: { 'class': self.models[k]['object'].__class__.__name__, - 'params': self.models[k]['params'], + 'info': self.models[k]['info'], } for k in self.models.keys() } @@ -291,15 +375,15 @@ class _Session(object): models = self.describe_loaded_models() for mid, det in models.items(): if is_path: - if PureWindowsPath(det.get('params').get(key)).as_posix() == Path(value).as_posix(): + if PureWindowsPath(det.get('info').get(key)).as_posix() == Path(value).as_posix(): return mid else: - if det.get('params').get(key) == value: + if det.get('info').get(key) == value: return mid return None def restart(self, **kwargs): - self.__init__(**kwargs) + self.__init__() # create singleton instance @@ -318,6 +402,9 @@ class CouldNotInstantiateModelError(Error): class AccessorIdError(Error): pass +class RoiSetIdError(Error): + pass + class WriteAccessorError(Error): pass diff --git a/model_server/conf/defaults.py b/model_server/conf/defaults.py index ff4f9040ceb9d9d14d947bbbce19185491d490fd..879cdb6a95f0995cee7ee358653ad647059136d1 100644 --- a/model_server/conf/defaults.py +++ b/model_server/conf/defaults.py @@ -1,11 +1,12 @@ from pathlib import Path -root = Path.home() / 'model_server' / 'sessions' +root = Path.home() / 'svlt' / 'sessions' subdirectories = { 'logs': 'logs', 'inbound_images': 'images/inbound', 'outbound_images': 'images/outbound', + 'rois': 'rois', 'tables': 'tables', } server_conf = { diff --git a/model_server/conf/startup.py b/model_server/conf/startup.py new file mode 100644 index 0000000000000000000000000000000000000000..b340363769b2ce49f0c994e3812c4031a1ec4fda --- /dev/null +++ b/model_server/conf/startup.py @@ -0,0 +1,58 @@ +from multiprocessing import Process +import os +import requests +from requests.adapters import HTTPAdapter +from urllib3 import Retry +import uvicorn +import webbrowser + + +def main(host, port, confpath, reload, debug, root) -> None: + + if root: + os.environ['SVLT_SESSION_ROOT'] = root + + server_process = Process( + target=uvicorn.run, + args=(f'{confpath}:app',), + kwargs={ + 'app_dir': '..', + 'host': host, + 'port': int(port), + 'log_level': 'debug', + 'reload': reload, + }, + daemon=(reload is False), + ) + url = f'http://{host}:{int(port):04d}/' + print(url) + server_process.start() + + try: + sesh = requests.Session() + retries = Retry( + total=10, + backoff_factor=0.5, + ) + sesh.mount('http://', HTTPAdapter(max_retries=retries)) + resp = sesh.get(url + 'status') + assert resp.status_code == 200 + except Exception: + print('Error starting server') + server_process.terminate() + exit() + + webbrowser.open(url + 'status', new=1, autoraise=True) + + if debug: + print('Running in debug mode') + print('Type "STOP" to stop server') + input_str = '' + while input_str.upper() != 'STOP': + input_str = input() + session_path = requests.get(url + 'paths/session').text + + print(f'Terminating server.\nSession data are located in {session_path}') + + server_process.terminate() + return session_path diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py index 963527d7f4eb70b58af458757d3d0631bcfcf293..e5065c1050e3ed5c6a7bf3dc6720ad938e61eaf5 100644 --- a/model_server/conf/testing.py +++ b/model_server/conf/testing.py @@ -14,7 +14,7 @@ from urllib3 import Retry from .fastapi import app from ..base.accessors import GenericImageDataAccessor, InMemoryDataAccessor -from ..base.models import SemanticSegmentationModel, InstanceSegmentationModel +from ..base.models import SemanticSegmentationModel, InstanceMaskSegmentationModel from ..base.session import session from ..base.accessors import generate_file_accessor @@ -52,7 +52,7 @@ def load_dummy_model() -> dict: @test_router.put('/models/dummy_instance/load/') def load_dummy_model() -> dict: - mid = session.load_model(DummyInstanceSegmentationModel) + mid = session.load_model(DummyInstanceMaskSegmentationModel) session.log_info(f'Loaded model {mid}') return {'model_id': mid} @@ -89,34 +89,54 @@ class TestServerBaseClass(unittest.TestCase): sesh.mount('http://', requests.adapters.HTTPAdapter(max_retries=retries)) return sesh - def _get(self, endpoint): - return self._get_sesh().get(self.uri + endpoint) + def assertGetSuccess(self, endpoint): + resp = self._get_sesh().get(self.uri + endpoint) + self.assertEqual(resp.status_code, 200, resp.text) + return resp.json() - def _put(self, endpoint, query=None, body=None): - return self._get_sesh().put( + def assertGetFailure(self, endpoint, code): + resp = self._get_sesh().get(self.uri + endpoint) + self.assertEqual(resp.status_code, code) + return resp + + def assertPutSuccess(self, endpoint, query={}, body={}): + resp = self._get_sesh().put( self.uri + endpoint, params=query, data=json.dumps(body) ) + self.assertEqual(resp.status_code, 200, resp.text) + return resp.json() + + def assertPutFailure(self, endpoint, code, query=None, body=None): + resp = self._get_sesh().put( + self.uri + endpoint, + params=query, + data=json.dumps(body) + ) + self.assertEqual(resp.status_code, code) + return resp + def tearDown(self) -> None: self.server_process.terminate() self.server_process.join() def copy_input_file_to_server(self): - resp = self._get('paths') - pa = resp.json()['inbound_images'] + r = self.assertGetSuccess('paths') + pa = r['inbound_images'] copyfile( self.input_data['path'], Path(pa) / self.input_data['name'] ) return self.input_data['name'] - def get_accessor(self, accessor_id, filename=None): - resp = self._put(f'/accessors/write_to_file/{accessor_id}', query={'filename': filename}) - where_out = self._get('paths').json()['outbound_images'] - fp_out = (Path(where_out) / resp.json()) + def get_accessor(self, accessor_id, filename=None, copy_to=None): + r = self.assertPutSuccess(f'/accessors/write_to_file/{accessor_id}', query={'filename': filename}) + fp_out = Path(self.assertGetSuccess('paths')['outbound_images']) / r self.assertTrue(fp_out.exists()) + if copy_to: + copyfile(fp_out, Path(copy_to) / f'normal_{fp_out.name}') return generate_file_accessor(fp_out) @@ -187,7 +207,7 @@ class DummySemanticSegmentationModel(SemanticSegmentationModel): def load(self): return True - def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): + def infer(self, img: GenericImageDataAccessor) -> GenericImageDataAccessor: super().infer(img) w = img.shape_dict['X'] h = img.shape_dict['Y'] @@ -201,7 +221,7 @@ class DummySemanticSegmentationModel(SemanticSegmentationModel): return mask -class DummyInstanceSegmentationModel(InstanceSegmentationModel): +class DummyInstanceMaskSegmentationModel(InstanceMaskSegmentationModel): model_id = 'dummy_pass_input_mask' @@ -221,5 +241,5 @@ class DummyInstanceSegmentationModel(InstanceSegmentationModel): """ Returns a trivial segmentation, i.e. the input mask with value 1 """ - super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs) + super(DummyInstanceMaskSegmentationModel, self).label_instance_class(img, mask, **kwargs) return self.infer(img, mask) diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index d098e03983e02c978231ca555826b567726b8dd5..8ba5ce04f33217f30690f93dc6d0edc738f6b9f9 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -2,38 +2,28 @@ import json from logging import getLogger import os from pathlib import Path -from typing import Union import warnings import numpy as np -from pydantic import BaseModel, Field import vigra import model_server.extensions.ilastik.conf from ...base.accessors import PatchStack from ...base.accessors import GenericImageDataAccessor, InMemoryDataAccessor -from ...base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel +from ...base.models import Model, ImageToImageModel, InstanceMaskSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel -class IlastikParams(BaseModel): - project_file: str = Field(description='(*.ilp) ilastik project filename') - duplicate: bool = Field( - True, - description='Load another instance of the same project file if True; return existing one if False' - ) - model_id: Union[str, None] = Field(None, description='Unique identifier of the model, or autogenerate if empty') class IlastikModel(Model): - def __init__(self, params: IlastikParams, autoload=True, enforce_embedded=True): + def __init__(self, project_file, autoload=True, enforce_embedded=True, **kwargs): """ Base class for models that run via ilastik shell API - :param params: - project_file: path to ilastik project file + :param: project_file: path to ilastik project file :param autoload: automatically load model into memory if true :param enforce_embedded: raise an error if all input data are not embedded in the project file, i.e. on the filesystem """ - pf = Path(params.project_file) + pf = Path(project_file) self.enforce_embedded = enforce_embedded if pf.is_absolute(): pap = pf @@ -46,7 +36,7 @@ class IlastikModel(Model): raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file') self.shell = None - super().__init__(autoload, params) + super().__init__(autoload, info={'project_file': project_file, **kwargs}) def load(self): # suppress warnings when loading ilastik app @@ -95,6 +85,14 @@ class IlastikModel(Model): dd[ci] = 1 return dd + @property + def labels(self): + return [] + + @property + def info(self): + return {**self._info, 'labels': self.labels} + @property def model_chroma(self): return self.model_shape_dict['C'] @@ -103,16 +101,18 @@ class IlastikModel(Model): def model_3d(self): return self.model_shape_dict['Z'] > 1 -class IlastikPixelClassifierParams(IlastikParams): - px_class: int = 0 - px_prob_threshold: float = 0.5 class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): - model_id = 'ilastik_pixel_classification' operations = ['segment', ] - def __init__(self, params: IlastikPixelClassifierParams, **kwargs): - super(IlastikPixelClassifierModel, self).__init__(params, **kwargs) + def __init__(self, px_class: int = 0, px_prob_threshold: float = 0.5, **kwargs): + self.px_class = px_class + self.px_prob_threshold = px_prob_threshold + super(IlastikPixelClassifierModel, self).__init__( + px_class=px_class, + px_prob_threshold=px_prob_threshold, + **kwargs + ) @staticmethod def get_workflow(): @@ -150,32 +150,42 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): [1, 2, 3, 0], [0, 1, 2, 3] ) - return InMemoryDataAccessor(data=yxcz), {'success': True} + return InMemoryDataAccessor(data=yxcz) - def infer_patch_stack(self, img: PatchStack, **kwargs) -> (np.ndarray, dict): + def infer_patch_stack(self, img: PatchStack, crop=True, normalize=False, **kwargs) -> (np.ndarray, dict): """ Iterative over a patch stack, call inference separately on each cropped patch + :param img: patch stack of input data + :param crop: pass list of cropped (generally non-uniform) patches to classifier if True + :param normalize: scale the inference result (generally 0.0 to 1.0) to the range of img if True """ - from ilastik.applets.featureSelection.opFeatureSelection import FeatureSelectionConstraintError - - nc = len(self.labels) - data = np.zeros((img.count, *img.hw, nc, img.nz), dtype=float) # interpret as PYXCZ - for i in range(0, img.count): - sl = img.get_slice_at(i) - try: - data[i][sl[0], sl[1], :, sl[3]] = self.infer(img.iat(i, crop=True))[0].data - except FeatureSelectionConstraintError: # occurs occasionally on small patches - continue - return PatchStack(data), {'success': True} + dsi = [ + { + 'Raw Data': self.PreloadedArrayDatasetInfo( + preloaded_array=vigra.taggedView(patch, 'yxcz')) + + } for patch in img.get_list(crop=crop) + ] + pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] + yxcz = [np.moveaxis(pm, [1, 2, 3, 0], [0, 1, 2, 3]) for pm in pxmaps] + if normalize: + return PatchStack(yxcz).apply( + lambda x: (x * img.dtype_max).astype(img.dtype) + ) + else: + return PatchStack(yxcz) def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs): - pxmap, _ = self.infer(img) - mask = pxmap.get_mono(self.params['px_class']).apply(lambda x: x > self.params['px_prob_threshold']) + pxmap = self.infer(img) + mask = pxmap.get_mono( + self.px_class, + ).apply( + lambda x: x > kwargs.get('px_prob_threshold', self.px_prob_threshold), + ) return mask -class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel): - model_id = 'ilastik_object_classification_from_segmentation' +class IlastikObjectClassifierFromMaskSegmentationModel(IlastikModel, InstanceMaskSegmentationModel): @staticmethod def _make_8bit_mask(nda): @@ -189,6 +199,11 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary return ObjectClassificationWorkflowBinary + @property + def labels(self): + h5 = self.shell.projectManager.currentProjectFile + return [None] + [l.decode() for l in h5['ObjectClassification/LabelNames'][()]] + def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): if self.model_chroma != input_img.chroma: raise IlastikInputShapeError( @@ -233,29 +248,32 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment [0, 1, 2, 3, 4], [0, 4, 1, 2, 3] ) - return PatchStack(data=pyxcz), {'success': True} + return PatchStack(data=pyxcz) else: yxcz = np.moveaxis( obmaps[0], [1, 2, 3, 0], [0, 1, 2, 3] ) - return InMemoryDataAccessor(data=yxcz), {'success': True} + return InMemoryDataAccessor(data=yxcz) def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): - super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) - obmap, _ = self.infer(img, mask) - return obmap + super(IlastikObjectClassifierFromMaskSegmentationModel, self).label_instance_class(img, mask, **kwargs) + return self.infer(img, mask) class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImageModel): - model_id = 'ilastik_object_classification_from_pixel_predictions' @staticmethod def get_workflow(): from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction return ObjectClassificationWorkflowPrediction + @property + def labels(self): + h5 = self.shell.projectManager.currentProjectFile + return [None] + [l.decode() for l in h5['ObjectClassification/LabelNames'][()]] + def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): if self.model_chroma != input_img.chroma: raise IlastikInputShapeError( @@ -292,14 +310,14 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag [0, 1, 2, 3, 4], [0, 4, 1, 2, 3] ) - return PatchStack(data=pyxcz), {'success': True} + return PatchStack(data=pyxcz) else: yxcz = np.moveaxis( obmaps[0], [1, 2, 3, 0], [0, 1, 2, 3] ) - return InMemoryDataAccessor(data=yxcz), {'success': True} + return InMemoryDataAccessor(data=yxcz) def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs): @@ -320,8 +338,7 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag pxch = kwargs.get('pixel_classification_channel', 0) pxtr = kwargs.get('pixel_classification_threshold', 0.5) mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr) - obmap, _ = self.infer(img, mask) - return obmap + return self.infer(img, mask) class Error(Exception): diff --git a/model_server/extensions/ilastik/pipelines/px_then_ob.py b/model_server/extensions/ilastik/pipelines/px_then_ob.py index 1aa64615479cad29f01b7d7e416d1fec2f3c9568..335e4115d0b9d0cdb2f9e405bfa402e88efc5a36 100644 --- a/model_server/extensions/ilastik/pipelines/px_then_ob.py +++ b/model_server/extensions/ilastik/pipelines/px_then_ob.py @@ -42,23 +42,23 @@ def pixel_then_object_classification_pipeline( **k ) -> PxThenObRecord: - if not isinstance(models['px_model'], IlastikPixelClassifierModel): + if not isinstance(models['px_'], IlastikPixelClassifierModel): raise IncompatibleModelsError( f'Expecting px_model to be an ilastik pixel classification model' ) - if not isinstance(models['ob_model'], IlastikObjectClassifierFromPixelPredictionsModel): + if not isinstance(models['ob_'], IlastikObjectClassifierFromPixelPredictionsModel): raise IncompatibleModelsError( f'Expecting ob_model to be an ilastik object classification from pixel predictions model' ) - d = PipelineTrace(accessors['accessor']) + d = PipelineTrace(accessors['']) if (ch := k.get('channel')) is not None: channels = [ch] else: channels = range(0, d['input'].chroma) d['select_channels'] = d.last.get_channels(channels, mip=k.get('mip', False)) - d['pxmap'], _ = models['px_model'].infer(d.last) - d['ob_map'], _ = models['ob_model'].infer(d['select_channels'], d['pxmap']) + d['pxmap'] = models['px_'].infer(d.last) + d['ob_map'] = models['ob_'].infer(d['select_channels'], d['pxmap']) return d diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py index 20a1ef24235c0c4ee3d8cc101ffba9e92d86560b..b909e0f654256cf9e1d08d0f33dced1e401f14bc 100644 --- a/model_server/extensions/ilastik/router.py +++ b/model_server/extensions/ilastik/router.py @@ -1,4 +1,5 @@ from fastapi import APIRouter +from pydantic import BaseModel, Field from model_server.base.session import session @@ -13,43 +14,58 @@ router = APIRouter( import model_server.extensions.ilastik.pipelines.px_then_ob router.include_router(model_server.extensions.ilastik.pipelines.px_then_ob.router) + +class IlastikParams(BaseModel): + project_file: str = Field(description='(*.ilp) ilastik project filename') + duplicate: bool = Field( + True, + description='Load another instance of the same project file if True; return existing one if False' + ) + +class IlastikPixelClassifierParams(IlastikParams): + px_class: int = 0 + px_prob_threshold: float = 0.5 + @router.put('/seg/load/') -def load_px_model(p: ilm.IlastikPixelClassifierParams) -> dict: +def load_px_model(p: IlastikPixelClassifierParams, model_id=None) -> dict: """ Load an ilastik pixel classifier model from its project file """ return load_ilastik_model( ilm.IlastikPixelClassifierModel, p, + model_id=model_id, ) @router.put('/pxmap_to_obj/load/') -def load_pxmap_to_obj_model(p: ilm.IlastikParams) -> dict: +def load_pxmap_to_obj_model(p: IlastikParams, model_id=None) -> dict: """ Load an ilastik object classifier from pixel predictions model from its project file """ return load_ilastik_model( ilm.IlastikObjectClassifierFromPixelPredictionsModel, p, + model_id=model_id, ) @router.put('/seg_to_obj/load/') -def load_seg_to_obj_model(p: ilm.IlastikParams) -> dict: +def load_seg_to_obj_model(p: IlastikParams, model_id=None) -> dict: """ Load an ilastik object classifier from segmentation model from its project file """ return load_ilastik_model( - ilm.IlastikObjectClassifierFromSegmentationModel, + ilm.IlastikObjectClassifierFromMaskSegmentationModel, p, + model_id=model_id, ) -def load_ilastik_model(model_class: ilm.IlastikModel, p: ilm.IlastikParams) -> dict: - project_file = p.project_file +def load_ilastik_model(model_class: ilm.IlastikModel, p: IlastikParams, model_id=None) -> dict: + pf = p.project_file if not p.duplicate: - existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True) + existing_model_id = session.find_param_in_loaded_models('project_file', pf, is_path=True) if existing_model_id is not None: - session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate') + session.log_info(f'An ilastik model from {pf} already existing exists; did not load a duplicate') return {'model_id': existing_model_id} - result = session.load_model(model_class, key=p.model_id, params=p) - session.log_info(f'Loaded ilastik model {result} from {project_file}') + result = session.load_model(model_class, key=model_id, params=p) + session.log_info(f'Loaded ilastik model {result} from {pf}') return {'model_id': result} \ No newline at end of file diff --git a/requirements.yml b/requirements.yml index fa5b53ab6818d693eeb428cc3008de839412a8c1..1cbbc32b7e9d05392f18ac8e9fad457e12d4ed1c 100644 --- a/requirements.yml +++ b/requirements.yml @@ -6,13 +6,14 @@ channels: dependencies: - czifile - fastapi>=0.101 - - ilastik=1.4.1b15 + - ilastik=1.4.1b6 - imagecodecs - jupyterlab - matplotlib - numpy - pandas - pillow + - protobuf ==4.25.3 - pydantic=1.10.* - pytorch=1.* - scikit-image>=0.21.0 diff --git a/scripts/run_server.py b/scripts/run_server.py index 4da04d5dde8e851e6c2e2aed586e8b94f2bb1c02..3710c86eaec383a4bd8816a4a8f81992db01914f 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -1,14 +1,7 @@ import argparse -from multiprocessing import Process -from pathlib import Path -import requests -from requests.adapters import HTTPAdapter -from urllib3 import Retry -import uvicorn -import webbrowser - -from model_server.conf.defaults import server_conf +from model_server.conf.defaults import root, server_conf +from model_server.conf.startup import main def parse_args(): parser = argparse.ArgumentParser( @@ -29,6 +22,11 @@ def parse_args(): default=str(server_conf['port']), help='bind socket to this port', ) + parser.add_argument( + '--root', + default=root.__str__(), + help='root directory of session data' + ) parser.add_argument( '--debug', action='store_true', @@ -39,55 +37,13 @@ def parse_args(): action='store_true', help='automatically restart server when changes are noticed, for development purposes' ) - return parser.parse_args() - + return parser.parse_args() -def main(args) -> None: +if __name__ == '__main__': + args = parse_args() print('CLI args:\n' + str(args)) - server_process = Process( - target=uvicorn.run, - args=(f'{args.confpath}:app',), - kwargs={ - 'app_dir': '..', - 'host': args.host, - 'port': int(args.port), - 'log_level': 'debug', - 'reload': args.reload, - }, - daemon=(args.reload is False), - ) - url = f'http://{args.host}:{int(args.port):04d}/status' - print(url) - server_process.start() - - try: - sesh = requests.Session() - retries = Retry( - total=5, - backoff_factor=0.1, - ) - sesh.mount('http://', HTTPAdapter(max_retries=retries)) - resp = sesh.get(url) - assert resp.status_code == 200 - except Exception: - print('Error starting server') - server_process.terminate() - exit() - - webbrowser.open(url, new=1, autoraise=True) - - if args.debug: - print('Running in debug mode') - print('Type "STOP" to stop server') - input_str = '' - while input_str.upper() != 'STOP': - input_str = input() - - server_process.terminate() + main(**args.__dict__) print('Finished') - -if __name__ == '__main__': - main(parse_args()) diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py index 1df5863c71f1d157d612e74ba83a822632d9f221..e73e347029f5ec42cd01696ebdf01a0ae4c0cff0 100644 --- a/tests/base/test_accessors.py +++ b/tests/base/test_accessors.py @@ -207,6 +207,27 @@ class TestCziImageFileAccess(unittest.TestCase): pxs = cf.pixel_scale_in_micrometers self.assertAlmostEqual(pxs['X'], data['czifile']['um_per_pixel'], places=3) + def test_append_channels(self): + w = 256 + h = 512 + nc = 2 + nz = 11 + acc1 = InMemoryDataAccessor(np.random.rand(h, w, nc, nz)) + acc2 = InMemoryDataAccessor(np.random.rand(h, w, nc, nz)) + + app = acc1.append_channels(acc2) + self.assertEqual(app.shape, (h, w, 2 * nc, nz)) + self.assertTrue( + np.all( + app.get_channels([0, 1]).data == acc1.data + ) + ) + self.assertTrue( + np.all( + app.get_channels([2, 3]).data == acc2.data + ) + ) + class TestPatchStackAccessor(unittest.TestCase): def setUp(self) -> None: @@ -233,7 +254,6 @@ class TestPatchStackAccessor(unittest.TestCase): self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1)) return acc - def test_make_patch_stack_from_file(self): h = data['monozstackmask']['h'] w = data['monozstackmask']['w'] @@ -252,7 +272,7 @@ class TestPatchStackAccessor(unittest.TestCase): def test_make_3d_patch_stack_from_nonuniform_list(self): w = 256 h = 512 - c = 1 + c = 3 nz = 5 n = 4 @@ -270,6 +290,17 @@ class TestPatchStackAccessor(unittest.TestCase): self.assertEqual(patches[i].shape, acc.iat(i, crop=True).shape) self.assertEqual(acc.shape[1:], acc.iat(i, crop=False).shape) + ps_list_cropped = acc.get_list(crop=True) + self.assertTrue(all([np.all(ps_list_cropped[i] == patches[i]) for i in range(0, n)])) + + ps_list_uncropped = acc.get_list(crop=False) + self.assertTrue(all([p.shape == acc.shape[1:] for p in ps_list_uncropped])) + + # test that this persists after channel selection + for i in range(0, acc.count): + self.assertEqual(patches[i].shape[0:2], acc.get_channels([0]).iat(i, crop=True).shape[0:2]) + self.assertEqual(patches[i].shape[3], acc.get_channels([0]).iat(i, crop=True).shape[3]) + def test_make_3d_patch_stack_from_list_force_long_dim(self): def _r(h, w): return np.random.randint(0, 2 ** 8, size=(h, w, 1, 1), dtype='uint8') @@ -301,6 +332,42 @@ class TestPatchStackAccessor(unittest.TestCase): self.assertEqual(acc.count, n) self.assertEqual(acc.pczyx.shape, (n, nc, nz, h, w)) self.assertEqual(acc.hw, (h, w)) + self.assertEqual(acc.get_mono(channel=0).data_yxz.shape, (n, h, w, nz)) + self.assertEqual(acc.get_mono(channel=0, mip=True).data_yx.shape, (n, h, w)) + return acc + + def test_can_mask(self): + w = 30 + h = 20 + n = 2 + nz = 3 + nc = 3 + acc1 = PatchStack(_random_int(n, h, w, nc, nz)) + acc2 = PatchStack(_random_int(n, 2*h, w, nc, nz)) + mask = PatchStack(_random_int(n, h, w, 1, nz) > 0.5) + self.assertFalse(acc1.is_mask()) + self.assertFalse(acc2.is_mask()) + self.assertTrue(mask.is_mask()) + self.assertFalse(acc1.can_mask(acc2)) + self.assertFalse(acc1.can_mask(acc2)) + self.assertTrue(mask.can_mask(acc1)) + self.assertFalse(mask.can_mask(acc2)) + + def test_object_df(self): + w = 30 + h = 20 + n = 2 + nz = 3 + nc = 3 + acc = PatchStack(_random_int(n, h, w, nc, nz)) + mask_data= np.zeros((n, h, w, 1, nz), dtype='uint8') + mask_data[0, 0:5, 0:5, :, :] = 255 + mask_data[1, 0:10, 0:10, :, :] = 255 + mask = PatchStack(mask_data) + df = acc.get_mono(0).get_object_df(mask) + # intensity means are centered around half of full range + self.assertTrue(np.all(((df['intensity_mean'] / acc.dtype_max) - 0.5)**2 < 1e-2)) + self.assertTrue(df['area'][1] / df['area'][0] == 4.0) return acc def test_get_one_channel(self): @@ -331,4 +398,20 @@ class TestPatchStackAccessor(unittest.TestCase): fp = output_path / 'patch_hyperstack.tif' acc.export_pyxcz(fp) acc2 = make_patch_stack_from_file(fp) - self.assertEqual(acc.shape, acc2.shape) \ No newline at end of file + self.assertEqual(acc.shape, acc2.shape) + + def test_append_channels(self): + acc = self.test_pczyx() + app = acc.append_channels(acc) + p, h, w, nc, nz = acc.shape + self.assertEqual(app.shape, (p, h, w, 2 * nc, nz)) + self.assertTrue( + np.all( + app.get_channels([0, 1, 2]).data == acc.data + ) + ) + self.assertTrue( + np.all( + app.get_channels([3, 4, 5]).data == acc.data + ) + ) \ No newline at end of file diff --git a/tests/base/test_api.py b/tests/base/test_api.py index 74c6a79e5dda0df8073335ce38205b22c9d570d2..a5df9e2f85539dc49e5700188395726bb06c9437 100644 --- a/tests/base/test_api.py +++ b/tests/base/test_api.py @@ -11,89 +11,64 @@ class TestApiFromAutomatedClient(TestServerBaseClass): input_data = czifile def test_trivial_api_response(self): - resp = self._get('') - self.assertEqual(resp.status_code, 200) + self.assertGetSuccess('') def test_bounceback_parameters(self): - resp = self._put('testing/bounce_back', body={'par1': 'hello', 'par2': ['ab', 'cd']}) - self.assertEqual(resp.status_code, 200, resp.content) - self.assertEqual(resp.json()['params']['par1'], 'hello', resp.json()) - self.assertEqual(resp.json()['params']['par2'], ['ab', 'cd'], resp.json()) + r = self.assertPutSuccess('testing/bounce_back', body={'par1': 'hello', 'par2': ['ab', 'cd']}) + self.assertEqual(r['params']['par1'], 'hello', r) + self.assertEqual(r['params']['par2'], ['ab', 'cd'], r) def test_default_session_paths(self): import model_server.conf.defaults - resp = self._get('paths') + r = self.assertGetSuccess('paths') conf_root = model_server.conf.defaults.root for p in ['inbound_images', 'outbound_images', 'logs']: - self.assertTrue(resp.json()[p].startswith(conf_root.__str__())) + self.assertTrue(r[p].startswith(conf_root.__str__())) suffix = Path(model_server.conf.defaults.subdirectories[p]).__str__() - self.assertTrue(resp.json()[p].endswith(suffix)) + self.assertTrue(r[p].endswith(suffix)) def test_list_empty_loaded_models(self): - resp = self._get('models') - self.assertEqual(resp.status_code, 200) - self.assertEqual(resp.content, b'{}') + r = self.assertGetSuccess('models') + self.assertEqual(r, {}) def test_load_dummy_semantic_model(self): - resp_load = self._put(f'testing/models/dummy_semantic/load') - model_id = resp_load.json()['model_id'] - self.assertEqual(resp_load.status_code, 200, resp_load.json()) - resp_list = self._get('models') - self.assertEqual(resp_list.status_code, 200) - rj = resp_list.json() - self.assertEqual(rj[model_id]['class'], 'DummySemanticSegmentationModel') - return model_id + mid = self.assertPutSuccess(f'testing/models/dummy_semantic/load')['model_id'] + rl = self.assertGetSuccess('models') + self.assertEqual(rl[mid]['class'], 'DummySemanticSegmentationModel') + return mid def test_load_dummy_instance_model(self): - resp_load = self._put(f'testing/models/dummy_instance/load') - model_id = resp_load.json()['model_id'] - self.assertEqual(resp_load.status_code, 200, resp_load.json()) - resp_list = self._get('models') - self.assertEqual(resp_list.status_code, 200) - rj = resp_list.json() - self.assertEqual(rj[model_id]['class'], 'DummyInstanceSegmentationModel') - return model_id - - def test_respond_with_error_when_invalid_filepath_requested(self): - model_id = self.test_load_dummy_semantic_model() - - resp = self._put( - f'infer/from_image_file', - query={'model_id': model_id, 'input_filename': 'not_a_real_file.name'} - ) - self.assertEqual(resp.status_code, 404, resp.content.decode()) + mid = self.assertPutSuccess(f'testing/models/dummy_instance/load')['model_id'] + rl = self.assertGetSuccess('models') + self.assertEqual(rl[mid]['class'], 'DummyInstanceMaskSegmentationModel') + return mid def test_pipeline_errors_when_ids_not_found(self): fname = self.copy_input_file_to_server() - model_id = self._put(f'testing/models/dummy_semantic/load').json()['model_id'] - in_acc_id = self._put(f'accessors/read_from_file/{fname}').json() + model_id = self.assertPutSuccess(f'testing/models/dummy_semantic/load')['model_id'] + in_acc_id = self.assertPutSuccess(f'accessors/read_from_file/{fname}') # respond with 409 for invalid accessor_id - self.assertEqual( - self._put( - f'pipelines/segment', - body={'model_id': model_id, 'accessor_id': 'fake'} - ).status_code, - 409 + self.assertPutFailure( + f'pipelines/segment', + 409, + body={'model_id': model_id, 'accessor_id': 'fake'}, ) # respond with 409 for invalid model_id - self.assertEqual( - self._put( - f'pipelines/segment', - body={'model_id': 'fake', 'accessor_id': in_acc_id} - ).status_code, - 409 + self.assertPutFailure( + f'pipelines/segment', + 409, + body={'model_id': 'fake', 'accessor_id': in_acc_id} ) - def test_i2i_dummy_inference_by_api(self): fname = self.copy_input_file_to_server() - model_id = self._put(f'testing/models/dummy_semantic/load').json()['model_id'] - in_acc_id = self._put(f'accessors/read_from_file/{fname}').json() + model_id = self.assertPutSuccess(f'testing/models/dummy_semantic/load')['model_id'] + in_acc_id = self.assertPutSuccess(f'accessors/read_from_file/{fname}') # run segmentation pipeline on preloaded accessor - resp_infer = self._put( + r = self.assertPutSuccess( f'pipelines/segment', body={ 'accessor_id': in_acc_id, @@ -102,109 +77,126 @@ class TestApiFromAutomatedClient(TestServerBaseClass): 'keep_interm': True, }, ) - self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) - out_acc_id = resp_infer.json()['output_accessor_id'] - self.assertTrue(self._get(f'accessors/{out_acc_id}').json()['loaded']) + out_acc_id = r['output_accessor_id'] + self.assertTrue(self.assertGetSuccess(f'accessors/{out_acc_id}')['loaded']) acc_out = self.get_accessor(out_acc_id, 'dummy_semantic_output.tif') self.assertEqual(acc_out.shape_dict['C'], 1) # validate intermediate data - resp_list = self._get(f'accessors').json() + resp_list = self.assertGetSuccess(f'accessors') self.assertEqual(len([k for k in resp_list.keys() if '_step' in k]), 2) def test_restarting_session_clears_loaded_models(self): - resp_load = self._put(f'testing/models/dummy_semantic/load',) - self.assertEqual(resp_load.status_code, 200, resp_load.json()) - resp_list_0 = self._get('models') - self.assertEqual(resp_list_0.status_code, 200) - rj0 = resp_list_0.json() - self.assertEqual(len(rj0), 1, f'Unexpected models in response: {rj0}') - resp_restart = self._get('session/restart') - resp_list_1 = self._get('models') - rj1 = resp_list_1.json() - self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}') + self.assertPutSuccess(f'testing/models/dummy_semantic/load') + rl0 = self.assertGetSuccess('models') + self.assertEqual(len(rl0), 1, f'Unexpected models in response: {rl0}') + self.assertGetSuccess('session/restart') + rl1 = self.assertGetSuccess('models') + self.assertEqual(len(rl1), 0, f'Unexpected models in response: {rl1}') def test_change_inbound_path(self): - resp_inpath = self._get('paths') - resp_change = self._put( + r_inpath = self.assertGetSuccess('paths') + self.assertPutSuccess( f'paths/watch_output', - query={'path': resp_inpath.json()['inbound_images']} + query={'path': r_inpath['inbound_images']} ) - self.assertEqual(resp_change.status_code, 200) - resp_check = self._get('paths') - self.assertEqual(resp_check.json()['inbound_images'], resp_check.json()['outbound_images']) + r_check = self.assertGetSuccess('paths') + self.assertEqual(r_check['inbound_images'], r_check['outbound_images']) def test_exception_when_changing_inbound_path(self): - resp_inpath = self._get('paths') + r_inpath = self.assertGetSuccess('paths') fakepath = 'c:/fake/path/to/nowhere' - resp_change = self._put( + r_change = self.assertPutFailure( f'paths/watch_output', - query={'path': fakepath} + 404, + query={'path': fakepath}, ) - self.assertEqual(resp_change.status_code, 404) - self.assertIn(fakepath, resp_change.json()['detail']) - resp_check = self._get('paths') - self.assertEqual(resp_inpath.json()['outbound_images'], resp_check.json()['outbound_images']) + self.assertIn(fakepath, r_change.json()['detail']) + r_check = self.assertGetSuccess('paths') + self.assertEqual(r_inpath['outbound_images'], r_check['outbound_images']) def test_no_change_inbound_path(self): - resp_inpath = self._get('paths') - resp_change = self._put( + r_inpath = self.assertGetSuccess('paths') + r_change = self.assertPutSuccess( f'paths/watch_output', - query={'path': resp_inpath.json()['outbound_images']} + query={'path': r_inpath['outbound_images']} ) - self.assertEqual(resp_change.status_code, 200) - resp_check = self._get('paths') - self.assertEqual(resp_inpath.json()['outbound_images'], resp_check.json()['outbound_images']) + r_check = self.assertGetSuccess('paths') + self.assertEqual(r_inpath['outbound_images'], r_check['outbound_images']) def test_get_logs(self): - resp = self._get('session/logs') - self.assertEqual(resp.status_code, 200) - self.assertEqual(resp.json()[0]['message'], 'Initialized session') + r = self.assertGetSuccess('session/logs') + self.assertEqual(r[0]['message'], 'Initialized session') def test_add_and_delete_accessor(self): fname = self.copy_input_file_to_server() # add accessor to session - resp_add_acc = self._put( + acc_id = self.assertPutSuccess( f'accessors/read_from_file/{fname}', ) - acc_id = resp_add_acc.json() - self.assertTrue(acc_id.startswith('auto_')) + self.assertTrue(acc_id.startswith('acc_')) # confirm that accessor is listed in session context - resp_list_acc = self._get( + acc_list = self.assertGetSuccess( f'accessors', ) - self.assertEqual(len(resp_list_acc.json()), 1) - self.assertTrue(list(resp_list_acc.json().keys())[0].startswith('auto_')) - self.assertTrue(resp_list_acc.json()[acc_id]['loaded']) + self.assertEqual(len(acc_list), 1) + self.assertTrue(list(acc_list.keys())[0].startswith('acc_')) + self.assertTrue(acc_list[acc_id]['loaded']) # delete and check that its 'loaded' state changes - self.assertTrue(self._get(f'accessors/{acc_id}').json()['loaded']) - self.assertEqual(self._get(f'accessors/delete/{acc_id}').json(), acc_id) - self.assertFalse(self._get(f'accessors/{acc_id}').json()['loaded']) + self.assertTrue(self.assertGetSuccess(f'accessors/{acc_id}')['loaded']) + self.assertEqual(self.assertGetSuccess(f'accessors/delete/{acc_id}'), acc_id) + self.assertFalse(self.assertGetSuccess(f'accessors/{acc_id}')['loaded']) # and try a non-existent accessor ID - resp_wrong_acc = self._get('accessors/auto_123456') - self.assertEqual(resp_wrong_acc.status_code, 404) + self.assertGetFailure('accessors/auto_123456', 404) # load another... then remove all - self._put(f'accessors/read_from_file/{fname}') - self.assertEqual(sum([v['loaded'] for v in self._get('accessors').json().values()]), 1) - self.assertEqual(len(self._get(f'accessors/delete/*').json()), 1) - self.assertEqual(sum([v['loaded'] for v in self._get('accessors').json().values()]), 0) - + self.assertPutSuccess(f'accessors/read_from_file/{fname}') + self.assertEqual(sum([v['loaded'] for v in self.assertGetSuccess('accessors').values()]), 1) + self.assertEqual(len(self.assertGetSuccess(f'accessors/delete/*')), 1) + self.assertEqual(sum([v['loaded'] for v in self.assertGetSuccess('accessors').values()]), 0) def test_empty_accessor_list(self): - resp_list_acc = self._get( + r_list = self.assertGetSuccess( f'accessors', ) - self.assertEqual(len(resp_list_acc.json()), 0) + self.assertEqual(len(r_list), 0) def test_write_accessor(self): - acc_id = self._put('/testing/accessors/dummy_accessor/load').json() - self.assertTrue(self._get(f'accessors/{acc_id}').json()['loaded']) - sd = self._get(f'accessors/{acc_id}').json()['shape_dict'] - self.assertEqual(self._get(f'accessors/{acc_id}').json()['filepath'], '') + acc_id = self.assertPutSuccess('/testing/accessors/dummy_accessor/load') + self.assertTrue(self.assertGetSuccess(f'accessors/{acc_id}')['loaded']) + sd = self.assertGetSuccess(f'accessors/{acc_id}')['shape_dict'] + self.assertEqual(self.assertGetSuccess(f'accessors/{acc_id}')['filepath'], '') acc_out = self.get_accessor(accessor_id=acc_id, filename='test_output.tif') - self.assertEqual(sd, acc_out.shape_dict) \ No newline at end of file + self.assertEqual(sd, acc_out.shape_dict) + + def test_binary_segmentation_model(self): + mid = self.assertPutSuccess( + '/models/seg/threshold/load/', body={'tr': 0.1} + )['model_id'] + + fname = self.copy_input_file_to_server() + acc_id = self.assertPutSuccess(f'accessors/read_from_file/{fname}') + r = self.assertPutSuccess( + f'pipelines/segment', + body={ + 'accessor_id': acc_id, + 'model_id': mid, + 'channel': 0, + 'keep_interm': True, + }, + ) + acc = self.get_accessor(r['output_accessor_id']) + self.assertTrue(all(acc.unique()[0] == [0, 255])) + self.assertTrue(all(acc.unique()[1] > 0)) + + def test_permissive_instance_segmentation_model(self): + self.assertPutSuccess( + '/models/classify/threshold/load', + body={} + ) + + diff --git a/tests/base/test_model.py b/tests/base/test_model.py index d975f7cd8725e0215391b4a526feab7cd69eeb31..912756b897b22763ca103d2c3f654f2c6df4249b 100644 --- a/tests/base/test_model.py +++ b/tests/base/test_model.py @@ -1,7 +1,7 @@ import unittest import model_server.conf.testing as conf -from model_server.conf.testing import DummySemanticSegmentationModel, DummyInstanceSegmentationModel +from model_server.conf.testing import DummySemanticSegmentationModel, DummyInstanceMaskSegmentationModel from model_server.base.accessors import CziImageFileAccessor from model_server.base.models import CouldNotLoadModelError, BinaryThresholdSegmentationModel @@ -13,7 +13,7 @@ class TestCziImageFileAccess(unittest.TestCase): self.cf = CziImageFileAccessor(czifile['path']) def test_instantiate_model(self): - model = DummySemanticSegmentationModel(params=None) + model = DummySemanticSegmentationModel() self.assertTrue(model.loaded) def test_instantiate_model_with_nondefault_kwarg(self): @@ -57,11 +57,12 @@ class TestCziImageFileAccess(unittest.TestCase): def test_binary_segmentation(self): model = BinaryThresholdSegmentationModel(tr=3e4) - img = self.cf.get_mono(0) - res = model.label_pixel_class(img) + res = model.label_pixel_class(self.cf) self.assertTrue(res.is_mask()) def test_dummy_instance_segmentation(self): img, mask = self.test_dummy_pixel_segmentation() - model = DummyInstanceSegmentationModel() + model = DummyInstanceMaskSegmentationModel() obmap = model.label_instance_class(img, mask) + self.assertTrue(all(obmap.unique()[0] == [0, 1])) + self.assertTrue(all(obmap.unique()[1] > 0)) diff --git a/tests/base/test_pipelines.py b/tests/base/test_pipelines.py index 9f1b0303cbacf9fe9fe9969f7beadfe1ce942711..cbe4a5f873db2833ffd8847048689017764130ed 100644 --- a/tests/base/test_pipelines.py +++ b/tests/base/test_pipelines.py @@ -1,7 +1,10 @@ import unittest +import numpy as np + from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file from model_server.base.pipelines import router, segment, segment_zproj +from model_server.base.pipelines.shared import PipelineTrace import model_server.conf.testing as conf from model_server.conf.testing import DummySemanticSegmentationModel @@ -17,7 +20,7 @@ class TestSegmentationPipelines(unittest.TestCase): def test_call_segment_pipeline(self): acc = generate_file_accessor(czifile['path']) - trace = segment.segment_pipeline({'accessor': acc}, {'model': self.model}, channel=2, smooth=3) + trace = segment.segment_pipeline({'': acc}, {'': self.model}, channel=2, smooth=3) outfp = output_path / 'pipelines' / 'segment_binary_mask.tif' write_accessor_data_to_file(outfp, trace.last) @@ -53,14 +56,37 @@ class TestSegmentationPipelines(unittest.TestCase): def test_call_segment_zproj_pipeline(self): acc = generate_file_accessor(zstack['path']) - trace1 = segment_zproj.segment_zproj_pipeline({'accessor': acc}, {'model': self.model}, channel=0, smooth=3, zi=4) + trace1 = segment_zproj.segment_zproj_pipeline({'': acc}, {'': self.model}, channel=0, smooth=3, zi=4) self.assertEqual(trace1.last.chroma, 1) self.assertEqual(trace1.last.nz, 1) - trace2 = segment_zproj.segment_zproj_pipeline({'accessor': acc}, {'model': self.model}, channel=0, smooth=3) + trace2 = segment_zproj.segment_zproj_pipeline({'': acc}, {'': self.model}, channel=0, smooth=3) self.assertEqual(trace2.last.chroma, 1) self.assertEqual(trace2.last.nz, 1) - trace3 = segment_zproj.segment_zproj_pipeline({'accessor': acc}, {'model': self.model}) + trace3 = segment_zproj.segment_zproj_pipeline({'': acc}, {'': self.model}) self.assertEqual(trace3.last.chroma, 1) # still == 1: model returns a single channel regardless of input self.assertEqual(trace3.last.nz, 1) + + def test_append_traces(self): + acc = generate_file_accessor(zstack['path']) + trace1 = PipelineTrace(acc) + trace1['halve'] = trace1.last.apply(lambda x: 0.5 * x) + + trace2 = PipelineTrace(trace1.last) + trace2['double'] = trace2.last.apply(lambda x: 2 * x) + trace3 = trace1.append(trace2, skip_first=False) + + self.assertEqual(len(trace3), len(trace1) + len(trace2)) + self.assertEqual(trace3['halve'], trace3['appended_input']) + self.assertTrue(np.all(trace3['input'].data == trace3['double'].data)) + + trace4 = trace1.append(trace2, skip_first=True) + self.assertEqual(len(trace4), len(trace1) + len(trace2) - 1) + self.assertTrue(np.all(trace4['input'].data == trace4['double'].data)) + + def test_add_nda_to_trace(self): + acc = generate_file_accessor(zstack['path']) + trace1 = PipelineTrace(acc, enforce_accessors=False) + trace1['halve'] = trace1.last.data * 0.5 + self.assertEqual(len(trace1), 2) diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index 785358c92961662273f97f8f8c74037e03fc3cf3..ee73aa4b15d2a0309b625c1bbf9a197ebed92fae 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -7,11 +7,10 @@ from pathlib import Path import pandas as pd from model_server.base.process import smooth -from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_seg, RoiSetExportParams, RoiSetMetaParams -from model_server.base.roiset import RoiSet -from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack +from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_seg, IntensityThresholdInstanceMaskSegmentationModel, RoiSet, RoiSetExportParams, RoiSetMetaParams +from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file import model_server.conf.testing as conf -from model_server.conf.testing import DummyInstanceSegmentationModel +from model_server.conf.testing import DummyInstanceMaskSegmentationModel data = conf.meta['image_files'] output_path = conf.meta['output_path'] @@ -82,7 +81,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): zero_obmap = InMemoryDataAccessor(np.zeros(self.seg_mask.shape, self.seg_mask.dtype)) roiset = RoiSet.from_object_ids(self.stack_ch_pa, zero_obmap) self.assertEqual(roiset.count, 0) - roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel()) + roiset.classify_by('dummy_class', [0], DummyInstanceMaskSegmentationModel()) self.assertTrue('classify_by_dummy_class' in roiset.get_df().columns) def test_slices_are_valid(self): @@ -183,14 +182,14 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_classify_by(self): roiset = self._make_roi_set() - roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel()) + roiset.classify_by('dummy_class', [0], DummyInstanceMaskSegmentationModel()) self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1])) return roiset def test_classify_by_multiple_channels(self): roiset = RoiSet.from_binary_mask(self.stack, self.seg_mask, params=RoiSetMetaParams(deproject_channel=0)) - roiset.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel()) + roiset.classify_by('dummy_class', [0, 1], DummyInstanceMaskSegmentationModel()) self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1])) return roiset @@ -207,7 +206,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertGreater(total_iou, 0.6) # classify first RoiSet - roiset1.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel()) + roiset1.classify_by('dummy_class', [0, 1], DummyInstanceMaskSegmentationModel()) self.assertTrue('dummy_class' in roiset1.classification_columns) self.assertFalse('dummy_class' in roiset2.classification_columns) @@ -261,7 +260,11 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa params=RoiSetMetaParams( expand_box_by=(128, 2), mask_type='boxes', - filters={'area': {'min': 1e3, 'max': 1e4}}, + filters={ + 'area': {'min': 1e3, 'max': 1e4}, + 'diag': {'min': 1e1, 'max': 1e5}, + 'min_hw': {'min': 1e1, 'max': 1e4} + }, deproject_channel=0, ) ) @@ -455,7 +458,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa 'annotated_zstacks': {}, 'object_classes': True, }) - self.roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel()) + self.roiset.classify_by('dummy_class', [0], DummyInstanceMaskSegmentationModel()) interm = self.roiset.get_export_product_accessors( channel=3, params=p @@ -644,7 +647,7 @@ class TestRoiSetObjectDetection(unittest.TestCase): from skimage.measure import label, regionprops, regionprops_table mask = self.seg_mask_3d - labels = label(mask.data_xyz, connectivity=3) + labels = label(mask.data_yxz, connectivity=3) table = pd.DataFrame( regionprops_table(labels) ).rename( @@ -819,3 +822,45 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertTrue((res.bbox_intersec == 2).all()) self.assertTrue((res.loc[res.seg_overlaps, :].index == [1]).all()) self.assertTrue((res.loc[res.seg_overlaps, 'seg_iou'] == 0.4).all()) + +class TestIntensityThresholdObjectModel(BaseTestRoiSetMonoProducts, unittest.TestCase): + def test_instance_segmentation(self): + + img = self.stack.get_mono(channel=0, mip=True) + mask = self.seg_mask + + model = IntensityThresholdInstanceMaskSegmentationModel() + obmap = model.label_instance_class(img, mask) + + img.write(output_path / 'TestIntensityThresholdObjectModel' / 'img.tif') + mask.write(output_path / 'TestIntensityThresholdObjectModel' / 'mask.tif') + obmap.write(output_path / 'TestIntensityThresholdObjectModel' / 'obmap.tif') + + self.assertGreater((mask.data > 0).sum(), (obmap.data > 0).sum()) + + def test_roiset_with_instance_segmentation(self): + roiset = RoiSet.from_binary_mask( + self.stack, + self.seg_mask, + params=RoiSetMetaParams( + mask_type='countours', + filters={'area': {'min': 1e3, 'max': 1e4}}, + expand_box_by=(128, 2), + deproject_channel=0 + ) + ) + roiset.classify_by('permissive_model', [0], IntensityThresholdInstanceMaskSegmentationModel(tr=0.0)) + self.assertEqual(roiset.get_df()['classify_by_permissive_model'].sum(), roiset.count) + roiset.classify_by('avg_intensity', [0], IntensityThresholdInstanceMaskSegmentationModel(tr=0.5)) + self.assertLess(roiset.get_df()['classify_by_avg_intensity'].sum(), roiset.count) + return roiset + + def test_aggregate_classification_results(self): + roiset = self.test_roiset_with_instance_segmentation() + om_mod = roiset.get_object_class_map('permissive_model') + om_tr = roiset.get_object_class_map('avg_intensity') + om_fil = roiset.get_object_class_map('permissive_model', filter_by=['avg_intensity']) + self.assertTrue(np.all(om_fil.unique()[0] == [0, 1])) + self.assertEqual(om_fil.data.sum(), om_tr.data.sum()) + self.assertGreater(om_mod.data.sum(), om_fil.data.sum()) + diff --git a/tests/base/test_roiset_derived.py b/tests/base/test_roiset_derived.py index 156ef9fe42c472dc829085ed3b8c26bb46ac003a..7be964272a8956716ff7a1572f70b15112af01e5 100644 --- a/tests/base/test_roiset_derived.py +++ b/tests/base/test_roiset_derived.py @@ -7,7 +7,7 @@ from model_server.base.roiset import RoiSetWithDerivedChannelsExportParams, RoiS from model_server.base.roiset import RoiSetWithDerivedChannels from model_server.base.accessors import generate_file_accessor, PatchStack import model_server.conf.testing as conf -from model_server.conf.testing import DummyInstanceSegmentationModel +from model_server.conf.testing import DummyInstanceMaskSegmentationModel data = conf.meta['image_files'] params = conf.meta['roiset'] @@ -20,7 +20,7 @@ class TestDerivedChannels(unittest.TestCase): self.seg_mask = generate_file_accessor(data['multichannel_zstack_mask2d']['path']) def test_classify_by_with_derived_channel(self): - class ModelWithDerivedInputs(DummyInstanceSegmentationModel): + class ModelWithDerivedInputs(DummyInstanceMaskSegmentationModel): def infer(self, img, mask): return PatchStack(super().infer(img, mask).data * img.chroma) @@ -38,12 +38,13 @@ class TestDerivedChannels(unittest.TestCase): [0, 1], ModelWithDerivedInputs(), derived_channel_functions=[ - lambda acc: PatchStack(2 * acc.get_channels([0]).data), - lambda acc: PatchStack((0.5 * acc.get_channels([1]).data).astype('uint8')) + lambda acc: acc.apply(lambda x: 2 * x), + lambda acc: PatchStack((0.5 * acc.get_channels([1]).data).astype('uint16')) ] ) - self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [4])) - self.assertTrue(all(np.unique(roiset.get_object_class_map('multiple_input_model').data) == [0, 4])) + self.assertGreater(roiset.accs_derived[0].chroma, 1) + self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [5])) + self.assertTrue(all(np.unique(roiset.get_object_class_map('multiple_input_model').data) == [0, 5])) self.assertEqual(len(roiset.accs_derived), 2) for di in roiset.accs_derived: diff --git a/tests/test_ilastik/test_roiset_workflow.py b/tests/base/test_roiset_pipeline.py similarity index 59% rename from tests/test_ilastik/test_roiset_workflow.py rename to tests/base/test_roiset_pipeline.py index cec05dbdbdb1be3c97c1ed53cccf3d57984fa34c..a32ec47b07137666479277f384eb4d42ca2e1eb8 100644 --- a/tests/test_ilastik/test_roiset_workflow.py +++ b/tests/base/test_roiset_pipeline.py @@ -3,20 +3,13 @@ import unittest import numpy as np - from model_server.base.accessors import generate_file_accessor -from model_server.base.api import app import model_server.conf.testing as conf from model_server.base.pipelines.roiset_obmap import RoiSetObjectMapParams, roiset_object_map_pipeline -import model_server.extensions.ilastik.models as ilm -from model_server.extensions.ilastik.router import router - -app.include_router(router) data = conf.meta['image_files'] output_path = conf.meta['output_path'] test_params = conf.meta['roiset'] -classifiers = conf.meta['ilastik_classifiers'] class BaseTestRoiSetMonoProducts(object): @@ -64,27 +57,17 @@ class BaseTestRoiSetMonoProducts(object): 'deproject_channel': 0, } - def _get_models(self): # tests can either use model objects directly, or load in API via project file string - fp_px = classifiers['px']['path'].__str__() - fp_ob = classifiers['seg_to_obj']['path'].__str__() + def _get_models(self): + from model_server.base.models import BinaryThresholdSegmentationModel + from model_server.base.roiset import IntensityThresholdInstanceMaskSegmentationModel return { 'pixel_classifier_segmentation': { - 'name': 'ilastik_px_mod', - 'project_file': fp_px, - 'model': ilm.IlastikPixelClassifierModel( - ilm.IlastikPixelClassifierParams( - project_file=fp_px, - ) - ) + 'name': 'min_px_mod', + 'model': BinaryThresholdSegmentationModel(tr=0.2), }, 'object_classifier': { - 'name': 'ilastik_ob_mod', - 'project_file': fp_ob, - 'model': ilm.IlastikObjectClassifierFromSegmentationModel( - ilm.IlastikParams( - project_file=fp_ob - ) - ) + 'name': 'min_ob_mod', + 'model': IntensityThresholdInstanceMaskSegmentationModel(), }, } @@ -111,22 +94,21 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase): **self._pipeline_params(), ) trace, rois = roiset_object_map_pipeline( - {'accessor': acc_in}, - {f'{k}_model': v['model'] for k, v in self._get_models().items()}, + {'': acc_in}, + {f'{k}_': v['model'] for k, v in self._get_models().items()}, **params.dict() ) - self.assertEqual(trace.pop('annotated_patches_2d').count, 13) - self.assertEqual(trace.pop('patches_2d').count, 13) + self.assertEqual(trace.pop('annotated_patches_2d').count, 22) + self.assertEqual(trace.pop('patches_2d').count, 22) trace.write_interm(Path(output_path) / 'trace', 'roiset_worfklow_trace', skip_first=False, skip_last=False) self.assertTrue('ob_id' in trace.keys()) - self.assertEqual(len(trace['labeled'].unique()[0]), 14) - self.assertEqual(rois.count, 13) + self.assertEqual(len(trace['labeled'].unique()[0]), 40) + self.assertEqual(rois.count, 22) self.assertEqual(len(trace['ob_id'].unique()[0]), 2) class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts): - app_name = 'tests.test_ilastik.test_roiset_workflow:app' input_data = data['multichannel_zstack_raw'] @@ -136,33 +118,30 @@ class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProd return conf.TestServerBaseClass.setUp(self) def test_trivial_api_response(self): - resp = self._get('') - self.assertEqual(resp.status_code, 200) + self.assertGetSuccess('') def test_load_input_accessor(self): fname = self.copy_input_file_to_server() - return self._put(f'accessors/read_from_file/{fname}').json() + return self.assertPutSuccess(f'accessors/read_from_file/{fname}') def test_load_pixel_classifier(self): - resp = self._put( - 'ilastik/seg/load/', - body={'project_file': self._get_models()['pixel_classifier_segmentation']['project_file']}, - ) - model_id = resp.json()['model_id'] - self.assertTrue(model_id.startswith('IlastikPixelClassifierModel')) - return model_id + mid = self.assertPutSuccess( + 'models/seg/threshold/load/', + query={'tr': 0.2}, + )['model_id'] + self.assertTrue(mid.startswith('BinaryThresholdSegmentationModel')) + return mid def test_load_object_classifier(self): - resp = self._put( - 'ilastik/seg_to_obj/load/', - body={'project_file': self._get_models()['object_classifier']['project_file']}, - ) - model_id = resp.json()['model_id'] - self.assertTrue(model_id.startswith('IlastikObjectClassifierFromSegmentationModel')) - return model_id + mid = self.assertPutSuccess( + 'models/classify/threshold/load/', + body={'tr': 0} + )['model_id'] + self.assertTrue(mid.startswith('IntensityThresholdInstanceMaskSegmentation')) + return mid def _object_map_workflow(self, ob_classifer_id): - resp = self._put( + res = self.assertPutSuccess( 'pipelines/roiset_to_obmap/infer', body={ 'accessor_id': self.test_load_input_accessor(), @@ -174,18 +153,33 @@ class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProd 'export_params': self._get_export_params(), }, ) - self.assertEqual(resp.status_code, 200, resp.json()) - oid = resp.json()['output_accessor_id'] - obmap_fn = self._put(f'/accessors/write_to_file/{oid}').json() - where_out = self._get('paths').json()['outbound_images'] - obmap_fp = Path(where_out) / obmap_fn - self.assertTrue(obmap_fp.exists()) - return generate_file_accessor(obmap_fp) + + # check on automatically written RoiSet + roiset_id = res['roiset_id'] + roiset_info = self.assertGetSuccess(f'rois/{roiset_id}') + self.assertGreater(roiset_info['count'], 0) + return res def test_workflow_with_object_classifier(self): - acc = self._object_map_workflow(self.test_load_object_classifier()) - self.assertTrue(np.all(acc.unique()[0] == [0, 1, 2])) + obmod_id = self.test_load_object_classifier() + res = self._object_map_workflow(obmod_id) + acc_obmap = self.get_accessor(res['output_accessor_id']) + self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1])) + + # get object map via RoiSet API + roiset_id = res['roiset_id'] + obmap_id = self.assertPutSuccess(f'rois/obmap/{roiset_id}/{obmod_id}', query={'object_classes': True}) + acc_obmap_roiset = self.get_accessor(obmap_id) + self.assertTrue(np.all(acc_obmap_roiset.data == acc_obmap.data)) + + # check serialize RoiSet + self.assertPutSuccess(f'rois/write/{roiset_id}') + self.assertFalse( + self.assertGetSuccess(f'rois/{roiset_id}')['loaded'] + ) + def test_workflow_without_object_classifier(self): - acc = self._object_map_workflow(None) - self.assertTrue(np.all(acc.unique()[0] == [0, 1])) + res = self._object_map_workflow(None) + acc_obmap = self.get_accessor(res['output_accessor_id']) + self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1])) diff --git a/tests/base/test_session.py b/tests/base/test_session.py index 6843f36a011851638ce3b4739ce9ff6caf15b301..7a90a70505482351b4fcfe653f7e68fcf78a462d 100644 --- a/tests/base/test_session.py +++ b/tests/base/test_session.py @@ -3,8 +3,10 @@ import pathlib import unittest import numpy as np -from model_server.base.accessors import InMemoryDataAccessor +from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor +from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_seg, RoiSet, RoiSetExportParams, RoiSetMetaParams from model_server.base.session import session +import model_server.conf.testing as conf class TestGetSessionObject(unittest.TestCase): def setUp(self) -> None: @@ -15,14 +17,17 @@ class TestGetSessionObject(unittest.TestCase): self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place') def test_changing_session_root_creates_new_directory(self): - from model_server.conf.defaults import root - old_paths = self.sesh.get_paths() - newroot = root / 'subdir' - self.sesh.restart(root=newroot) + self.sesh.restart() new_paths = self.sesh.get_paths() - for k in old_paths.keys(): - self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__())) + self.assertEqual( + old_paths['logs'].parent.parent, + new_paths['logs'].parent.parent + ) + self.assertNotEqual( + old_paths['logs'].parent, + new_paths['logs'].parent + ) def test_change_session_subdirectory(self): old_paths = self.sesh.get_paths() @@ -88,19 +93,33 @@ class TestGetSessionObject(unittest.TestCase): self.assertEqual(dfv.columns[1], 'Y') - def test_add_and_remove_accessor(self): - acc = InMemoryDataAccessor( - np.random.randint( - 0, - 2 ** 8, - size=(512, 256, 3, 7), - dtype='uint8' +class TestSessionPersistentData(unittest.TestCase): + + def setUp(self) -> None: + session.restart() + self.sesh = session + + data = conf.meta['image_files'] + + self.acc_in = generate_file_accessor(data['multichannel_zstack_raw']['path']) + self.mask = generate_file_accessor(data['multichannel_zstack_mask2d']['path']) + + self.roiset = RoiSet.from_binary_mask( + self.acc_in, + self.mask, + params=RoiSetMetaParams( + filters={'area': {'min': 1e3, 'max': 1e4}}, + expand_box_by=(128, 2), + deproject_channel=0, ) ) - shd = acc.shape_dict + + + def test_add_and_remove_accessor(self): + shd = self.acc_in.shape_dict # add accessor to session registry - acc_id = session.add_accessor(acc) + acc_id = session.add_accessor(self.acc_in) self.assertEqual(session.get_accessor_info(acc_id)['shape_dict'], shd) self.assertTrue(session.get_accessor_info(acc_id)['loaded']) @@ -136,3 +155,19 @@ class TestGetSessionObject(unittest.TestCase): self.assertIsInstance(acc_get, InMemoryDataAccessor) self.assertEqual(acc_get.shape_dict, shd) self.assertFalse(session.get_accessor_info(acc_id)['loaded']) + + def test_add_roiset(self): + roiset_id = session.add_roiset(self.roiset) + info = session.get_roiset_info(roiset_id) + self.assertEqual(info['raw_shape_dict'], self.acc_in.shape_dict) + self.assertTrue(info['loaded']) + + # serialize roiset and hence unload from registry + rec = session.write_roiset(roiset_id) + info = session.get_roiset_info(roiset_id) + self.assertFalse(info['loaded']) + self.assertTrue((session.paths['rois'] / info['files']['dataframe']).exists()) + + # read RoiSet + des_roiset = RoiSet.deserialize(self.acc_in, session.paths['rois'], roiset_id) + self.assertEqual(des_roiset.get_patches_acc().shape, self.roiset.get_patches_acc().shape) \ No newline at end of file diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py index e7744acb0c53e02353654602e7916b69de61679e..2fdfdfad9d191a077c2b1efc6f15ad348aa1eaa2 100644 --- a/tests/test_ilastik/test_ilastik.py +++ b/tests/test_ilastik/test_ilastik.py @@ -29,14 +29,16 @@ class TestIlastikPixelClassification(unittest.TestCase): self.cf = CziImageFileAccessor(czifile['path']) self.channel = 0 self.model = ilm.IlastikPixelClassifierModel( - params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px']['path'].__str__()) + project_file=ilastik_classifiers['px']['path'].__str__(), + px_class=0, + px_prob_threshold=0.5 ) self.mono_image = self.cf.get_mono(self.channel) def test_raise_error_if_autoload_disabled(self): model = ilm.IlastikPixelClassifierModel( - params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px']['path'].__str__()), + project_file=ilastik_classifiers['px']['path'].__str__(), autoload=False ) w = 512 @@ -80,11 +82,9 @@ class TestIlastikPixelClassification(unittest.TestCase): def test_label_pixels_with_params(self): def _run_seg(tr, sig): mod = ilm.IlastikPixelClassifierModel( - params=ilm.IlastikPixelClassifierParams( - project_file=ilastik_classifiers['px']['path'].__str__(), - px_prob_threshold=tr, - px_smoothing=sig, - ), + project_file=ilastik_classifiers['px']['path'].__str__(), + px_class=0, + px_prob_threshold=tr, ) mask = mod.label_pixel_class(self.mono_image) write_accessor_data_to_file( @@ -143,21 +143,24 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertEqual(mask.nz, acc.nz) self.assertEqual(mask.count, acc.count) - pxmap, _ = self.model.infer_patch_stack(acc) - self.assertEqual(pxmap.dtype, float) + pxmap = self.model.infer_patch_stack(acc) + self.assertEqual(pxmap.dtype, 'float32') self.assertEqual(pxmap.chroma, len(self.model.labels)) self.assertEqual(pxmap.hw, acc.hw) self.assertEqual(pxmap.nz, acc.nz) self.assertEqual(pxmap.count, acc.count) + norm_pxmap = self.model.infer_patch_stack(acc, normalize=True) + self.assertEqual(norm_pxmap.dtype, 'uint8') + def test_run_object_classifier_from_pixel_predictions(self): self.test_run_pixel_classifier() fp = czifile['path'] model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( - params=ilm.IlastikParams(project_file=ilastik_classifiers['pxmap_to_obj']['path'].__str__()) + project_file=ilastik_classifiers['pxmap_to_obj']['path'].__str__() ) mask = self.model.label_pixel_class(self.mono_image) - objmap, _ = model.infer(self.mono_image, mask) + objmap = model.infer(self.mono_image, mask) self.assertTrue( write_accessor_data_to_file( @@ -171,8 +174,8 @@ class TestIlastikPixelClassification(unittest.TestCase): def test_run_object_classifier_from_segmentation(self): self.test_run_pixel_classifier() fp = czifile['path'] - model = ilm.IlastikObjectClassifierFromSegmentationModel( - params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj']['path'].__str__()) + model = ilm.IlastikObjectClassifierFromMaskSegmentationModel( + project_file=ilastik_classifiers['seg_to_obj']['path'].__str__() ) mask = self.model.label_pixel_class(self.mono_image) objmap = model.label_instance_class(self.mono_image, mask) @@ -188,13 +191,13 @@ class TestIlastikPixelClassification(unittest.TestCase): def test_ilastik_pixel_classification_as_workflow(self): res = segment.segment_pipeline( accessors={ - 'accessor': generate_file_accessor(czifile['path']) + '': generate_file_accessor(czifile['path']) }, models={ - 'model': ilm.IlastikPixelClassifierModel( - params=ilm.IlastikPixelClassifierParams( - project_file=ilastik_classifiers['px']['path'].__str__() - ), + '': ilm.IlastikPixelClassifierModel( + project_file=ilastik_classifiers['px']['path'].__str__(), + px_class=0, + px_prob_threshold=0.5, ), }, channel=0, @@ -209,42 +212,35 @@ class TestServerTestCase(conf.TestServerBaseClass): class TestIlastikOverApi(TestServerTestCase): def test_httpexception_if_incorrect_project_file_loaded(self): - resp_load = self._put( + self.assertPutFailure( 'ilastik/seg/load/', + 500, body={'project_file': 'improper.ilp'}, ) - self.assertEqual(resp_load.status_code, 500) def test_load_ilastik_pixel_model(self): - resp_load = self._put( + mid = self.assertPutSuccess( 'ilastik/seg/load/', body={'project_file': str(ilastik_classifiers['px']['path'])}, - ) - self.assertEqual(resp_load.status_code, 200, resp_load.json()) - model_id = resp_load.json()['model_id'] - resp_list = self._get('models') - self.assertEqual(resp_list.status_code, 200) - rj = resp_list.json() - self.assertEqual(rj[model_id]['class'], 'IlastikPixelClassifierModel') - return model_id + )['model_id'] + rl = self.assertGetSuccess('models') + self.assertEqual(rl[mid]['class'], 'IlastikPixelClassifierModel') + return mid def test_load_another_ilastik_pixel_model(self): - model_id = self.test_load_ilastik_pixel_model() - resp_list_1st = self._get('models').json() - self.assertEqual(len(resp_list_1st), 1, resp_list_1st) - resp_load_2nd = self._put( + self.test_load_ilastik_pixel_model() + self.assertEqual(len(self.assertGetSuccess('models')), 1) + self.assertPutSuccess( 'ilastik/seg/load/', body={'project_file': str(ilastik_classifiers['px']['path']), 'duplicate': True}, ) - resp_list_2nd = self._get('models').json() - self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd) - resp_load_3rd = self._put( + self.assertEqual(len(self.assertGetSuccess('models')), 2) + self.assertPutSuccess( 'ilastik/seg/load/', body={'project_file': str(ilastik_classifiers['px']['path']), 'duplicate': False}, ) - resp_list_3rd = self._get('models').json() - self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd) + self.assertEqual(len(self.assertGetSuccess('models')), 2) def test_load_ilastik_pixel_model_with_params(self): params = { @@ -252,67 +248,55 @@ class TestIlastikOverApi(TestServerTestCase): 'px_class': 0, 'px_prob_threshold': 0.5 } - resp_load = self._put( + mid = self.assertPutSuccess( 'ilastik/seg/load/', body=params, - ) - self.assertEqual(resp_load.status_code, 200, resp_load.json()) - model_id = resp_load.json()['model_id'] - mods = self._get('models').json() + )['model_id'] + mods = self.assertGetSuccess('models') self.assertEqual(len(mods), 1) - self.assertEqual(mods[model_id]['params']['px_prob_threshold'], 0.5) + self.assertEqual(mods[mid]['info']['px_prob_threshold'], 0.5) def test_load_ilastik_pxmap_to_obj_model(self): - resp_load = self._put( + mid = self.assertPutSuccess( 'ilastik/pxmap_to_obj/load/', body={'project_file': str(ilastik_classifiers['pxmap_to_obj']['path'])}, - ) - model_id = resp_load.json()['model_id'] - - self.assertEqual(resp_load.status_code, 200, resp_load.json()) - resp_list = self._get('models') - self.assertEqual(resp_list.status_code, 200) - rj = resp_list.json() - self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromPixelPredictionsModel') - return model_id + )['model_id'] + rl = self.assertGetSuccess('models') + self.assertEqual(rl[mid]['class'], 'IlastikObjectClassifierFromPixelPredictionsModel') + return mid def test_load_ilastik_model_with_model_id(self): - mid = 'new_model_id' - resp_load = self._put( + nmid = 'new_model_id' + rmid = self.assertPutSuccess( 'ilastik/pxmap_to_obj/load/', + query={ + 'model_id': nmid, + }, body={ 'project_file': str(ilastik_classifiers['pxmap_to_obj']['path']), - 'model_id': mid, }, - ) - res_mid = resp_load.json()['model_id'] - self.assertEqual(res_mid, mid) + )['model_id'] + self.assertEqual(rmid, nmid) def test_load_ilastik_seg_to_obj_model(self): - resp_load = self._put( + mid = self.assertPutSuccess( 'ilastik/seg_to_obj/load/', body={'project_file': str(ilastik_classifiers['seg_to_obj']['path'])}, - ) - model_id = resp_load.json()['model_id'] - - self.assertEqual(resp_load.status_code, 200, resp_load.json()) - resp_list = self._get('models') - self.assertEqual(resp_list.status_code, 200) - rj = resp_list.json() - self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromSegmentationModel') - return model_id + )['model_id'] + rl = self.assertGetSuccess('models') + self.assertEqual(rl[mid]['class'], 'IlastikObjectClassifierFromMaskSegmentationModel') + return mid def test_ilastik_infer_pixel_probability(self): fname = self.copy_input_file_to_server() - model_id = self.test_load_ilastik_pixel_model() - in_acc_id = self._put(f'accessors/read_from_file/{fname}').json() + mid = self.test_load_ilastik_pixel_model() + acc_id = self.assertPutSuccess(f'accessors/read_from_file/{fname}') - resp_infer = self._put( + self.assertPutSuccess( f'pipelines/segment', - body={'model_id': model_id, 'accessor_id': in_acc_id, 'channel': 0}, + body={'model_id': mid, 'accessor_id': acc_id, 'channel': 0}, ) - self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) def test_ilastik_infer_px_then_ob(self): @@ -320,9 +304,9 @@ class TestIlastikOverApi(TestServerTestCase): px_model_id = self.test_load_ilastik_pixel_model() ob_model_id = self.test_load_ilastik_pxmap_to_obj_model() - in_acc_id = self._put(f'accessors/read_from_file/{fname}').json() + in_acc_id = self.assertPutSuccess(f'accessors/read_from_file/{fname}') - resp_infer = self._put( + self.assertPutSuccess( 'ilastik/pipelines/pixel_then_object_classification/infer/', body={ 'px_model_id': px_model_id, @@ -331,7 +315,6 @@ class TestIlastikOverApi(TestServerTestCase): 'channel': 0, } ) - self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) class TestIlastikOnMultichannelInputs(TestServerTestCase): @@ -347,8 +330,8 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase): def test_classify_pixels(self): img = generate_file_accessor(self.pa_input_image) self.assertGreater(img.chroma, 1) - mod = ilm.IlastikPixelClassifierModel(ilm.IlastikPixelClassifierParams(project_file=self.pa_px_classifier.__str__())) - pxmap = mod.infer(img)[0] + mod = ilm.IlastikPixelClassifierModel(project_file=self.pa_px_classifier.__str__()) + pxmap = mod.infer(img) self.assertEqual(pxmap.hw, img.hw) self.assertEqual(pxmap.nz, img.nz) return pxmap @@ -357,9 +340,9 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase): pxmap = self.test_classify_pixels() img = generate_file_accessor(self.pa_input_image) mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel( - ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__()) + project_file=self.pa_ob_pxmap_classifier.__str__(), ) - obmap = mod.infer(img, pxmap)[0] + obmap = mod.infer(img, pxmap) self.assertEqual(obmap.hw, img.hw) self.assertEqual(obmap.nz, img.nz) @@ -370,14 +353,14 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase): def _call_workflow(channel): return px_then_ob.pixel_then_object_classification_pipeline( accessors={ - 'accessor': generate_file_accessor(self.pa_input_image) + '': generate_file_accessor(self.pa_input_image) }, models={ - 'px_model': ilm.IlastikPixelClassifierModel( - ilm.IlastikParams(project_file=self.pa_px_classifier.__str__()), + 'px_': ilm.IlastikPixelClassifierModel( + project_file=self.pa_px_classifier.__str__(), ), - 'ob_model': ilm.IlastikObjectClassifierFromPixelPredictionsModel( - ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__()), + 'ob_': ilm.IlastikObjectClassifierFromPixelPredictionsModel( + project_file=self.pa_ob_pxmap_classifier.__str__(), ) }, channel=channel, @@ -398,38 +381,32 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase): """ copyfile( self.pa_input_image, - Path(self._get('paths').json()['inbound_images']) / self.pa_input_image.name + Path(self.assertGetSuccess('paths')['inbound_images']) / self.pa_input_image.name ) - in_acc_id = self._put(f'accessors/read_from_file/{self.pa_input_image.name}').json() + in_acc_id = self.assertPutSuccess(f'accessors/read_from_file/{self.pa_input_image.name}') - resp_load_px = self._put( + px_model_id = self.assertPutSuccess( 'ilastik/seg/load/', body={'project_file': str(self.pa_px_classifier)}, - ) - self.assertEqual(resp_load_px.status_code, 200, resp_load_px.json()) - px_model_id = resp_load_px.json()['model_id'] + )['model_id'] - resp_load_ob = self._put( + ob_model_id = self.assertPutSuccess( 'ilastik/pxmap_to_obj/load/', body={'project_file': str(self.pa_ob_pxmap_classifier)}, - ) - self.assertEqual(resp_load_ob.status_code, 200, resp_load_ob.json()) - ob_model_id = resp_load_ob.json()['model_id'] + )['model_id'] # run the pipeline - resp_infer = self._put( + obmap_id = self.assertPutSuccess( 'ilastik/pipelines/pixel_then_object_classification/infer/', body={ 'accessor_id': in_acc_id, 'px_model_id': px_model_id, 'ob_model_id': ob_model_id, } - ) - self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) + )['output_accessor_id'] # save output object map to file and compare - obmap_id = resp_infer.json()['output_accessor_id'] obmap_acc = self.get_accessor(obmap_id) self.assertEqual(obmap_acc.shape_dict['C'], 1) @@ -453,8 +430,8 @@ class TestIlastikObjectClassification(unittest.TestCase): ) ) - self.classifier = ilm.IlastikObjectClassifierFromSegmentationModel( - params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj']['path'].__str__()), + self.classifier = ilm.IlastikObjectClassifierFromMaskSegmentationModel( + project_file=ilastik_classifiers['seg_to_obj']['path'].__str__(), ) self.raw = self.roiset.get_patches_acc() self.masks = self.roiset.get_patch_masks_acc()