diff --git a/model_server/base/api.py b/model_server/base/api.py index 6fe8a7aff229c9bad4da83bd70e71180796a040f..96620aefb05e30fedcf06dadbca9b76ffddbdf2a 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -1,36 +1,31 @@ +from typing import Union + from fastapi import FastAPI, HTTPException -from pydantic import BaseModel +from .accessors import generate_file_accessor +from .session import session, AccessorIdError, InvalidPathError, WriteAccessorError -from .models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel -from .session import session, InvalidPathError -from .validators import validate_workflow_inputs -from .workflows import classify_pixels app = FastAPI(debug=True) -from ..extensions.ilastik.router import router as ilastik_router -app.include_router(ilastik_router) +from .pipelines.router import router +app.include_router(router) + @app.on_event("startup") def startup(): pass + @app.get('/') def read_root(): return {'success': True} -class BounceBackParams(BaseModel): - par1: str - par2: list - -@app.put('/bounce_back') -def list_bounce_back(params: BounceBackParams): - return {'success': True, 'params': {'par1': params.par1, 'par2': params.par2}} @app.get('/paths') def list_session_paths(): return session.get_paths() + @app.get('/status') def show_session_status(): return { @@ -39,61 +34,79 @@ def show_session_status(): 'paths': session.get_paths(), } -def change_path(key, path): + +def _change_path(key, path): try: - if session.get_paths()[key] == path: - return session.get_paths() session.set_data_directory(key, path) except InvalidPathError as e: - raise HTTPException( - status_code=404, - detail=e.__str__(), - ) - session.log_info(f'Change {key} path to {path}') - return session.get_paths() + raise HTTPException(404, f'Did not find valid folder at: {path}') + @app.put('/paths/watch_input') def watch_input_path(path: str): - return change_path('inbound_images', path) + return _change_path('inbound_images', path) + @app.put('/paths/watch_output') def watch_output_path(path: str): - return change_path('outbound_images', path) + return _change_path('outbound_images', path) + @app.get('/session/restart') def restart_session(root: str = None) -> dict: session.restart(root=root) return session.describe_loaded_models() + @app.get('/session/logs') def list_session_log() -> list: return session.get_log_data() + @app.get('/models') def list_active_models(): return session.describe_loaded_models() -@app.put('/models/dummy_semantic/load/') -def load_dummy_model() -> dict: - mid = session.load_model(DummySemanticSegmentationModel) - session.log_info(f'Loaded model {mid}') - return {'model_id': mid} - -@app.put('/models/dummy_instance/load/') -def load_dummy_model() -> dict: - mid = session.load_model(DummyInstanceSegmentationModel) - session.log_info(f'Loaded model {mid}') - return {'model_id': mid} - -@app.put('/workflows/segment') -def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: - inpath = session.paths['inbound_images'] / input_filename - validate_workflow_inputs([model_id], [inpath]) - record = classify_pixels( - inpath, - session.models[model_id]['object'], - session.paths['outbound_images'], - channel=channel, - ) - session.log_info(f'Completed segmentation of {input_filename}') - return record \ No newline at end of file + +@app.get('/accessors') +def list_accessors(): + return session.list_accessors() + + +def _session_accessor(func, acc_id): + try: + return func(acc_id) + except AccessorIdError as e: + raise HTTPException(404, f'Did not find accessor with ID {acc_id}') + + +@app.get('/accessors/{accessor_id}') +def get_accessor(accessor_id: str): + return _session_accessor(session.get_accessor_info, accessor_id) + + +@app.get('/accessors/delete/{accessor_id}') +def delete_accessor(accessor_id: str): + if accessor_id == '*': + return session.del_all_accessors() + else: + return _session_accessor(session.del_accessor, accessor_id) + + +@app.put('/accessors/read_from_file/{filename}') +def read_accessor_from_file(filename: str, accessor_id: Union[str, None] = None): + fp = session.paths['inbound_images'] / filename + if not fp.exists(): + raise HTTPException(status_code=404, detail=f'Could not find file:\n{filename}') + acc = generate_file_accessor(fp) + return session.add_accessor(acc, accessor_id=accessor_id) + + +@app.put('/accessors/write_to_file/{accessor_id}') +def write_accessor_to_file(accessor_id: str, filename: Union[str, None] = None) -> str: + try: + return session.write_accessor(accessor_id, filename) + except AccessorIdError as e: + raise HTTPException(404, f'Did not find accessor with ID {accessor_id}') + except WriteAccessorError as e: + raise HTTPException(409, str(e)) \ No newline at end of file diff --git a/model_server/base/models.py b/model_server/base/models.py index 07c00ad457fe785882f753ca9ff62ff74860b529..eaded2c9f24c8793600019d86a1dcbb5cdab7c52 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -48,6 +48,11 @@ class Model(ABC): def reload(self): self.load() + @property + def name(self): + return f'{self.__class__.__name__}' + + class ImageToImageModel(Model): """ @@ -124,48 +129,20 @@ class InstanceSegmentationModel(ImageToImageModel): return PatchStack(data) -class DummySemanticSegmentationModel(SemanticSegmentationModel): +class BinaryThresholdSegmentationModel(SemanticSegmentationModel): - model_id = 'dummy_make_white_square' - - def load(self): - return True + def __init__(self, tr: float = 0.5): + self.tr = tr def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): - super().infer(img) - w = img.shape_dict['X'] - h = img.shape_dict['Y'] - result = np.zeros([h, w], dtype='uint8') - result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255 - return InMemoryDataAccessor(data=result), {'success': True} - - def label_pixel_class( - self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: - mask, _ = self.infer(img) - return mask - -class DummyInstanceSegmentationModel(InstanceSegmentationModel): + return img.apply(lambda x: x > self.tr), {'success': True} - model_id = 'dummy_pass_input_mask' + def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: + return self.infer(img, **kwargs)[0] def load(self): - return True - - def infer( - self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor - ) -> (GenericImageDataAccessor, dict): - return img.__class__( - (mask.data / mask.data.max()).astype('uint16') - ) + pass - def label_instance_class( - self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs - ) -> GenericImageDataAccessor: - """ - Returns a trivial segmentation, i.e. the input mask with value 1 - """ - super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs) - return self.infer(img, mask) class Error(Exception): pass diff --git a/model_server/base/pipelines/__init__.py b/model_server/base/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py new file mode 100644 index 0000000000000000000000000000000000000000..037fb7fc015c5646152854f5a7fa4f2258d9808d --- /dev/null +++ b/model_server/base/pipelines/roiset_obmap.py @@ -0,0 +1,121 @@ +from typing import Dict, Union + +from pydantic import BaseModel, Field, validator + +from ..accessors import GenericImageDataAccessor +from .router import router +from .segment_zproj import segment_zproj_pipeline +from .shared import call_pipeline +from ..roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams +from ..session import session + +from ..pipelines.shared import PipelineTrace, PipelineParams, PipelineRecord + +from ..models import Model, InstanceSegmentationModel + + +class RoiSetObjectMapParams(PipelineParams): + class _SegmentationParams(BaseModel): + channel: int = Field( + None, + description='Channel of input image to use for solving segmentation; use all channels if empty' + ) + zi: Union[int, None] = Field( + None, + description='z coordinate to use on input image when solving segmentation; apply MIP if empty', + ) + + accessor_id: str = Field( + description='ID(s) of previously loaded accessor(s) to use as pipeline input' + ) + pixel_classifier_segmentation_model_id: str = Field( + description='Pixel classifier applied to segmentation_channel(s) to segment objects' + ) + object_classifier_model_id: Union[str, None] = Field( + None, + description='Object classifier used to classify segmented objectss' + ) + pixel_classifier_derived_model_id: Union[str, None] = Field( + None, + description='Pixel classifier used to derive channel(s) as additional inputs to object classification' + ) + patches_channel: int = Field( + description='Channel of input image used in patches sent to object classifier' + ) + segmentation: _SegmentationParams = Field( + _SegmentationParams(), + description='Parameters used to solve segmentation' + ) + roi_params: RoiSetMetaParams = RoiSetMetaParams(**{ + 'mask_type': 'boxes', + 'filters': { + 'area': {'min': 1e3, 'max': 1e8} + }, + 'expand_box_by': [128, 2], + 'deproject_channel': None, + }) + export_params: RoiSetExportParams = RoiSetExportParams() + derived_channels_input_channel: Union[int, None] = Field( + None, + description='Channel of input image from which to compute derived channels; use all if empty' + ) + derived_channels_output_channels: Union[int, list] = Field( + None, + description='Derived channels to send to object classifier; use all if empty' + ) + export_label_interm: bool = False + + +class RoiSetToObjectMapRecord(PipelineRecord): + roiset_table: dict + +@router.put('/roiset_to_obmap/infer') +def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord: + """ + Compute a RoiSet from 2d segmentation, apply to z-stack, and optionally apply object classification. + """ + record, rois = call_pipeline(roiset_object_map_pipeline, p) + + table = rois.get_serializable_dataframe() + + session.write_to_table('RoiSet', {'input_filename': p.accessor_id}, table) + ret = RoiSetToObjectMapRecord( + roiset_table=table.to_dict(), + **record.dict() + ) + return ret + + +def roiset_object_map_pipeline( + accessors: Dict[str, GenericImageDataAccessor], + models: Dict[str, Model], + **k +) -> (PipelineTrace, RoiSet): + d = PipelineTrace(accessors['accessor']) + + d['mask'] = segment_zproj_pipeline( + accessors, + {'model': models['pixel_classifier_segmentation_model']}, + **k['segmentation'], + ).last + + d['labeled'] = get_label_ids(d.last) + rois = RoiSet.from_object_ids(d['input'], d['labeled'], RoiSetMetaParams(**k['roi_params'])) + + # optionally append RoiSet products + for ki, vi in rois.get_export_product_accessors(k['patches_channel'], RoiSetExportParams(**k['export_params'])).items(): + d[ki] = vi + + # optionally run an object classifier if specified + if obmod := models.get('object_classifier_model'): + obmod_name = k['object_classifier_model_id'] + assert isinstance(obmod, InstanceSegmentationModel) + rois.classify_by( + obmod_name, + [k['patches_channel']], + obmod, + ) + d[obmod_name] = rois.get_object_class_map(obmod_name) + else: + d['objects_unclassified'] = d.last.apply(lambda x: ((x > 0) * 1).astype('uint16')) + return d, rois diff --git a/model_server/base/pipelines/router.py b/model_server/base/pipelines/router.py new file mode 100644 index 0000000000000000000000000000000000000000..dc7850da0c7101db8bdfb04487a2498ec5cfdcfd --- /dev/null +++ b/model_server/base/pipelines/router.py @@ -0,0 +1,9 @@ +from fastapi import APIRouter + +router = APIRouter( + prefix='/pipelines', + tags=['pipelines'], +) + +# this completes routing in individual pipeline modules +from . import roiset_obmap, segment, segment_zproj \ No newline at end of file diff --git a/model_server/base/pipelines/segment.py b/model_server/base/pipelines/segment.py new file mode 100644 index 0000000000000000000000000000000000000000..fac1835111872ab563b71d6795982e170d52e61f --- /dev/null +++ b/model_server/base/pipelines/segment.py @@ -0,0 +1,46 @@ +from typing import Dict + +from .shared import call_pipeline, IncompatibleModelsError, PipelineTrace, PipelineParams, PipelineRecord +from ..accessors import GenericImageDataAccessor +from ..models import Model, SemanticSegmentationModel +from ..process import smooth +from .router import router + +from pydantic import Field + + +class SegmentParams(PipelineParams): + accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input') + model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)') + channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.') + + +class SegmentRecord(PipelineRecord): + pass + + +@router.put('/segment') +def segment(p: SegmentParams) -> SegmentRecord: + """ + Run a semantic segmentation model to compute a binary mask from an input image + """ + return call_pipeline(segment_pipeline, p) + +def segment_pipeline( + accessors: Dict[str, GenericImageDataAccessor], + models: Dict[str, Model], + **k +) -> PipelineTrace: + d = PipelineTrace(accessors.get('accessor')) + model = models.get('model') + + if not isinstance(model, SemanticSegmentationModel): + raise IncompatibleModelsError('Expecting a pixel classification model') + + if ch := k.get('channel') is not None: + d['mono'] = d['input'].get_mono(ch) + d['inference'] = model.label_pixel_class(d.last) + if sm := k.get('smooth') is not None: + d['smooth'] = d.last.apply(lambda x: smooth(x, sm)) + d['output'] = d.last + return d \ No newline at end of file diff --git a/model_server/base/pipelines/segment_zproj.py b/model_server/base/pipelines/segment_zproj.py new file mode 100644 index 0000000000000000000000000000000000000000..5f5ed9ba755433ef8445198b80885412e2161d98 --- /dev/null +++ b/model_server/base/pipelines/segment_zproj.py @@ -0,0 +1,40 @@ +from typing import Dict + +from .router import router +from .segment import SegmentParams, SegmentRecord, segment_pipeline +from .shared import call_pipeline, PipelineTrace +from ..accessors import GenericImageDataAccessor +from ..models import Model + +from pydantic import Field + +class SegmentZStackParams(SegmentParams): + zi: int = Field(None, description='z coordinate to use on input stack; apply MIP if empty') + + +class SegmentZStackRecord(SegmentRecord): + pass + + +@router.put('/segment_zproj') +def segment_zproj(p: SegmentZStackParams) -> SegmentZStackRecord: + """ + Run a semantic segmentation model to compute a binary mask from a projected input zstack + """ + return call_pipeline(segment_zproj(), p) + + +def segment_zproj_pipeline( + accessors: Dict[str, GenericImageDataAccessor], + models: Dict[str, Model], + **k +) -> PipelineTrace: + d = PipelineTrace(accessors.get('accessor')) + + if isinstance(k.get('zi'), int): + assert 0 < k['zi'] < d.last.nz + d['mip'] = d.last.get_zi(k['zi']) + else: + d['mip'] = d.last.get_mip() + return segment_pipeline({'accessor': d.last}, models, **k) + diff --git a/model_server/base/pipelines/shared.py b/model_server/base/pipelines/shared.py new file mode 100644 index 0000000000000000000000000000000000000000..df43cd3c6b55af8b99a2425be568c69786fb1ba9 --- /dev/null +++ b/model_server/base/pipelines/shared.py @@ -0,0 +1,218 @@ +from collections import OrderedDict +from pathlib import Path +from time import perf_counter +from typing import List, Union + +from fastapi import HTTPException +from pydantic import BaseModel, Field, root_validator + +from ..accessors import GenericImageDataAccessor +from ..session import session, AccessorIdError + + +class PipelineParams(BaseModel): + keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session') + api: bool = Field(True, description='Validate parameters against server session and map HTTP errors if True') + + @root_validator(pre=False) + def models_are_loaded(cls, dd): + for k, v in dd.items(): + if dd['api'] and k.endswith('model_id') and v is not None: + if v not in session.describe_loaded_models().keys(): + raise HTTPException(status_code=409, detail=f'Model with {k} = {v} has not been loaded') + return dd + + @root_validator(pre=False) + def accessors_are_loaded(cls, dd): + for k, v in dd.items(): + if dd['api'] and k.endswith('accessor_id'): + try: + info = session.get_accessor_info(v) + except AccessorIdError as e: + raise HTTPException(status_code=409, detail=str(e)) + if not info['loaded']: + raise HTTPException(status_code=409, detail=f'Accessor with {k} = {v} has not been loaded') + return dd + + +class PipelineRecord(BaseModel): + output_accessor_id: str + interm_accessor_ids: Union[List[str], None] + success: bool + timer: dict + + +def call_pipeline(func, p: PipelineParams) -> PipelineRecord: + # match accessor IDs to loaded accessor objects + accessors_in = {} + for k, v in p.dict().items(): + if k.endswith('accessor_id'): + accessors_in[k.split('_id')[0]] = session.get_accessor(v, pop=True) + if len(accessors_in) == 0: + raise NoAccessorsFoundError('Expecting as least one valid accessor to run pipeline') + + # use first validated accessor ID to name log file entries and derived accessors + input_description = [p.dict()[pk] for pk in p.dict().keys() if pk.endswith('accessor_id')][0] + + models = {} + for k, v in p.dict().items(): + if k.endswith('model_id') and v is not None: + models[k.split('_id')[0]] = session.models[v]['object'] + + # call the actual pipeline; expect a single PipelineTrace or a tuple where first element is PipelineTrace + ret = func( + accessors_in, + models, + **p.dict(), + ) + if isinstance(ret, PipelineTrace): + steps = ret + misc = None + elif isinstance(ret, tuple) and isinstance(ret[0], PipelineTrace): + steps = ret[0] + misc = ret[1:] + else: + raise UnexpectedPipelineReturnError( + f'{func.__name__} returned unexpected value of {type(ret)}' + ) + session.log_info(f'Completed {func.__name__} on {input_description}.') + + # map intermediate data accessors to accessor IDs + if p.keep_interm: + interm_ids = [] + acc_interm = steps.accessors(skip_first=True, skip_last=True).items() + for i, item in enumerate(acc_interm): + stk, acc = item + interm_ids.append( + session.add_accessor( + acc, + accessor_id=f'{input_description}_{func.__name__}_step{(i + 1):02d}_{stk}' + ) + ) + else: + interm_ids = None + + # map final result to an accessor ID + result_id = session.add_accessor( + steps.last, + accessor_id=f'{p.accessor_id}_{func.__name__}_result' + ) + + record = PipelineRecord( + output_accessor_id=result_id, + interm_accessor_ids=interm_ids, + success=True, + timer=steps.times + ) + + # return miscellaneous objects if pipeline returns these + if misc: + return record, *misc + else: + return record + + +class PipelineTrace(OrderedDict): + tfunc = perf_counter + + def __init__(self, start_acc: GenericImageDataAccessor = None, enforce_accessors=True, allow_overwrite=False): + """ + A container and timer for data at each stage of a pipeline. + :param start_acc: (optional) accessor to initialize as 'input' step + :param enforce_accessors: if True, only allow accessors to be appended as items + :param allow_overwrite: if True, allow an item to be overwritten + """ + self.enforce_accessors = enforce_accessors + self.allow_overwrite = allow_overwrite + self.last_time = self.tfunc() + self.markers = None + self.timer = OrderedDict() + super().__init__() + if start_acc is not None: + self['input'] = start_acc + + def __setitem__(self, key, value: GenericImageDataAccessor): + if self.enforce_accessors: + assert isinstance(value, GenericImageDataAccessor), f'Pipeline trace expects data accessor type' + if not self.allow_overwrite and key in self.keys(): + raise KeyAlreadyExists(f'key {key} already exists in pipeline trace') + self.timer.__setitem__(key, self.tfunc() - self.last_time) + self.last_time = self.tfunc() + return super().__setitem__(key, value) + + @property + def times(self): + """ + Return an ordered dictionary of incremental times for each item that is appended + """ + return {k: self.timer[k] for k in self.keys()} + + @property + def last(self): + """ + Return most recently appended item + :return: + """ + return list(self.values())[-1] + + def accessors(self, skip_first=True, skip_last=True) -> dict: + """ + Return subset ordered dictionary that guarantees items are accessors + :param skip_first: if True, exclude first item in trace + :param skip_last: if False, exclude last item in trace + :return: dictionary of accessors that meet input criteria + """ + res = OrderedDict() + for i, item in enumerate(self.items()): + k, v = item + if not isinstance(v, GenericImageDataAccessor): + continue + if skip_first and k == list(self.keys())[0]: + continue + if skip_last and k == list(self.keys())[-1]: + continue + res[k] = v + return res + + def write_interm( + self, + where: Path, + prefix: str = 'interm', + skip_first=True, + skip_last=True, + debug=False + ) -> List[Path]: + """ + Write accessor data to TIF files under specified path + :param where: directory in which to write image files + :param prefix: (optional) file prefix + :param skip_first: if True, do not write first item in trace + :param skip_last: if False, do not write last item in trace + :param debug: if True, report destination filepaths but do not write files + :return: list of destination filepaths + """ + paths = [] + accessors = self.accessors(skip_first=skip_first, skip_last=skip_last) + for i, item in enumerate(accessors.items()): + k, v = item + fp = where / f'{prefix}_{i:02d}_{k}.tif' + paths.append(fp) + if not debug: + v.write(fp) + return paths + + +class Error(Exception): + pass + +class IncompatibleModelsError(Error): + pass + +class KeyAlreadyExists(Error): + pass + +class NoAccessorsFoundError(Error): + pass + +class UnexpectedPipelineReturnError(Error): + pass diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 4eeda4909fa4e4c813698b0dc315f7c3b2c980e8..8b7642774a4e84c9f476b5aa744ec62652ed551e 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -1,3 +1,4 @@ +from collections import OrderedDict import itertools from math import sqrt, floor from pathlib import Path @@ -262,7 +263,6 @@ def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by, deproject_chann cropped = acc.get_mono(0, mip=True).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_xy return cropped - df['binary_mask'] = df.apply( _make_binary_mask, axis=1, @@ -370,19 +370,20 @@ class RoiSet(object): """Expose ROI meta information via the Pandas.DataFrame API""" return self._df.itertuples(name='Roi') - @staticmethod + @classmethod def from_object_ids( + cls, acc_raw: GenericImageDataAccessor, acc_obj_ids: GenericImageDataAccessor, params: RoiSetMetaParams = RoiSetMetaParams(), ): """ - - :param acc_raw: + Create an RoiSet from an object identities map + :param acc_raw: accessor to a generally multichannel z-stack :param acc_obj_ids: accessor to a 2D single-channel object identities map, where each pixel's intensity labels its membership in a connected object - :param params: - :return: + :param params: optional arguments that influence the definition and representation of ROIs + :return: RoiSet object """ assert acc_obj_ids.chroma == 1 @@ -395,15 +396,23 @@ class RoiSet(object): params.filters, ) - return RoiSet(acc_raw, df, params) + return cls(acc_raw, df, params) - @staticmethod + @classmethod def from_bounding_boxes( + cls, acc_raw: GenericImageDataAccessor, bbox_yxhw: List[Dict], bbox_zi: Union[List[int], int] = None, params: RoiSetMetaParams = RoiSetMetaParams() ): + """ + Create and RoiSet from bounding boxes + :param acc_raw: accessor to a generally a multichannel z-stack + :param yxhw_list: list of bounding boxing coordinates [corner X, corner Y, height, width] + :param params: optional arguments that influence the definition and representation of ROIs + :return: RoiSet object + """ bbox_df = pd.DataFrame(bbox_yxhw) if list(bbox_df.columns.str.upper().sort_values()) != ['H', 'W', 'X', 'Y']: raise BoundingBoxError(f'Expecting bounding box coordinates Y, X, H, and W, not {list(bbox_df.columns)}') @@ -444,11 +453,12 @@ class RoiSet(object): axis=1, result_type='reduce', ) - return RoiSet(acc_raw, df, params) + return cls(acc_raw, df, params) - @staticmethod + @classmethod def from_binary_mask( + cls, acc_raw: GenericImageDataAccessor, acc_seg: GenericImageDataAccessor, allow_3d=False, @@ -463,7 +473,7 @@ class RoiSet(object): :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False :param params: optional arguments that influence the definition and representation of ROIs """ - return RoiSet.from_object_ids( + return cls.from_object_ids( acc_raw, get_label_ids( acc_seg, @@ -473,8 +483,9 @@ class RoiSet(object): params ) - @staticmethod + @classmethod def from_polygons_2d( + cls, acc_raw, polygons: List[np.ndarray], params: RoiSetMetaParams = RoiSetMetaParams() @@ -489,7 +500,7 @@ class RoiSet(object): for p in polygons: sl = draw.polygon(p[:, 1], p[:, 0]) mask[sl] = True - return RoiSet.from_binary_mask( + return cls.from_binary_mask( acc_raw, InMemoryDataAccessor(mask), allow_3d=False, @@ -559,43 +570,21 @@ class RoiSet(object): def classify_by( self, name: str, channels: list[int], object_classification_model: InstanceSegmentationModel, - derived_channel_functions: list[callable] = None ): """ Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing - specified inputs through an instance segmentation classifier. Optionally derive additional inputs for object - classification by passing a raw input channel through one or more functions. + specified inputs through an instance segmentation classifier. :param name: name of column to insert :param channels: list of nc raw input channels to send to classifier :param object_classification_model: InstanceSegmentation model object - :param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and - that return a single-channel PatchStack accessor of the same shape :return: None """ + if self.count == 0: + self._df['classify_by_' + name] = None + return True - raw_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) # all channels - if derived_channel_functions is not None: - mono_data = [raw_acc.get_mono(c).data for c in range(0, raw_acc.chroma)] - for fcn in derived_channel_functions: - der = fcn(raw_acc) # returns patch stack - if der.shape != mono_data[0].shape or der.dtype not in ['uint8', 'uint16']: - raise DerivedChannelError( - f'Error processing derived channel {der} with shape {der.shape_dict} and dtype {der.dtype}' - ) - self.accs_derived.append(der) - - # combine channels - data_derived = [acc.data for acc in self.accs_derived] - input_acc = PatchStack( - np.concatenate( - [*mono_data, *data_derived], - axis=raw_acc._ga('C') - ) - ) - - else: - input_acc = raw_acc + input_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) # all channels # do this on a patch basis, i.e. only one object per frame obmap_patches = object_classification_model.label_patch_stack( @@ -665,6 +654,13 @@ class RoiSet(object): return InMemoryDataAccessor(om) + def get_serializable_dataframe(self) -> pd.DataFrame: + return self._df.drop(['expanded_slice', 'slice', 'relative_slice', 'binary_mask'], axis=1) + + def export_dataframe(self, csv_path: Path) -> str: + csv_path.parent.mkdir(parents=True, exist_ok=True) + self.get_serializable_dataframe().to_csv(csv_path, index=False) + return csv_path.name def export_patch_masks(self, where: Path, pad_to: int = None, prefix='mask', expanded=False) -> pd.DataFrame: patches_df = self.get_patch_masks(pad_to=pad_to, expanded=expanded).copy() @@ -925,7 +921,40 @@ class RoiSet(object): return record - def serialize(self, where: Path, prefix='roiset') -> dict: + def get_export_product_accessors(self, channel, params: RoiSetExportParams) -> dict: + """ + Return various representations of ROIs, e.g. patches, annotated stacks, and object maps, as accessors + :param channel: color channel of products to export + :param params: RoiSetExportParams object describing which products to export and with which parameters + :return: ordered dict of accessors containing the specified products + """ + interm = OrderedDict() + if not self.count: + return interm + + for k, kp in params.dict().items(): + if kp is None: + continue + if k == 'patches_3d': + interm[k] = self.get_patches_acc([channel], make_3d=True, **kp) + if k == 'annotated_patches_2d': + interm[k] = self.get_patches_acc( + make_3d=False, white_channel=channel, + bounding_box_channel=1, bounding_box_linewidth=2, **kp + ) + if k == 'patches_2d': + interm[k] = self.get_patches_acc(make_3d=False, white_channel=channel, **kp) + if k == 'annotated_zstacks': + interm[k] = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kp)) + if k == 'object_classes': + pr = 'classify_by_' + cnames = [c.split(pr)[1] for c in self._df.columns if c.startswith(pr)] + for n in cnames: + interm[f'{k}_{n}'] = self.get_object_class_map(n) + + return interm + + def serialize(self, where: Path, prefix='') -> dict: """ Export the minimal information needed to recreate RoiSet object, i.e. CSV data file and tight patch masks :param where: path of directory in which to write files @@ -949,10 +978,7 @@ class RoiSet(object): csv_path = where / 'dataframe' / (prefix + '.csv') csv_path.parent.mkdir(parents=True, exist_ok=True) - self._df.drop( - ['expanded_slice', 'slice', 'relative_slice', 'binary_mask'], - axis=1 - ).to_csv(csv_path, index=False) + self.export_dataframe(csv_path) record['dataframe'] = str(Path('dataframe') / csv_path.name) @@ -994,12 +1020,13 @@ class RoiSet(object): return self._df.apply(_poly_from_mask, axis=1) + @property def acc_obj_ids(self): return make_object_ids_from_df(self._df, self.acc_raw.shape_dict) - @staticmethod - def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix='roiset') -> Self: + @classmethod + def deserialize(cls, acc_raw: GenericImageDataAccessor, where: Path, prefix='roiset') -> Self: """ Create an RoiSet object from saved files and an image accessor :param acc_raw: accessor to image that contains ROIs @@ -1025,20 +1052,106 @@ class RoiSet(object): df['binary_mask'] = df.apply(_read_binary_mask, axis=1) id_mask = make_object_ids_from_df(df, acc_raw.shape_dict) - return RoiSet.from_object_ids(acc_raw, id_mask) + return cls.from_object_ids(acc_raw, id_mask) else: # assume bounding boxes only df['y'] = df['y0'] df['x'] = df['x0'] df['h'] = df['y1'] - df['y0'] df['w'] = df['x1'] - df['x0'] - return RoiSet.from_bounding_boxes( + return cls.from_bounding_boxes( acc_raw, df[['y', 'x', 'h', 'w']].to_dict(orient='records'), list(df['zi']) ) +class RoiSetWithDerivedChannelsExportParams(RoiSetExportParams): + derived_channels: bool = False + +class RoiSetWithDerivedChannels(RoiSet): + + def __init__(self, *a, **k): + self.accs_derived = [] + super().__init__(*a, **k) + + def classify_by( + self, name: str, channels: list[int], + object_classification_model: InstanceSegmentationModel, + derived_channel_functions: list[callable] = None + ): + """ + Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing + specified inputs through an instance segmentation classifier. Derive additional inputs for object + classification by passing a raw input channel through one or more functions. + + :param name: name of column to insert + :param channels: list of nc raw input channels to send to classifier + :param object_classification_model: InstanceSegmentation model object + :param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and + that return a single-channel PatchStack accessor of the same shape + :return: None + """ + + raw_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) # all channels + if derived_channel_functions is not None: + mono_data = [raw_acc.get_mono(c).data for c in range(0, raw_acc.chroma)] + for fcn in derived_channel_functions: + der = fcn(raw_acc) # returns patch stack + if der.shape != mono_data[0].shape or der.dtype not in ['uint8', 'uint16']: + raise DerivedChannelError( + f'Error processing derived channel {der} with shape {der.shape_dict} and dtype {der.dtype}' + ) + self.accs_derived.append(der) + + # combine channels + data_derived = [acc.data for acc in self.accs_derived] + input_acc = PatchStack( + np.concatenate( + [*mono_data, *data_derived], + axis=raw_acc._ga('C') + ) + ) + + else: + input_acc = raw_acc + + # do this on a patch basis, i.e. only one object per frame + obmap_patches = object_classification_model.label_patch_stack( + input_acc, + self.get_patch_masks_acc(expanded=False, pad_to=None) + ) + + self._df['classify_by_' + name] = pd.Series(dtype='Int64') + + for i, roi in enumerate(self): + oc = np.unique( + mask_largest_object( + obmap_patches.iat(i).data + ) + )[-1] + self._df.loc[roi.Index, 'classify_by_' + name] = oc + + def run_exports(self, where: Path, channel, prefix, params: RoiSetWithDerivedChannelsExportParams) -> dict: + """ + Export various representations of ROIs, e.g. patches, annotated stacks, and object maps. + :param where: path of directory in which to write all export products + :param channel: color channel of products to export + :param prefix: prefix of the name of each product's file or subfolder + :param params: RoiSetExportParams object describing which products to export and with which parameters + :return: nested dict of Path objects describing the location of export products + """ + record = super().run_exports(where, channel, prefix, params) + + k = 'derived_channels' + if k in params.dict().keys(): + record[k] = [] + for di, dacc in enumerate(self.accs_derived): + fp = where / k / f'dc{di:01d}.tif' + fp.parent.mkdir(exist_ok=True, parents=True) + dacc.export_pyxcz(fp) + record[k].append(str(fp)) + return record class Error(Exception): pass diff --git a/model_server/base/session.py b/model_server/base/session.py index 771ed28a47b3fde328c6f780a99d488622f6364f..7e06003f8be9407f6a8ee4f1816b70fac12f1900 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -1,3 +1,4 @@ +from collections import OrderedDict import logging import os @@ -9,11 +10,11 @@ from typing import Union import pandas as pd from ..conf import defaults +from .accessors import GenericImageDataAccessor, PatchStack from .models import Model logger = logging.getLogger(__name__) - class CsvTable(object): def __init__(self, fpath: Path): self.path = fpath @@ -40,6 +41,7 @@ class _Session(object): def __init__(self, root: str = None): self.models = {} # model_id : model object self.paths = self.make_paths(root) + self.accessors = OrderedDict() self.logfile = self.paths['logs'] / f'session.log' logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True, format=self.log_format) @@ -80,6 +82,111 @@ class _Session(object): raise InvalidPathError(f'Could not find {path}') self.paths[key] = Path(path) + def add_accessor(self, acc: GenericImageDataAccessor, accessor_id: str = None) -> str: + """ + Add an accessor to session context + :param acc: accessor to add + :param accessor_id: unique ID, or autogenerate if None + :return: ID of accessor + """ + if accessor_id in self.accessors.keys(): + raise AccessorIdError(f'Access with ID {accessor_id} already exists') + if accessor_id is None: + idx = len(self.accessors) + accessor_id = f'auto_{idx:06d}' + self.accessors[accessor_id] = {'loaded': True, 'object': acc, **acc.info} + return accessor_id + + def del_accessor(self, accessor_id: str) -> str: + """ + Remove accessor object but retain its info dictionary + :param accessor_id: accessor's ID + :return: ID of accessor + """ + if accessor_id not in self.accessors.keys(): + raise AccessorIdError(f'No accessor with ID {accessor_id} is registered') + v = self.accessors[accessor_id] + if isinstance(v, dict) and v['loaded'] is False: + logger.warning(f'Accessor {accessor_id} is already deleted') + else: + assert isinstance(v['object'], GenericImageDataAccessor) + v['loaded'] = False + v['object'] = None + return accessor_id + + def del_all_accessors(self) -> list[str]: + """ + Remove (unload) all accessors but keep their info in dictionary + :return: list of removed accessor IDs + """ + res = [] + for k, v in self.accessors.items(): + if v['loaded']: + v['object'] = None + v['loaded'] = False + res.append(k) + return res + + + def list_accessors(self) -> dict: + """ + List information about all accessors in JSON-readable format + """ + if len(self.accessors): + return pd.DataFrame(self.accessors).drop('object').to_dict() + else: + return {} + + def get_accessor_info(self, acc_id: str) -> dict: + """ + Get information about a single accessor + """ + if acc_id not in self.accessors.keys(): + raise AccessorIdError(f'No accessor with ID {acc_id} is registered') + return self.list_accessors()[acc_id] + + def get_accessor(self, acc_id: str, pop: bool = True) -> GenericImageDataAccessor: + """ + Return an accessor object + :param acc_id: accessor's ID + :param pop: remove object from session accessor registry if True + :return: accessor object + """ + if acc_id not in self.accessors.keys(): + raise AccessorIdError(f'No accessor with ID {acc_id} is registered') + acc = self.accessors[acc_id]['object'] + if pop: + self.del_accessor(acc_id) + return acc + + def write_accessor(self, acc_id: str, filename: Union[str, None] = None) -> str: + """ + Write an accessor to file and unload it from the session + :param acc_id: accessor's ID + :param filename: force use of a specific filename, raise InvalidPathError if this already exists + :return: name of file + """ + if filename is None: + fp = self.paths['outbound_images'] / f'{acc_id}.tif' + else: + fp = self.paths['outbound_images'] / filename + if fp.exists(): + raise InvalidPathError(f'Cannot overwrite file {filename} when writing accessor') + acc = self.get_accessor(acc_id, pop=True) + + old_fp = self.accessors[acc_id]['filepath'] + if old_fp != '': + raise WriteAccessorError( + f'Cannot overwrite accessor that is already written to {old_fp}' + ) + + if isinstance(acc, PatchStack): + acc.export_pyxcz(fp) + else: + acc.write(fp) + self.accessors[acc_id]['filepath'] = fp.__str__() + return fp.name + @staticmethod def make_paths(root: str = None) -> dict: """ @@ -131,10 +238,16 @@ class _Session(object): def log_error(self, msg): logger.error(msg) - def load_model(self, ModelClass: Model, params: Union[BaseModel, None] = None) -> dict: + def load_model( + self, + ModelClass: Model, + key: Union[str, None] = None, + params: Union[BaseModel, None] = None, + ) -> dict: """ Load an instance of a given model class and attach to this session's model registry :param ModelClass: subclass of Model + :param key: unique identifier of model, or autogenerate if None :param params: optional parameters that are passed to the model's construct :return: model_id of loaded model """ @@ -142,13 +255,17 @@ class _Session(object): assert mi.loaded, f'Error loading instance of {ModelClass.__name__}' ii = 0 - def mid(i): - return f'{ModelClass.__name__}_{i:02d}' + if key is None: + def mid(i): + return f'{mi.name}_{i:02d}' - while mid(ii) in self.models.keys(): - ii += 1 + while mid(ii) in self.models.keys(): + ii += 1 + + key = mid(ii) + elif key in self.models.keys(): + raise CouldNotInstantiateModelError(f'Model with key {key} already exists.') - key = mid(ii) self.models[key] = { 'object': mi, 'params': getattr(mi, 'params', None) @@ -198,6 +315,12 @@ class InferenceRecordError(Error): class CouldNotInstantiateModelError(Error): pass +class AccessorIdError(Error): + pass + +class WriteAccessorError(Error): + pass + class CouldNotCreateDirectory(Error): pass diff --git a/model_server/base/util.py b/model_server/base/util.py index 81736d485cd322d9ee42af0dcc6d77604314e154..c1556cb5f4c30004d789b05cb539a7a6b770d67e 100644 --- a/model_server/base/util.py +++ b/model_server/base/util.py @@ -1,14 +1,15 @@ +from collections import OrderedDict from math import ceil from pathlib import Path import re from time import localtime, strftime from typing import List +from time import perf_counter import pandas as pd -from .accessors import InMemoryDataAccessor, write_accessor_data_to_file -from .models import Model -from .roiset import filter_df_overlap_seg, RoiSet +from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file +from model_server.base.models import Model def autonumber_new_directory(where: str, prefix: str) -> str: """ @@ -164,4 +165,5 @@ def loop_workflow( ) if len(failures) > 0: - pd.DataFrame(failures).to_csv(Path(output_folder_path) / 'failures.csv') \ No newline at end of file + pd.DataFrame(failures).to_csv(Path(output_folder_path) / 'failures.csv') + diff --git a/model_server/base/validators.py b/model_server/base/validators.py deleted file mode 100644 index f1a092e9f5c65105953c22da716291b359f0f402..0000000000000000000000000000000000000000 --- a/model_server/base/validators.py +++ /dev/null @@ -1,17 +0,0 @@ -from fastapi import HTTPException - -from .session import session - -def validate_workflow_inputs(model_ids, inpaths): - for mid in model_ids: - if mid and mid not in session.describe_loaded_models().keys(): - raise HTTPException( - status_code=409, - detail=f'Model {mid} has not been loaded' - ) - for inpa in inpaths: - if not inpa.exists(): - raise HTTPException( - status_code=404, - detail=f'Could not find file:\n{inpa}' - ) \ No newline at end of file diff --git a/model_server/base/workflows.py b/model_server/base/workflows.py deleted file mode 100644 index 91f5696863473442a78046e53884982833f26502..0000000000000000000000000000000000000000 --- a/model_server/base/workflows.py +++ /dev/null @@ -1,61 +0,0 @@ -""" -Implementation of image analysis work behind API endpoints, without knowledge of persistent data in server session. -""" -from collections import OrderedDict -from pathlib import Path -from time import perf_counter -from typing import Dict - -from .accessors import generate_file_accessor, write_accessor_data_to_file -from .models import SemanticSegmentationModel - -from pydantic import BaseModel - -class Timer(object): - tfunc = perf_counter - - def __init__(self): - self.events = OrderedDict() - self.last = self.tfunc() - - def click(self, key): - self.events[key] = self.tfunc() - self.last - self.last = self.tfunc() - -class WorkflowRunRecord(BaseModel): - model_id: str - input_filepath: str - output_filepath: str - success: bool - timer_results: Dict[str, float] - - -def classify_pixels(fpi: Path, model: SemanticSegmentationModel, where_output: Path, **kwargs) -> WorkflowRunRecord: - """ - Run a semantic segmentation model to compute a binary mask from an input image - - :param fpi: Path object that references input image file - :param model: semantic segmentation model instance - :param where_output: Path object that references output image directory - :param kwargs: variable-length keyword arguments - :return: record object - """ - ti = Timer() - ch = kwargs.get('channel') - img = generate_file_accessor(fpi).get_mono(ch) - ti.click('file_input') - - outdata = model.label_pixel_class(img) - ti.click('inference') - - outpath = where_output / (model.model_id + '_' + fpi.stem + '.tif') - write_accessor_data_to_file(outpath, outdata) - ti.click('file_output') - - return WorkflowRunRecord( - model_id=model.model_id, - input_filepath=str(fpi), - output_filepath=str(outpath), - success=True, - timer_results=ti.events, - ) \ No newline at end of file diff --git a/model_server/conf/defaults.py b/model_server/conf/defaults.py index bdf7cfd0cf2786783b8f4c16dafdf7418af05976..ff4f9040ceb9d9d14d947bbbce19185491d490fd 100644 --- a/model_server/conf/defaults.py +++ b/model_server/conf/defaults.py @@ -8,8 +8,7 @@ subdirectories = { 'outbound_images': 'images/outbound', 'tables': 'tables', } - server_conf = { 'host': '127.0.0.1', 'port': 8000, -} \ No newline at end of file +} diff --git a/model_server/conf/servers/__init__.py b/model_server/conf/servers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_server/conf/servers/ilastik.py b/model_server/conf/servers/ilastik.py new file mode 100644 index 0000000000000000000000000000000000000000..17445d3dd2af5cc3dd0fe91260d2c0c555a84213 --- /dev/null +++ b/model_server/conf/servers/ilastik.py @@ -0,0 +1,7 @@ +import importlib + +from model_server.base.api import app + +for ex in ['ilastik']: + m = importlib.import_module(f'extensions.{ex}.router') + app.include_router(m.router) diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py index 16b656c225bcc7f1373e025c1659b8c0893a6ca7..12d93c8b884dfad672a5e7adf0f5cbe05c60eeba 100644 --- a/model_server/conf/testing.py +++ b/model_server/conf/testing.py @@ -3,10 +3,12 @@ import os import unittest from multiprocessing import Process from pathlib import Path +from shutil import copyfile import requests from urllib3 import Retry +from ..base.accessors import generate_file_accessor class TestServerBaseClass(unittest.TestCase): """ @@ -52,6 +54,22 @@ class TestServerBaseClass(unittest.TestCase): self.server_process.terminate() self.server_process.join() + def copy_input_file_to_server(self): + resp = self._get('paths') + pa = resp.json()['inbound_images'] + copyfile( + self.input_data['path'], + Path(pa) / self.input_data['name'] + ) + return self.input_data['name'] + + def get_accessor(self, accessor_id, filename=None): + resp = self._put(f'/accessors/write_to_file/{accessor_id}', query={'filename': filename}) + where_out = self._get('paths').json()['outbound_images'] + fp_out = (Path(where_out) / resp.json()) + self.assertTrue(fp_out.exists()) + return generate_file_accessor(fp_out) + def setup_test_data(): """ diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 5e7a0a758eddb2f6cbb6db3c8cd74a0436e81065..d098e03983e02c978231ca555826b567726b8dd5 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -2,21 +2,25 @@ import json from logging import getLogger import os from pathlib import Path +from typing import Union import warnings import numpy as np -from pydantic import BaseModel +from pydantic import BaseModel, Field import vigra import model_server.extensions.ilastik.conf from ...base.accessors import PatchStack from ...base.accessors import GenericImageDataAccessor, InMemoryDataAccessor -from ...base.process import smooth from ...base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel class IlastikParams(BaseModel): - project_file: str - duplicate: bool = True + project_file: str = Field(description='(*.ilp) ilastik project filename') + duplicate: bool = Field( + True, + description='Load another instance of the same project file if True; return existing one if False' + ) + model_id: Union[str, None] = Field(None, description='Unique identifier of the model, or autogenerate if empty') class IlastikModel(Model): @@ -102,12 +106,11 @@ class IlastikModel(Model): class IlastikPixelClassifierParams(IlastikParams): px_class: int = 0 px_prob_threshold: float = 0.5 - px_smoothing: float = 0.0 class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): model_id = 'ilastik_pixel_classification' operations = ['segment', ] - + def __init__(self, params: IlastikPixelClassifierParams, **kwargs): super(IlastikPixelClassifierModel, self).__init__(params, **kwargs) @@ -167,13 +170,8 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs): pxmap, _ = self.infer(img) - sig = self.params['px_smoothing'] - if sig > 0.0: - proc = smooth(img.data, sig) - else: - proc = pxmap.data - mask = proc[:, :, self.params['px_class'], :] > self.params['px_prob_threshold'] - return InMemoryDataAccessor(mask) + mask = pxmap.get_mono(self.params['px_class']).apply(lambda x: x > self.params['px_prob_threshold']) + return mask class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel): @@ -307,7 +305,7 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs): """ Given an image and a map of pixel probabilities of the same shape, return a map where each connected object is - assigned a class. + assigned a class :param img: input image :param pxmap: map of pixel probabilities :param kwargs: @@ -317,9 +315,11 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag """ if not img.shape == pxmap.shape: raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape') + if not pxmap.data.min() >= 0.0 and pxmap.data.max() <= 1.0: + raise InvalidInputImageError('Pixel probability values must be between 0.0 and 1.0') pxch = kwargs.get('pixel_classification_channel', 0) pxtr = kwargs.get('pixel_classification_threshold', 0.5) - mask = pxmap.get_mono(pxch).apply(lambda x: x > pxtr) + mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr) obmap, _ = self.infer(img, mask) return obmap diff --git a/model_server/extensions/ilastik/pipelines/px_then_ob.py b/model_server/extensions/ilastik/pipelines/px_then_ob.py new file mode 100644 index 0000000000000000000000000000000000000000..1aa64615479cad29f01b7d7e416d1fec2f3c9568 --- /dev/null +++ b/model_server/extensions/ilastik/pipelines/px_then_ob.py @@ -0,0 +1,69 @@ +from typing import Dict + +from fastapi import APIRouter, HTTPException +from pydantic import Field + +from ....base.accessors import GenericImageDataAccessor +from ....base.models import Model +from ....base.pipelines.shared import call_pipeline, PipelineTrace, PipelineParams, PipelineRecord + +from ..models import IlastikPixelClassifierModel, IlastikObjectClassifierFromPixelPredictionsModel + +router = APIRouter( + prefix='/pipelines', +) + +class PxThenObParams(PipelineParams): + accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input') + px_model_id: str = Field(description='ID of model for pixel classification') + ob_model_id: str = Field(description='ID of model for object classification') + channel: int = Field(None, description='Image channel to pass to pixel classification, or all channels if empty.') + mip: bool = Field(False, description='Use maximum intensity projection of input image if True') + + +class PxThenObRecord(PipelineRecord): + pass + +@router.put('/pixel_then_object_classification/infer') +def pixel_then_object_classification(p: PxThenObParams) -> PxThenObRecord: + """ + Workflow that specifically runs an ilastik pixel classifier, then passes results to an object classifier. + """ + + try: + return call_pipeline(pixel_then_object_classification_pipeline, p) + except IncompatibleModelsError as e: + raise HTTPException(status_code=409, detail=str(e)) + + +def pixel_then_object_classification_pipeline( + accessors: Dict[str, GenericImageDataAccessor], + models: Dict[str, Model], + **k +) -> PxThenObRecord: + + if not isinstance(models['px_model'], IlastikPixelClassifierModel): + raise IncompatibleModelsError( + f'Expecting px_model to be an ilastik pixel classification model' + ) + if not isinstance(models['ob_model'], IlastikObjectClassifierFromPixelPredictionsModel): + raise IncompatibleModelsError( + f'Expecting ob_model to be an ilastik object classification from pixel predictions model' + ) + + d = PipelineTrace(accessors['accessor']) + if (ch := k.get('channel')) is not None: + channels = [ch] + else: + channels = range(0, d['input'].chroma) + d['select_channels'] = d.last.get_channels(channels, mip=k.get('mip', False)) + d['pxmap'], _ = models['px_model'].infer(d.last) + d['ob_map'], _ = models['ob_model'].infer(d['select_channels'], d['pxmap']) + + return d + +class Error(Exception): + pass + +class IncompatibleModelsError(Error): + pass \ No newline at end of file diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py index 4645de18ebc4d06e11a17248b7f248a00c6ff72f..20a1ef24235c0c4ee3d8cc101ffba9e92d86560b 100644 --- a/model_server/extensions/ilastik/router.py +++ b/model_server/extensions/ilastik/router.py @@ -1,10 +1,8 @@ -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter -from ...base.session import session -from ...base.validators import validate_workflow_inputs +from model_server.base.session import session -from . import models as ilm -from .workflows import infer_px_then_ob_model +from model_server.extensions.ilastik import models as ilm router = APIRouter( prefix='/ilastik', @@ -12,59 +10,46 @@ router = APIRouter( ) -def load_ilastik_model(model_class: ilm.IlastikModel, params: ilm.IlastikParams) -> dict: - """ - Load an ilastik model of a given class and project filename. - :param model_class: - :param project_file: (*.ilp) ilastik project filename - :param duplicate: load another instance of the same project file if True; return existing one if false - :return: dict containing model's ID - """ - project_file = params.project_file - if not params.duplicate: - existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True) - if existing_model_id is not None: - session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate') - return {'model_id': existing_model_id} - result = session.load_model(model_class, params) - session.log_info(f'Loaded ilastik model {result} from {project_file}') - return {'model_id': result} +import model_server.extensions.ilastik.pipelines.px_then_ob +router.include_router(model_server.extensions.ilastik.pipelines.px_then_ob.router) @router.put('/seg/load/') -def load_px_model(params: ilm.IlastikPixelClassifierParams) -> dict: +def load_px_model(p: ilm.IlastikPixelClassifierParams) -> dict: + """ + Load an ilastik pixel classifier model from its project file + """ return load_ilastik_model( ilm.IlastikPixelClassifierModel, - params, + p, ) @router.put('/pxmap_to_obj/load/') -def load_pxmap_to_obj_model(params: ilm.IlastikParams) -> dict: +def load_pxmap_to_obj_model(p: ilm.IlastikParams) -> dict: + """ + Load an ilastik object classifier from pixel predictions model from its project file + """ return load_ilastik_model( ilm.IlastikObjectClassifierFromPixelPredictionsModel, - params, + p, ) @router.put('/seg_to_obj/load/') -def load_seg_to_obj_model(params: ilm.IlastikParams) -> dict: +def load_seg_to_obj_model(p: ilm.IlastikParams) -> dict: + """ + Load an ilastik object classifier from segmentation model from its project file + """ return load_ilastik_model( ilm.IlastikObjectClassifierFromSegmentationModel, - params, + p, ) -@router.put('/pixel_then_object_classification/infer') -def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None, mip: bool = False) -> dict: - inpath = session.paths['inbound_images'] / input_filename - validate_workflow_inputs([px_model_id, ob_model_id], [inpath]) - try: - record = infer_px_then_ob_model( - inpath, - session.models[px_model_id]['object'], - session.models[ob_model_id]['object'], - session.paths['outbound_images'], - channel=channel, - mip=mip, - ) - session.log_info(f'Completed pixel and object classification of {input_filename}') - except AssertionError: - raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}') - return record \ No newline at end of file +def load_ilastik_model(model_class: ilm.IlastikModel, p: ilm.IlastikParams) -> dict: + project_file = p.project_file + if not p.duplicate: + existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True) + if existing_model_id is not None: + session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate') + return {'model_id': existing_model_id} + result = session.load_model(model_class, key=p.model_id, params=p) + session.log_info(f'Loaded ilastik model {result} from {project_file}') + return {'model_id': result} \ No newline at end of file diff --git a/model_server/extensions/ilastik/workflows.py b/model_server/extensions/ilastik/workflows.py deleted file mode 100644 index f096cba13523deb5e1baecefda5f5ebd9abbb16e..0000000000000000000000000000000000000000 --- a/model_server/extensions/ilastik/workflows.py +++ /dev/null @@ -1,77 +0,0 @@ -""" -Implementation of image analysis work behind API endpoints, without knowledge of persistent data in server session. -""" -from pathlib import Path -from typing import Dict, Union - -from pydantic import BaseModel - -from .models import IlastikPixelClassifierModel, IlastikObjectClassifierFromPixelPredictionsModel -from ...base.accessors import generate_file_accessor, write_accessor_data_to_file -from ...base.workflows import Timer - -class WorkflowRunRecord(BaseModel): - pixel_model_id: str - object_model_id: Union[str, None] = None - input_filepath: str - pixel_map_filepath: str - object_map_filepath: str - success: bool - timer_results: Dict[str, float] - - -def infer_px_then_ob_model( - fpi: Path, - px_model: IlastikPixelClassifierModel, - ob_model: IlastikObjectClassifierFromPixelPredictionsModel, - where_output: Path, - channel: int = None, - **kwargs -) -> WorkflowRunRecord: - """ - Workflow that specifically runs an ilastik pixel classifier, then passes results to an object classifier, - saving intermediate images - :param fpi: Path object that references input image file - :param px_model: model instance for pixel classification - :param ob_model: model instance for object classification - :param where_output: Path object that references output image directory - :param channel: input image channel to pass to pixel classification, or all channels if None - :param kwargs: variable-length keyword arguments - :return: - """ - assert isinstance(px_model, IlastikPixelClassifierModel) - assert isinstance(ob_model, IlastikObjectClassifierFromPixelPredictionsModel) - - ti = Timer() - raw_acc = generate_file_accessor(fpi) - if channel is not None: - channels = [channel] - else: - channels = range(0, raw_acc.chroma) - img = raw_acc.get_channels(channels, mip=kwargs.get('mip', False)) - ti.click('file_input') - - px_map, _ = px_model.infer(img) - ti.click('pixel_probability_inference') - - px_map_path = where_output / (px_model.model_id + '_pxmap_' + fpi.stem + '.tif') - write_accessor_data_to_file(px_map_path, px_map) - ti.click('pixel_map_output') - - ob_map, _ = ob_model.infer(img, px_map) - ti.click('object_classification') - - ob_map_path = where_output / (ob_model.model_id + '_obmap_' + fpi.stem + '.tif') - write_accessor_data_to_file(ob_map_path, ob_map) - ti.click('object_map_output') - - return WorkflowRunRecord( - pixel_model_id=px_model.model_id, - object_model_id=ob_model.model_id, - input_filepath=str(fpi), - pixel_map_filepath=str(px_map_path), - object_map_filepath=str(ob_map_path), - success=True, - timer_results=ti.events, - ) - diff --git a/model_server/scripts/run_server.py b/model_server/scripts/run_server.py index b4720cc6b61819c05d13097c5f43af6b7fe132be..b02f1effd96168d194c8493f1b8d912fa0e68481 100644 --- a/model_server/scripts/run_server.py +++ b/model_server/scripts/run_server.py @@ -6,12 +6,18 @@ from urllib3 import Retry import uvicorn import webbrowser -from ..conf.defaults import server_conf +from conf.defaults import server_conf + def parse_args(): parser = argparse.ArgumentParser( description='Start model server with optional arguments', ) + parser.add_argument( + '--confpath', + default='conf.servers.ilastik', + help='path to server startup configuration', + ) parser.add_argument( '--host', default=server_conf['host'], @@ -35,12 +41,13 @@ def parse_args(): return parser.parse_args() -def main(args, app_name='model_server.base.api:app') -> None: + +def main(args) -> None: print('CLI args:\n' + str(args)) server_process = Process( target=uvicorn.run, - args=(app_name,), + args=(f'{args.confpath}:app',), kwargs={ 'app_dir': '.', 'host': args.host, diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py index 9bb3b7a8c5a50a4f9aedc1818df3febc2590eaa3..1df5863c71f1d157d612e74ba83a822632d9f221 100644 --- a/tests/base/test_accessors.py +++ b/tests/base/test_accessors.py @@ -94,6 +94,17 @@ class TestCziImageFileAccess(unittest.TestCase): self.assertTrue(np.all(sz.data[:, :, :, 0] == cf.data[:, :, :, zi])) + def test_get_mip(self): + w = 256 + h = 512 + nc = 4 + nz = 11 + zi = 5 + cf = InMemoryDataAccessor(_random_int(h, w, nc, nz)) + sm = cf.get_mip() + self.assertEqual(sm.shape_dict['Z'], 1) + self.assertTrue(np.all(cf.data.max(axis=-1, keepdims=True) == sm.data)) + def test_crop_yx(self): w = 256 h = 512 diff --git a/tests/base/test_api.py b/tests/base/test_api.py index 46f614be2303c184950ad4b78cb2dba9bde58b96..f368dc05212c59f1b2ece208a719b4e249cedf70 100644 --- a/tests/base/test_api.py +++ b/tests/base/test_api.py @@ -1,29 +1,74 @@ from pathlib import Path +from fastapi import APIRouter, FastAPI +import numpy as np +from pydantic import BaseModel + import model_server.conf.testing as conf +from model_server.base.accessors import InMemoryDataAccessor +from model_server.base.api import app +from model_server.base.session import session +from tests.base.test_model import DummyInstanceSegmentationModel, DummySemanticSegmentationModel czifile = conf.meta['image_files']['czifile'] -class TestApiFromAutomatedClient(conf.TestServerBaseClass): +""" +Configure additional endpoints for testing +""" +test_router = APIRouter(prefix='/testing', tags=['testing']) - def copy_input_file_to_server(self): - from shutil import copyfile +class BounceBackParams(BaseModel): + par1: str + par2: list - resp = self._get('paths') - pa = resp.json()['inbound_images'] - outpath = Path(pa) / czifile['name'] - copyfile( - czifile['path'], - Path(pa) / czifile['name'] +@test_router.put('/bounce_back') +def list_bounce_back(params: BounceBackParams): + return {'success': True, 'params': {'par1': params.par1, 'par2': params.par2}} + +@test_router.put('/accessors/dummy_accessor/load') +def load_dummy_accessor() -> str: + acc = InMemoryDataAccessor( + np.random.randint( + 0, + 2 ** 8, + size=(512, 256, 3, 7), + dtype='uint8' ) + ) + return session.add_accessor(acc) + +@test_router.put('/models/dummy_semantic/load/') +def load_dummy_model() -> dict: + mid = session.load_model(DummySemanticSegmentationModel) + session.log_info(f'Loaded model {mid}') + return {'model_id': mid} + +@test_router.put('/models/dummy_instance/load/') +def load_dummy_model() -> dict: + mid = session.load_model(DummyInstanceSegmentationModel) + session.log_info(f'Loaded model {mid}') + return {'model_id': mid} + +app.include_router(test_router) + +""" +Implement unit testing on extended base app +""" + +class TestServerTestCase(conf.TestServerBaseClass): + app_name = 'tests.base.test_api:app' + input_data = czifile + + +class TestApiFromAutomatedClient(TestServerTestCase): def test_trivial_api_response(self): resp = self._get('') self.assertEqual(resp.status_code, 200) def test_bounceback_parameters(self): - resp = self._put('bounce_back', body={'par1': 'hello', 'par2': ['ab', 'cd']}) - self.assertEqual(resp.status_code, 200, resp.json()) + resp = self._put('testing/bounce_back', body={'par1': 'hello', 'par2': ['ab', 'cd']}) + self.assertEqual(resp.status_code, 200, resp.content) self.assertEqual(resp.json()['params']['par1'], 'hello', resp.json()) self.assertEqual(resp.json()['params']['par2'], ['ab', 'cd'], resp.json()) @@ -42,7 +87,7 @@ class TestApiFromAutomatedClient(conf.TestServerBaseClass): self.assertEqual(resp.content, b'{}') def test_load_dummy_semantic_model(self): - resp_load = self._put(f'models/dummy_semantic/load') + resp_load = self._put(f'testing/models/dummy_semantic/load') model_id = resp_load.json()['model_id'] self.assertEqual(resp_load.status_code, 200, resp_load.json()) resp_list = self._get('models') @@ -52,7 +97,7 @@ class TestApiFromAutomatedClient(conf.TestServerBaseClass): return model_id def test_load_dummy_instance_model(self): - resp_load = self._put(f'models/dummy_instance/load') + resp_load = self._put(f'testing/models/dummy_instance/load') model_id = resp_load.json()['model_id'] self.assertEqual(resp_load.status_code, 200, resp_load.json()) resp_list = self._get('models') @@ -70,30 +115,57 @@ class TestApiFromAutomatedClient(conf.TestServerBaseClass): ) self.assertEqual(resp.status_code, 404, resp.content.decode()) + def test_pipeline_errors_when_ids_not_found(self): + fname = self.copy_input_file_to_server() + model_id = self._put(f'testing/models/dummy_semantic/load').json()['model_id'] + in_acc_id = self._put(f'accessors/read_from_file/{fname}').json() - def test_i2i_inference_errors_when_model_not_found(self): - model_id = 'not_a_real_model' - resp = self._put( - f'workflows/segment', - query={'model_id': model_id, 'input_filename': 'not_a_real_file.name'} + # respond with 409 for invalid accessor_id + self.assertEqual( + self._put( + f'pipelines/segment', + body={'model_id': model_id, 'accessor_id': 'fake'} + ).status_code, + 409 ) - self.assertEqual(resp.status_code, 409, resp.content.decode()) + + # respond with 409 for invalid model_id + self.assertEqual( + self._put( + f'pipelines/segment', + body={'model_id': 'fake', 'accessor_id': in_acc_id} + ).status_code, + 409 + ) + def test_i2i_dummy_inference_by_api(self): - model_id = self.test_load_dummy_semantic_model() - self.copy_input_file_to_server() + fname = self.copy_input_file_to_server() + model_id = self._put(f'testing/models/dummy_semantic/load').json()['model_id'] + in_acc_id = self._put(f'accessors/read_from_file/{fname}').json() + + # run segmentation pipeline on preloaded accessor resp_infer = self._put( - f'workflows/segment', - query={ + f'pipelines/segment', + body={ + 'accessor_id': in_acc_id, 'model_id': model_id, - 'input_filename': czifile['name'], 'channel': 2, + 'keep_interm': True, }, ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) + out_acc_id = resp_infer.json()['output_accessor_id'] + self.assertTrue(self._get(f'accessors/{out_acc_id}').json()['loaded']) + acc_out = self.get_accessor(out_acc_id, 'dummy_semantic_output.tif') + self.assertEqual(acc_out.shape_dict['C'], 1) + + # validate intermediate data + resp_list = self._get(f'accessors').json() + self.assertEqual(len([k for k in resp_list.keys() if '_step' in k]), 2) def test_restarting_session_clears_loaded_models(self): - resp_load = self._put(f'models/dummy_semantic/load',) + resp_load = self._put(f'testing/models/dummy_semantic/load',) self.assertEqual(resp_load.status_code, 200, resp_load.json()) resp_list_0 = self._get('models') self.assertEqual(resp_list_0.status_code, 200) @@ -139,4 +211,52 @@ class TestApiFromAutomatedClient(conf.TestServerBaseClass): def test_get_logs(self): resp = self._get('session/logs') self.assertEqual(resp.status_code, 200) - self.assertEqual(resp.json()[0]['message'], 'Initialized session') \ No newline at end of file + self.assertEqual(resp.json()[0]['message'], 'Initialized session') + + def test_add_and_delete_accessor(self): + fname = self.copy_input_file_to_server() + + # add accessor to session + resp_add_acc = self._put( + f'accessors/read_from_file/{fname}', + ) + acc_id = resp_add_acc.json() + self.assertTrue(acc_id.startswith('auto_')) + + # confirm that accessor is listed in session context + resp_list_acc = self._get( + f'accessors', + ) + self.assertEqual(len(resp_list_acc.json()), 1) + self.assertTrue(list(resp_list_acc.json().keys())[0].startswith('auto_')) + self.assertTrue(resp_list_acc.json()[acc_id]['loaded']) + + # delete and check that its 'loaded' state changes + self.assertTrue(self._get(f'accessors/{acc_id}').json()['loaded']) + self.assertEqual(self._get(f'accessors/delete/{acc_id}').json(), acc_id) + self.assertFalse(self._get(f'accessors/{acc_id}').json()['loaded']) + + # and try a non-existent accessor ID + resp_wrong_acc = self._get('accessors/auto_123456') + self.assertEqual(resp_wrong_acc.status_code, 404) + + # load another... then remove all + self._put(f'accessors/read_from_file/{fname}') + self.assertEqual(sum([v['loaded'] for v in self._get('accessors').json().values()]), 1) + self.assertEqual(len(self._get(f'accessors/delete/*').json()), 1) + self.assertEqual(sum([v['loaded'] for v in self._get('accessors').json().values()]), 0) + + + def test_empty_accessor_list(self): + resp_list_acc = self._get( + f'accessors', + ) + self.assertEqual(len(resp_list_acc.json()), 0) + + def test_write_accessor(self): + acc_id = self._put('/testing/accessors/dummy_accessor/load').json() + self.assertTrue(self._get(f'accessors/{acc_id}').json()['loaded']) + sd = self._get(f'accessors/{acc_id}').json()['shape_dict'] + self.assertEqual(self._get(f'accessors/{acc_id}').json()['filepath'], '') + acc_out = self.get_accessor(accessor_id=acc_id, filename='test_output.tif') + self.assertEqual(sd, acc_out.shape_dict) \ No newline at end of file diff --git a/tests/base/test_model.py b/tests/base/test_model.py index d5b25c17c7528de730260e5845a7d7b6f21b3ae0..fb5a8f21e263ea5ab5f08d31f727ddcd24386e63 100644 --- a/tests/base/test_model.py +++ b/tests/base/test_model.py @@ -1,11 +1,59 @@ +from math import floor import unittest -from model_server.base.accessors import CziImageFileAccessor -from model_server.base.models import DummySemanticSegmentationModel, DummyInstanceSegmentationModel, CouldNotLoadModelError +import numpy as np + import model_server.conf.testing as conf +from model_server.base.accessors import CziImageFileAccessor, GenericImageDataAccessor, InMemoryDataAccessor +from model_server.base.models import CouldNotLoadModelError, InstanceSegmentationModel, SemanticSegmentationModel, BinaryThresholdSegmentationModel czifile = conf.meta['image_files']['czifile'] +class DummySemanticSegmentationModel(SemanticSegmentationModel): + + model_id = 'dummy_make_white_square' + + def load(self): + return True + + def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): + super().infer(img) + w = img.shape_dict['X'] + h = img.shape_dict['Y'] + result = np.zeros([h, w], dtype='uint8') + result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255 + return InMemoryDataAccessor(data=result), {'success': True} + + def label_pixel_class( + self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: + mask, _ = self.infer(img) + return mask + + +class DummyInstanceSegmentationModel(InstanceSegmentationModel): + + model_id = 'dummy_pass_input_mask' + + def load(self): + return True + + def infer( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor + ) -> (GenericImageDataAccessor, dict): + return img.__class__( + (mask.data / mask.data.max()).astype('uint16') + ) + + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + """ + Returns a trivial segmentation, i.e. the input mask with value 1 + """ + super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs) + return self.infer(img, mask) + + class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) @@ -53,6 +101,12 @@ class TestCziImageFileAccess(unittest.TestCase): ) return img, mask + def test_binary_segmentation(self): + model = BinaryThresholdSegmentationModel(tr=3e4) + img = self.cf.get_mono(0) + res = model.label_pixel_class(img) + self.assertTrue(res.is_mask()) + def test_dummy_instance_segmentation(self): img, mask = self.test_dummy_pixel_segmentation() model = DummyInstanceSegmentationModel() diff --git a/tests/base/test_pipelines.py b/tests/base/test_pipelines.py new file mode 100644 index 0000000000000000000000000000000000000000..3c643c0a8faac926b9303c0c1edf2ad34494ca27 --- /dev/null +++ b/tests/base/test_pipelines.py @@ -0,0 +1,67 @@ +from pathlib import Path +import unittest + +from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file +from model_server.base.pipelines import segment, segment_zproj + +import model_server.conf.testing as conf +from tests.base.test_model import DummySemanticSegmentationModel + +czifile = conf.meta['image_files']['czifile'] +zstack = conf.meta['image_files']['tifffile'] +output_path = conf.meta['output_path'] + + +class TestSegmentationPipelines(unittest.TestCase): + def setUp(self) -> None: + self.model = DummySemanticSegmentationModel() + + def test_call_segment_pipeline(self): + acc = generate_file_accessor(czifile['path']) + trace = segment.segment_pipeline({'accessor': acc}, {'model': self.model}, channel=2, smooth=3) + outfp = output_path / 'pipelines' / 'segment_binary_mask.tif' + write_accessor_data_to_file(outfp, trace.last) + + import tifffile + img = tifffile.imread(outfp) + w = czifile['w'] + h = czifile['h'] + + self.assertEqual( + img.shape, + (h, w), + 'Inferred image is not the expected shape' + ) + + self.assertEqual( + img[int(w/2), int(h/2)], + 255, + 'Middle pixel is not white as expected' + ) + + self.assertEqual( + img[0, 0], + 0, + 'First pixel is not black as expected' + ) + + interm_fps = trace.write_interm( + output_path / 'pipelines' / 'segment_interm', + prefix=czifile['name'] + ) + self.assertTrue([ofp.stem.split('_')[-1] for ofp in interm_fps] == ['mono', 'inference', 'smooth']) + + def test_call_segment_zproj_pipeline(self): + acc = generate_file_accessor(zstack['path']) + + trace1 = segment_zproj.segment_zproj_pipeline({'accessor': acc}, {'model': self.model}, channel=0, smooth=3, zi=4) + self.assertEqual(trace1.last.chroma, 1) + self.assertEqual(trace1.last.nz, 1) + + trace2 = segment_zproj.segment_zproj_pipeline({'accessor': acc}, {'model': self.model}, channel=0, smooth=3) + self.assertEqual(trace2.last.chroma, 1) + self.assertEqual(trace2.last.nz, 1) + + trace3 = segment_zproj.segment_zproj_pipeline({'accessor': acc}, {'model': self.model}) + self.assertEqual(trace3.last.chroma, 1) # still == 1: model returns a single channel regardless of input + self.assertEqual(trace3.last.nz, 1) diff --git a/tests/base/test_process.py b/tests/base/test_process.py index abd5777deb02834942689af0802b0c6ce1b65a90..56838fb1f755f9543f3fd8dfffdc21a4de70edb5 100644 --- a/tests/base/test_process.py +++ b/tests/base/test_process.py @@ -3,7 +3,7 @@ import unittest import numpy as np from model_server.base.annotators import draw_contours_on_patch -from model_server.base.process import get_safe_contours, mask_largest_object, pad +from model_server.base.process import get_safe_contours, mask_largest_object, pad, smooth class TestProcessingUtilityMethods(unittest.TestCase): def setUp(self) -> None: @@ -77,4 +77,19 @@ class TestSafeContours(unittest.TestCase): self.assertEqual((patch == 0).sum(), 0) patch = draw_contours_on_patch(self.patch, con) self.assertEqual((patch == 0).sum(), 20) - self.assertEqual((patch[0, :] == 0).sum(), 20) \ No newline at end of file + self.assertEqual((patch[0, :] == 0).sum(), 20) + +class TestSmooth(unittest.TestCase): + + def test_smooth_uint8_binary_mask(self): + mask = np.zeros([4, 4], dtype='uint8') + mask[1:3, 1:3] = 255 + mask[2, 2] = 0 + res = smooth(mask, sig=3) + + # assert type and range match + self.assertEqual(mask.dtype, res.dtype) + self.assertTrue(np.all(np.unique(mask) == np.unique(res))) + + # trivial case with sig=0 just returns input array + self.assertTrue(np.all(mask == smooth(mask, sig=0))) diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index 93e96df72b99640c418ece38cbb467fb0aeace92..0a03973e78ca824083f2821157c70bb9adc77265 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -6,12 +6,12 @@ from pathlib import Path import pandas as pd +from model_server.base.process import smooth from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_seg, RoiSetExportParams, RoiSetMetaParams from model_server.base.roiset import RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack -from model_server.base.models import DummyInstanceSegmentationModel -from model_server.base.process import smooth import model_server.conf.testing as conf +from tests.base.test_model import DummyInstanceSegmentationModel data = conf.meta['image_files'] output_path = conf.meta['output_path'] @@ -82,6 +82,8 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): zero_obmap = InMemoryDataAccessor(np.zeros(self.seg_mask.shape, self.seg_mask.dtype)) roiset = RoiSet.from_object_ids(self.stack_ch_pa, zero_obmap) self.assertEqual(roiset.count, 0) + roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel()) + self.assertTrue('classify_by_dummy_class' in roiset.get_df().columns) def test_slices_are_valid(self): roiset = self._make_roi_set() @@ -217,42 +219,6 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): ) - def test_classify_by_with_derived_channel(self): - class ModelWithDerivedInputs(DummyInstanceSegmentationModel): - def infer(self, img, mask): - return PatchStack(super().infer(img, mask).data * img.chroma) - - roiset = RoiSet.from_binary_mask( - self.stack, - self.seg_mask, - params=RoiSetMetaParams( - filters={'area': {'min': 1e3, 'max': 1e4}}, - deproject_channel=0, - ) - ) - roiset.classify_by( - 'multiple_input_model', - [0, 1], - ModelWithDerivedInputs(), - derived_channel_functions=[ - lambda acc: PatchStack(2 * acc.get_channels([0]).data), - lambda acc: PatchStack((0.5 * acc.get_channels([1]).data).astype('uint8')) - ] - ) - self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [4])) - self.assertTrue(all(np.unique(roiset.get_object_class_map('multiple_input_model').data) == [0, 4])) - - self.assertEqual(len(roiset.accs_derived), 2) - for di in roiset.accs_derived: - self.assertEqual(roiset.get_patches_acc().hw, di.hw) - self.assertEqual(roiset.get_patches_acc().nz, di.nz) - self.assertEqual(roiset.get_patches_acc().count, di.count) - - dpas = roiset.run_exports(output_path / 'derived_channels', 0, 'der', RoiSetExportParams(derived_channels=True)) - for fp in dpas['derived_channels']: - assert Path(fp).exists() - return roiset - def test_export_object_classes(self): record = self.test_classify_by().run_exports( output_path / 'object_class_maps', @@ -473,6 +439,46 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa for f in test_df[c]: self.assertTrue((where / f).exists(), where / f) + def test_get_interm_prods(self): + p = RoiSetExportParams(**{ + 'patches_3d': None, + 'annotated_patches_2d': { + 'draw_bounding_box': True, + 'rgb_overlay_channels': [3, None, None], + 'rgb_overlay_weights': [0.2, 1.0, 1.0], + 'pad_to': 512, + }, + 'patches_2d': { + 'draw_bounding_box': False, + 'draw_mask': False, + }, + 'annotated_zstacks': {}, + 'object_classes': True, + }) + self.roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel()) + interm = self.roiset.get_export_product_accessors( + channel=3, + params=p + ) + self.assertNotIn('patches_3d', interm.keys()) + self.assertEqual( + interm['annotated_patches_2d'].hw, + (self.roiset.get_df().h.max(), self.roiset.get_df().w.max()) + ) + self.assertEqual( + interm['patches_2d'].hw, + (self.roiset.get_df().h.max(), self.roiset.get_df().w.max()) + ) + self.assertEqual( + interm['annotated_zstacks'].hw, + self.stack.hw + ) + self.assertEqual( + interm['object_classes_dummy_class'].hw, + self.stack.hw + ) + self.assertTrue(np.all(interm['object_classes_dummy_class'].unique()[0] == [0, 1])) + def test_run_export_expanded_2d_patch(self): p = RoiSetExportParams(**{ 'patches_2d': { diff --git a/tests/base/test_roiset_derived.py b/tests/base/test_roiset_derived.py new file mode 100644 index 0000000000000000000000000000000000000000..52e7f6fc0a917d719262dcd0c19f3170414f2f18 --- /dev/null +++ b/tests/base/test_roiset_derived.py @@ -0,0 +1,60 @@ +from pathlib import Path +import unittest + +import numpy as np + +from model_server.base.roiset import RoiSetWithDerivedChannelsExportParams, RoiSetMetaParams +from model_server.base.roiset import RoiSetWithDerivedChannels +from model_server.base.accessors import generate_file_accessor, PatchStack +import model_server.conf.testing as conf +from tests.base.test_model import DummyInstanceSegmentationModel + +data = conf.meta['image_files'] +params = conf.meta['roiset'] +output_path = conf.meta['output_path'] + +class TestDerivedChannels(unittest.TestCase): + def setUp(self) -> None: + self.stack = generate_file_accessor(data['multichannel_zstack_raw']['path']) + self.stack_ch_pa = self.stack.get_mono(params['patches_channel']) + self.seg_mask = generate_file_accessor(data['multichannel_zstack_mask2d']['path']) + + def test_classify_by_with_derived_channel(self): + class ModelWithDerivedInputs(DummyInstanceSegmentationModel): + def infer(self, img, mask): + return PatchStack(super().infer(img, mask).data * img.chroma) + + roiset = RoiSetWithDerivedChannels.from_binary_mask( + self.stack, + self.seg_mask, + params=RoiSetMetaParams( + filters={'area': {'min': 1e3, 'max': 1e4}}, + deproject_channel=0, + ) + ) + self.assertIsInstance(roiset, RoiSetWithDerivedChannels) + roiset.classify_by( + 'multiple_input_model', + [0, 1], + ModelWithDerivedInputs(), + derived_channel_functions=[ + lambda acc: PatchStack(2 * acc.get_channels([0]).data), + lambda acc: PatchStack((0.5 * acc.get_channels([1]).data).astype('uint8')) + ] + ) + self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [4])) + self.assertTrue(all(np.unique(roiset.get_object_class_map('multiple_input_model').data) == [0, 4])) + + self.assertEqual(len(roiset.accs_derived), 2) + for di in roiset.accs_derived: + self.assertEqual(roiset.get_patches_acc().hw, di.hw) + self.assertEqual(roiset.get_patches_acc().nz, di.nz) + self.assertEqual(roiset.get_patches_acc().count, di.count) + + dpas = roiset.run_exports( + output_path / 'derived_channels', 0, 'der', + RoiSetWithDerivedChannelsExportParams(derived_channels=True) + ) + for fp in dpas['derived_channels']: + assert Path(fp).exists() + return roiset \ No newline at end of file diff --git a/tests/base/test_session.py b/tests/base/test_session.py index d8b3f27081559e175f56de2951f1a90c8247cae7..6843f36a011851638ce3b4739ce9ff6caf15b301 100644 --- a/tests/base/test_session.py +++ b/tests/base/test_session.py @@ -1,9 +1,9 @@ from os.path import exists import pathlib -from pydantic import BaseModel import unittest -from model_server.base.models import DummySemanticSegmentationModel +import numpy as np +from model_server.base.accessors import InMemoryDataAccessor from model_server.base.session import session class TestGetSessionObject(unittest.TestCase): @@ -64,56 +64,6 @@ class TestGetSessionObject(unittest.TestCase): self.assertEqual(logs[1]['level'], 'WARNING') self.assertEqual(logs[-1]['message'], 'Initialized session') - def test_session_loads_model(self): - MC = DummySemanticSegmentationModel - success = self.sesh.load_model(MC) - self.assertTrue(success) - loaded_models = self.sesh.describe_loaded_models() - self.assertTrue( - (MC.__name__ + '_00') in loaded_models.keys() - ) - self.assertEqual( - loaded_models[MC.__name__ + '_00']['class'], - MC.__name__ - ) - - def test_session_loads_second_instance_of_same_model(self): - MC = DummySemanticSegmentationModel - self.sesh.load_model(MC) - self.sesh.load_model(MC) - self.assertIn(MC.__name__ + '_00', self.sesh.models.keys()) - self.assertIn(MC.__name__ + '_01', self.sesh.models.keys()) - - def test_session_loads_model_with_params(self): - MC = DummySemanticSegmentationModel - class _PM(BaseModel): - p: str - p1 = _PM(p='abc') - success = self.sesh.load_model(MC, params=p1) - self.assertTrue(success) - loaded_models = self.sesh.describe_loaded_models() - mid = MC.__name__ + '_00' - self.assertEqual(loaded_models[mid]['params'], p1) - - # load a second model and confirm that the first is locatable by its param entry - p2 = _PM(p='def') - self.sesh.load_model(MC, params=p2) - find_mid = self.sesh.find_param_in_loaded_models('p', 'abc') - self.assertEqual(mid, find_mid) - self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1) - - def test_session_finds_existing_model_with_different_path_formats(self): - MC = DummySemanticSegmentationModel - class _PM(BaseModel): - path: str - - mod_pw = _PM(path='c:\\windows\\dummy.pa') - mod_pu = _PM(path='c:/windows/dummy.pa') - - mid = self.sesh.load_model(MC, params=mod_pw) - find_mid = self.sesh.find_param_in_loaded_models('path', mod_pu.path, is_path=True) - self.assertEqual(mid, find_mid) - def test_change_output_path(self): pa = self.sesh.get_paths()['inbound_images'] self.assertIsInstance(pa, pathlib.Path) @@ -136,3 +86,53 @@ class TestGetSessionObject(unittest.TestCase): self.assertEqual(len(dfv), len(data)) self.assertEqual(dfv.columns[0], 'X') self.assertEqual(dfv.columns[1], 'Y') + + + def test_add_and_remove_accessor(self): + acc = InMemoryDataAccessor( + np.random.randint( + 0, + 2 ** 8, + size=(512, 256, 3, 7), + dtype='uint8' + ) + ) + shd = acc.shape_dict + + # add accessor to session registry + acc_id = session.add_accessor(acc) + self.assertEqual(session.get_accessor_info(acc_id)['shape_dict'], shd) + self.assertTrue(session.get_accessor_info(acc_id)['loaded']) + + # remove accessor from session registry + session.del_accessor(acc_id) + self.assertEqual(session.get_accessor_info(acc_id)['shape_dict'], shd) + self.assertFalse(session.get_accessor_info(acc_id)['loaded']) + + def test_add_and_use_accessor(self): + acc = InMemoryDataAccessor( + np.random.randint( + 0, + 2 ** 8, + size=(512, 256, 3, 7), + dtype='uint8' + ) + ) + shd = acc.shape_dict + + # add accessor to session registry + acc_id = session.add_accessor(acc) + self.assertEqual(session.get_accessor_info(acc_id)['shape_dict'], shd) + self.assertTrue(session.get_accessor_info(acc_id)['loaded']) + + # get accessor from session registry without popping + acc_get = session.get_accessor(acc_id, pop=False) + self.assertIsInstance(acc_get, InMemoryDataAccessor) + self.assertEqual(acc_get.shape_dict, shd) + self.assertTrue(session.get_accessor_info(acc_id)['loaded']) + + # get accessor from session registry with popping + acc_get = session.get_accessor(acc_id) + self.assertIsInstance(acc_get, InMemoryDataAccessor) + self.assertEqual(acc_get.shape_dict, shd) + self.assertFalse(session.get_accessor_info(acc_id)['loaded']) diff --git a/tests/base/test_workflow.py b/tests/base/test_workflow.py deleted file mode 100644 index 54b8326845f03de394bba254c3948169d9b273a7..0000000000000000000000000000000000000000 --- a/tests/base/test_workflow.py +++ /dev/null @@ -1,39 +0,0 @@ -import unittest - -from model_server.base.models import DummySemanticSegmentationModel -from model_server.base.workflows import classify_pixels -import model_server.conf.testing as conf - -czifile = conf.meta['image_files']['czifile'] -output_path = conf.meta['output_path'] - -class TestGetSessionObject(unittest.TestCase): - def setUp(self) -> None: - self.model = DummySemanticSegmentationModel() - - def test_single_session_instance(self): - result = classify_pixels(czifile['path'], self.model, output_path, channel=2) - self.assertTrue(result.success) - - import tifffile - img = tifffile.imread(result.output_filepath) - w = czifile['w'] - h = czifile['h'] - - self.assertEqual( - img.shape, - (h, w), - 'Inferred image is not the expected shape' - ) - - self.assertEqual( - img[int(w/2), int(h/2)], - 255, - 'Middle pixel is not white as expected' - ) - - self.assertEqual( - img[0, 0], - 0, - 'First pixel is not black as expected' - ) \ No newline at end of file diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py index fd6893a55feff793facb931efb6c4265abd9961d..e7744acb0c53e02353654602e7916b69de61679e 100644 --- a/tests/test_ilastik/test_ilastik.py +++ b/tests/test_ilastik/test_ilastik.py @@ -1,12 +1,16 @@ +from pathlib import Path +from shutil import copyfile import unittest import numpy as np from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file +from model_server.base.api import app from model_server.extensions.ilastik import models as ilm -from model_server.extensions.ilastik.workflows import infer_px_then_ob_model -from model_server.base.roiset import get_label_ids, RoiSet, RoiSetMetaParams -from model_server.base.workflows import classify_pixels +from model_server.extensions.ilastik.pipelines import px_then_ob +from model_server.extensions.ilastik.router import router +from model_server.base.roiset import RoiSet, RoiSetMetaParams +from model_server.base.pipelines import segment import model_server.conf.testing as conf data = conf.meta['image_files'] @@ -15,6 +19,8 @@ params = conf.meta['roiset'] czifile = conf.meta['image_files']['czifile'] ilastik_classifiers = conf.meta['ilastik_classifiers'] +app.include_router(router) + def _random_int(*args): return np.random.randint(0, 2 ** 8, size=args, dtype='uint8') @@ -180,30 +186,28 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertEqual(objmap.data.max(), 2) def test_ilastik_pixel_classification_as_workflow(self): - result = classify_pixels( - czifile['path'], - ilm.IlastikPixelClassifierModel( - params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px']['path'].__str__()), - ), - output_path, + res = segment.segment_pipeline( + accessors={ + 'accessor': generate_file_accessor(czifile['path']) + }, + models={ + 'model': ilm.IlastikPixelClassifierModel( + params=ilm.IlastikPixelClassifierParams( + project_file=ilastik_classifiers['px']['path'].__str__() + ), + ), + }, channel=0, ) - self.assertTrue(result.success) - self.assertGreater(result.timer_results['inference'], 0.1) + self.assertGreater(res.times['inference'], 0.1) -class TestIlastikOverApi(conf.TestServerBaseClass): - def _copy_input_file_to_server(self): - from pathlib import Path - from shutil import copyfile +class TestServerTestCase(conf.TestServerBaseClass): + app_name = 'tests.test_ilastik.test_ilastik:app' + input_data = czifile - resp = self._get('paths') - pa = resp.json()['inbound_images'] - copyfile( - czifile['path'], - Path(pa) / czifile['name'] - ) +class TestIlastikOverApi(TestServerTestCase): def test_httpexception_if_incorrect_project_file_loaded(self): resp_load = self._put( 'ilastik/seg/load/', @@ -273,6 +277,18 @@ class TestIlastikOverApi(conf.TestServerBaseClass): self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromPixelPredictionsModel') return model_id + def test_load_ilastik_model_with_model_id(self): + mid = 'new_model_id' + resp_load = self._put( + 'ilastik/pxmap_to_obj/load/', + body={ + 'project_file': str(ilastik_classifiers['pxmap_to_obj']['path']), + 'model_id': mid, + }, + ) + res_mid = resp_load.json()['model_id'] + self.assertEqual(res_mid, mid) + def test_load_ilastik_seg_to_obj_model(self): resp_load = self._put( 'ilastik/seg_to_obj/load/', @@ -288,34 +304,37 @@ class TestIlastikOverApi(conf.TestServerBaseClass): return model_id def test_ilastik_infer_pixel_probability(self): - self._copy_input_file_to_server() + fname = self.copy_input_file_to_server() model_id = self.test_load_ilastik_pixel_model() + in_acc_id = self._put(f'accessors/read_from_file/{fname}').json() resp_infer = self._put( - f'workflows/segment', - query={'model_id': model_id, 'input_filename': czifile['name'], 'channel': 0}, + f'pipelines/segment', + body={'model_id': model_id, 'accessor_id': in_acc_id, 'channel': 0}, ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) def test_ilastik_infer_px_then_ob(self): - self._copy_input_file_to_server() + fname = self.copy_input_file_to_server() px_model_id = self.test_load_ilastik_pixel_model() ob_model_id = self.test_load_ilastik_pxmap_to_obj_model() + in_acc_id = self._put(f'accessors/read_from_file/{fname}').json() + resp_infer = self._put( - 'ilastik/pixel_then_object_classification/infer/', - query={ + 'ilastik/pipelines/pixel_then_object_classification/infer/', + body={ 'px_model_id': px_model_id, 'ob_model_id': ob_model_id, - 'input_filename': czifile['name'], + 'accessor_id': in_acc_id, 'channel': 0, } ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) -class TestIlastikOnMultichannelInputs(conf.TestServerBaseClass): +class TestIlastikOnMultichannelInputs(TestServerTestCase): def setUp(self) -> None: super(TestIlastikOnMultichannelInputs, self).setUp() self.pa_px_classifier = ilastik_classifiers['px_color_zstack']['path'] @@ -324,6 +343,7 @@ class TestIlastikOnMultichannelInputs(conf.TestServerBaseClass): self.pa_input_image = data['multichannel_zstack_raw']['path'] self.pa_mask = data['multichannel_zstack_mask3d']['path'] + def test_classify_pixels(self): img = generate_file_accessor(self.pa_input_image) self.assertGreater(img.chroma, 1) @@ -343,56 +363,78 @@ class TestIlastikOnMultichannelInputs(conf.TestServerBaseClass): self.assertEqual(obmap.hw, img.hw) self.assertEqual(obmap.nz, img.nz) - def _call_workflow(self, channel): - return infer_px_then_ob_model( - self.pa_input_image, - ilm.IlastikPixelClassifierModel( - ilm.IlastikParams(project_file=self.pa_px_classifier.__str__()), - ), - ilm.IlastikObjectClassifierFromPixelPredictionsModel( - ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__()), - ), - output_path, - channel=channel, - ) - def test_workflow(self): + """ + Test calling pixel then object map classification pipeline function directly + """ + def _call_workflow(channel): + return px_then_ob.pixel_then_object_classification_pipeline( + accessors={ + 'accessor': generate_file_accessor(self.pa_input_image) + }, + models={ + 'px_model': ilm.IlastikPixelClassifierModel( + ilm.IlastikParams(project_file=self.pa_px_classifier.__str__()), + ), + 'ob_model': ilm.IlastikObjectClassifierFromPixelPredictionsModel( + ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__()), + ) + }, + channel=channel, + ) + with self.assertRaises(ilm.IlastikInputShapeError): - self._call_workflow(channel=0) - res = self._call_workflow(channel=None) + _call_workflow(channel=0) + res = _call_workflow(channel=None) acc_input = generate_file_accessor(self.pa_input_image) - acc_obmap = generate_file_accessor(res.object_map_filepath) + acc_obmap = res['ob_map'] self.assertEqual(acc_obmap.hw, acc_input.hw) self.assertEqual(len(acc_obmap.unique()[1]), 3) def test_api(self): - resp_load = self._put( + """ + Test calling pixel then object map classification pipeline over API + """ + copyfile( + self.pa_input_image, + Path(self._get('paths').json()['inbound_images']) / self.pa_input_image.name + ) + + in_acc_id = self._put(f'accessors/read_from_file/{self.pa_input_image.name}').json() + + resp_load_px = self._put( 'ilastik/seg/load/', body={'project_file': str(self.pa_px_classifier)}, ) - self.assertEqual(resp_load.status_code, 200, resp_load.json()) - px_model_id = resp_load.json()['model_id'] + self.assertEqual(resp_load_px.status_code, 200, resp_load_px.json()) + px_model_id = resp_load_px.json()['model_id'] - resp_load = self._put( + resp_load_ob = self._put( 'ilastik/pxmap_to_obj/load/', body={'project_file': str(self.pa_ob_pxmap_classifier)}, ) - self.assertEqual(resp_load.status_code, 200, resp_load.json()) - ob_model_id = resp_load.json()['model_id'] + self.assertEqual(resp_load_ob.status_code, 200, resp_load_ob.json()) + ob_model_id = resp_load_ob.json()['model_id'] + # run the pipeline resp_infer = self._put( - 'ilastik/pixel_then_object_classification/infer/', - query={ + 'ilastik/pipelines/pixel_then_object_classification/infer/', + body={ + 'accessor_id': in_acc_id, 'px_model_id': px_model_id, 'ob_model_id': ob_model_id, - 'input_filename': self.pa_input_image.__str__(), } ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) - acc_input = generate_file_accessor(self.pa_input_image) - acc_obmap = generate_file_accessor(resp_infer.json()['object_map_filepath']) - self.assertEqual(acc_obmap.hw, acc_input.hw) + + # save output object map to file and compare + obmap_id = resp_infer.json()['output_accessor_id'] + obmap_acc = self.get_accessor(obmap_id) + self.assertEqual(obmap_acc.shape_dict['C'], 1) + + # compare dimensions to input image + self.assertEqual(obmap_acc.hw, generate_file_accessor(self.pa_input_image).hw) class TestIlastikObjectClassification(unittest.TestCase): diff --git a/tests/test_ilastik/test_roiset_workflow.py b/tests/test_ilastik/test_roiset_workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..cec05dbdbdb1be3c97c1ed53cccf3d57984fa34c --- /dev/null +++ b/tests/test_ilastik/test_roiset_workflow.py @@ -0,0 +1,191 @@ +from pathlib import Path +import unittest + +import numpy as np + + +from model_server.base.accessors import generate_file_accessor +from model_server.base.api import app +import model_server.conf.testing as conf +from model_server.base.pipelines.roiset_obmap import RoiSetObjectMapParams, roiset_object_map_pipeline +import model_server.extensions.ilastik.models as ilm +from model_server.extensions.ilastik.router import router + +app.include_router(router) + +data = conf.meta['image_files'] +output_path = conf.meta['output_path'] +test_params = conf.meta['roiset'] +classifiers = conf.meta['ilastik_classifiers'] + + +class BaseTestRoiSetMonoProducts(object): + + @property + def fpi(self): + return data['multichannel_zstack_raw']['path'].__str__() + + @property + def stack(self): + return generate_file_accessor(self.fpi) + + @property + def stack_ch_pa(self): + return self.stack.get_mono(test_params['patches_channel']) + + @property + def seg_mask(self): + return generate_file_accessor(data['multichannel_zstack_mask2d']['path']) + + def _get_export_params(self): + return { + 'patches_3d': None, + 'annotated_patches_2d': { + 'draw_bounding_box': True, + 'rgb_overlay_channels': [3, None, None], + 'rgb_overlay_weights': [0.2, 1.0, 1.0], + 'pad_to': 512, + }, + 'patches_2d': { + 'draw_bounding_box': False, + 'draw_mask': False, + }, + 'annotated_zstacks': None, + 'object_classes': True, + } + + def _get_roi_params(self): + return { + 'mask_type': 'boxes', + 'filters': { + 'area': {'min': 1e0, 'max': 1e8} + }, + 'expand_box_by': [128, 2], + 'deproject_channel': 0, + } + + def _get_models(self): # tests can either use model objects directly, or load in API via project file string + fp_px = classifiers['px']['path'].__str__() + fp_ob = classifiers['seg_to_obj']['path'].__str__() + return { + 'pixel_classifier_segmentation': { + 'name': 'ilastik_px_mod', + 'project_file': fp_px, + 'model': ilm.IlastikPixelClassifierModel( + ilm.IlastikPixelClassifierParams( + project_file=fp_px, + ) + ) + }, + 'object_classifier': { + 'name': 'ilastik_ob_mod', + 'project_file': fp_ob, + 'model': ilm.IlastikObjectClassifierFromSegmentationModel( + ilm.IlastikParams( + project_file=fp_ob + ) + ) + }, + } + + +class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase): + + def _pipeline_params(self): + return { + 'api': False, + 'accessor_id': 'acc_id', + 'pixel_classifier_segmentation_model_id': 'px_id', + 'object_classifier_model_id': 'ob_id', + 'segmentation': { + 'channel': test_params['segmentation_channel'], + }, + 'patches_channel': test_params['patches_channel'], + 'roi_params': self._get_roi_params(), + 'export_params': self._get_export_params(), + } + + def test_object_map_workflow(self): + acc_in = generate_file_accessor(self.fpi) + params = RoiSetObjectMapParams( + **self._pipeline_params(), + ) + trace, rois = roiset_object_map_pipeline( + {'accessor': acc_in}, + {f'{k}_model': v['model'] for k, v in self._get_models().items()}, + **params.dict() + ) + self.assertEqual(trace.pop('annotated_patches_2d').count, 13) + self.assertEqual(trace.pop('patches_2d').count, 13) + trace.write_interm(Path(output_path) / 'trace', 'roiset_worfklow_trace', skip_first=False, skip_last=False) + self.assertTrue('ob_id' in trace.keys()) + self.assertEqual(len(trace['labeled'].unique()[0]), 14) + self.assertEqual(rois.count, 13) + self.assertEqual(len(trace['ob_id'].unique()[0]), 2) + + +class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts): + + app_name = 'tests.test_ilastik.test_roiset_workflow:app' + input_data = data['multichannel_zstack_raw'] + + + def setUp(self) -> None: + self.where_out = output_path / 'roiset' + self.where_out.mkdir(parents=True, exist_ok=True) + return conf.TestServerBaseClass.setUp(self) + + def test_trivial_api_response(self): + resp = self._get('') + self.assertEqual(resp.status_code, 200) + + def test_load_input_accessor(self): + fname = self.copy_input_file_to_server() + return self._put(f'accessors/read_from_file/{fname}').json() + + def test_load_pixel_classifier(self): + resp = self._put( + 'ilastik/seg/load/', + body={'project_file': self._get_models()['pixel_classifier_segmentation']['project_file']}, + ) + model_id = resp.json()['model_id'] + self.assertTrue(model_id.startswith('IlastikPixelClassifierModel')) + return model_id + + def test_load_object_classifier(self): + resp = self._put( + 'ilastik/seg_to_obj/load/', + body={'project_file': self._get_models()['object_classifier']['project_file']}, + ) + model_id = resp.json()['model_id'] + self.assertTrue(model_id.startswith('IlastikObjectClassifierFromSegmentationModel')) + return model_id + + def _object_map_workflow(self, ob_classifer_id): + resp = self._put( + 'pipelines/roiset_to_obmap/infer', + body={ + 'accessor_id': self.test_load_input_accessor(), + 'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(), + 'object_classifier_model_id': ob_classifer_id, + 'segmentation': {'channel': 0}, + 'patches_channel': 1, + 'roi_params': self._get_roi_params(), + 'export_params': self._get_export_params(), + }, + ) + self.assertEqual(resp.status_code, 200, resp.json()) + oid = resp.json()['output_accessor_id'] + obmap_fn = self._put(f'/accessors/write_to_file/{oid}').json() + where_out = self._get('paths').json()['outbound_images'] + obmap_fp = Path(where_out) / obmap_fn + self.assertTrue(obmap_fp.exists()) + return generate_file_accessor(obmap_fp) + + def test_workflow_with_object_classifier(self): + acc = self._object_map_workflow(self.test_load_object_classifier()) + self.assertTrue(np.all(acc.unique()[0] == [0, 1, 2])) + + def test_workflow_without_object_classifier(self): + acc = self._object_map_workflow(None) + self.assertTrue(np.all(acc.unique()[0] == [0, 1]))