From a295a72dba3fc52d4834af56200103638ebc264e Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Tue, 13 Aug 2024 10:47:15 +0200 Subject: [PATCH] First test passes on direct pipeline call with accessors --- model_server/base/pipelines/params.py | 5 +- model_server/base/pipelines/roiset_obmap.py | 179 ++++++++++ model_server/base/session.py | 1 + .../ilastik/pipelines/roiset_obmap.py | 266 --------------- tests/base/test_pipelines.py | 2 +- tests/test_ilastik/test_roiset_workflow.py | 317 ++++++++++++++++++ 6 files changed, 501 insertions(+), 269 deletions(-) create mode 100644 model_server/base/pipelines/roiset_obmap.py delete mode 100644 model_server/extensions/ilastik/pipelines/roiset_obmap.py create mode 100644 tests/test_ilastik/test_roiset_workflow.py diff --git a/model_server/base/pipelines/params.py b/model_server/base/pipelines/params.py index ad9585a5..b81ccf41 100644 --- a/model_server/base/pipelines/params.py +++ b/model_server/base/pipelines/params.py @@ -7,18 +7,19 @@ from ..session import session, AccessorIdError class PipelineParams(BaseModel): keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session') + api: bool = Field(True, description='Validate parameters against server session and map HTTP errors if True') @root_validator(pre=False) def models_are_loaded(cls, dd): for k, v in dd.items(): - if k.endswith('model_id') and v not in session.describe_loaded_models().keys(): + if dd['api'] and k.endswith('model_id') and v not in session.describe_loaded_models().keys(): raise HTTPException(status_code=409, detail=f'Model with {k} = {v} has not been loaded') return dd @root_validator(pre=False) def accessors_are_loaded(cls, dd): for k, v in dd.items(): - if k.endswith('accessor_id'): + if dd['api'] and k.endswith('accessor_id'): try: info = session.get_accessor_info(v) except AccessorIdError as e: diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py new file mode 100644 index 00000000..a25d740d --- /dev/null +++ b/model_server/base/pipelines/roiset_obmap.py @@ -0,0 +1,179 @@ +from typing import Dict, Union + +from fastapi import APIRouter +import pandas as pd +from pydantic import Field + + +from model_server.base.accessors import GenericImageDataAccessor, PatchStack +from model_server.base.pipelines.params import PipelineParams, PipelineRecord +from model_server.base.pipelines.util import call_pipeline +from model_server.base.roiset import RoiSet, RoiSetMetaParams, RoiSetExportParams +from model_server.base.session import session +from model_server.base.util import PipelineTrace + +from model_server.base.models import Model, InstanceSegmentationModel, SemanticSegmentationModel + +router = APIRouter( + prefix='/pipelines', +) + +class RoiSetObjectMapParams(PipelineParams): + accessor_id: str = Field( + description='ID(s) of previously loaded accessor(s) to use as pipeline input' + ) + pixel_classifier_segmentation_model_id: str = Field( + description='Pixel classifier applied to segmentation_channel(s) to segment objects' + ) + object_classifier_model_id: Union[str, None] = Field( + None, + description='Object classifier used to classify segmented objectss' + ) + pixel_classifier_derived_model_id: Union[str, None] = Field( + None, + description='Pixel classifier used to derive channel(s) as additional inputs to object classification' + ) + segmentation_channel: int = Field( + description='Channel of input image to use for solving segmentation' + ) + patches_channel: int = Field( + description='Channel of input image used in patches sent to object classifier' + ) + zmask_zindex: Union[int, None] = Field( + None, + description='z coordinate to use on input image when solving segmentation; apply MIP if empty', + ) + roi_params: RoiSetMetaParams = RoiSetMetaParams(**{ + 'mask_type': 'boxes', + 'filters': { + 'area': {'min': 1e3, 'max': 1e8} + }, + 'expand_box_by': [128, 2] + }) + # TODO: maybe don't support all these exports here; instead leverage interm accessors + export_params: RoiSetExportParams = RoiSetExportParams(**{ + 'annotated_patches_2d': { + 'draw_bounding_box': True, + 'pad_to': 256, + }, + 'patches_2d': { + 'draw_bounding_box': True, + 'draw_mask': False, + }, + 'patch_masks': {}, + 'object_classes': True, + 'dataframe': True, + }) + derived_channels_input_channel: Union[int, None] = Field( + None, + description='Channel of input image from which to compute derived channels; use all if empty' + ) + derived_channels_output_channels: Union[int, list] = Field( + None, + description='Derived channels to send to object classifier; use all if empty' + ) + # TODO: move to subclassed pipeline + # label_params: Union[LabelFromBoundarySegParams, None] = None + export_label_interm: bool = False + + +class RoiSetToObjectMapRecord(PipelineRecord): + pass + +@router.put('/roiset_to_obmap/infer') +def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord: + """ + Compute a RoiSet from 2d segmentation, apply to z-stack, and optionally apply object classification. + """ + + + record, dataframe = call_pipeline(roiset_object_map_pipeline, p) + + # TODO: try labeling this correctly as input accessor_id instead of filename + session.write_to_table('RoiSet', {'input_filename': p.accessor_id}, dataframe) + return record + + +def roiset_object_map_pipeline( + accessors: Dict[str, GenericImageDataAccessor], + models: Dict[str, Model], + **k +) -> (RoiSetToObjectMapRecord, pd.DataFrame): + if not isinstance(models['pixel_classifier_segmentation_model'], SemanticSegmentationModel): + raise IncompatibleModelsError('Expecting a pixel classification model') + + d = PipelineTrace() + d['raw'] = accessors['accessor'] + + # MIP if no zmask z-index is given, then classify pixels + zmi = k['zmask_zindex'] + sch = k['segmentation_channel'] + + if isinstance(zmi, int): + assert 0 < zmi < d.last.nz + d['mip'] = d.last.get_mono(channel=sch).get_zi(zmi) + else: + d['mip'] = d.last.get_mono(channel=sch).apply(lambda x: x.max(axis=-1, keepdims=True)) + + d['mip_mask'] = models['pixel_classifier_segmentation_model'].label_pixel_class(d.last) + + # TODO: subclass this boundary pipeline in separate module + # # optionally label objects using a boundary segmentation pipeline + # if p.label_params: + # la_nda, interm = label_from_boundary_seg( + # mip_mask.data[:, :, 0, 0], + # p.label_params, + # return_interm=True + # ) + # labels = InMemoryDataAccessor(la_nda) + # rois = RoiSet(stack, labels, params=p.roi_params) + # else: + # rois = RoiSet.from_segmentation(stack, mip_mask, params=p.roi_params) + # ti.click('generate_zmasks') + + rois = RoiSet.from_binary_mask(d['raw'], d.last, RoiSetMetaParams(**k['roi_params'])) + + # optionally derive additional inputs for object classification + if dpmod := models.get('pixel_classifier_derived'): + ic = k['derived_channels_input_channel'] + ocs = k['derived_channels_output_channels'] + assert ic < d['raw'].chroma + if not isinstance(dpmod, SemanticSegmentationModel): + raise IncompatibleModelsError('Expecting pixel_classifier_derived to be a pixel classification model') + + def _derive(acc: GenericImageDataAccessor, oc: int): + acc_mono = acc.get_channels([ic]) + pxmap = dpmod.infer_patch_stack(acc_mono)[0] + assert oc < pxmap.chroma + return PatchStack((pxmap.get_channels([oc]).data * 255).astype('uint8')) + derived_channel_handles = [ + lambda a: _derive(a, oc) for oc in ocs + ] + else: + derived_channel_handles = None + + # optionally classify if an object classifier is passed + if obmod := models.get('object_classifier_model'): + obmod_name = k['object_classifier_model_id'] + assert isinstance(obmod, InstanceSegmentationModel) + rois.classify_by( + obmod_name, + [k['patches_channel']], + obmod, + derived_channel_functions=derived_channel_handles + ) + d[obmod_name] = rois.get_object_class_map(obmod_name) + + # TODO: subclass this boundary pipeline in separate module + # if p.label_params and p.export_label_interm: + # fp_la_interm = Path(output_folder_path) / 'mip_masks' / f'interm_{fstem}.png' + # plot_image_sequence_with_markers(interm, fp_la_interm, n_rows=3) + # record['label_interm'] = fp_la_interm + + return d + +class Error(Exception): + pass + +class IncompatibleModelsError(Error): + pass \ No newline at end of file diff --git a/model_server/base/session.py b/model_server/base/session.py index ea1b3bf9..44e5855d 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -235,6 +235,7 @@ class _Session(object): if key is None: def mid(i): + # TODO: give model the option to report its own name return f'{ModelClass.__name__}_{i:02d}' while mid(ii) in self.models.keys(): diff --git a/model_server/extensions/ilastik/pipelines/roiset_obmap.py b/model_server/extensions/ilastik/pipelines/roiset_obmap.py deleted file mode 100644 index fec4f748..00000000 --- a/model_server/extensions/ilastik/pipelines/roiset_obmap.py +++ /dev/null @@ -1,266 +0,0 @@ -from pathlib import Path -from typing import Dict, List, Union - -from fastapi import APIRouter, HTTPException -import numpy as np -import pandas as pd -from pydantic import Field - - -from skimage.measure import label -from skimage.morphology import dilation -from sklearn.model_selection import train_test_split - -# from extensions.chaeo.plotting import plot_image_sequence_with_markers -# from extensions.chaeo.process import label_from_boundary_seg, LabelFromBoundarySegParams - -from model_server.base.accessors import GenericImageDataAccessor, PatchStack -from model_server.base.models import Model, SemanticSegmentationModel -from model_server.base.pipelines.params import PipelineParams, PipelineRecord -from model_server.base.pipelines.util import call_pipeline -from model_server.base.roiset import RoiSetMetaParams, RoiSetExportParams -from model_server.base.process import mask_largest_object -from model_server.base.roiset import RoiSet -from model_server.base.session import session -from model_server.base.util import PipelineTrace - -from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file -from model_server.base.models import Model, InstanceSegmentationModel, SemanticSegmentationModel -# from model_server.base.workflows import Timer - -router = APIRouter( - prefix='/pipelines', -) - -class RoiSetObjectMapParams(PipelineParams): - accessor_id: str = Field( - description='ID(s) of previously loaded accessor(s) to use as pipeline input' - ) - pixel_classifier_segmentation_model_id: str = Field( - description='Pixel classifier applied to segmentation_channel(s) to segment objects' - ) - object_classifier_model_id: Union[str, None] = Field( - None, - description='Object classifier used to classify segmented objectss' - ) - pixel_classifier_derived_model_id: Union[str, None] = Field( - None, - description='Pixel classifier used to derive channel(s) as additional inputs to object classification' - ) - segmentation_channel: int = Field( - description='Channel of input image to use for solving segmentation' - ) - patches_channel: int = Field( - description='Channel of input image used in patches sent to object classifier' - ) - zmask_zindex: Union[int, None] = Field( - None, - description='z coordinate to use on input image when solving segmentation; apply MIP if empty', - ) - roi_params: RoiSetMetaParams = RoiSetMetaParams(**{ - 'mask_type': 'boxes', - 'filters': { - 'area': {'min': 1e3, 'max': 1e8} - }, - 'expand_box_by': [128, 2] - }) - # TODO: maybe don't support all these exports here; instead leverage interm accessors - export_params: RoiSetExportParams = RoiSetExportParams(**{ - 'annotated_patches_2d': { - 'draw_bounding_box': True, - 'pad_to': 256, - }, - 'patches_2d': { - 'draw_bounding_box': True, - 'draw_mask': False, - }, - 'patch_masks': {}, - 'object_classes': True, - 'dataframe': True, - }) - derived_channels_input_channel: Union[int, None] = Field( - None, - description='Channel of input image from which to compute derived channels; use all if empty' - ) - derived_channels_output_channels: Union[int, list] = Field( - None, - description='Derived channels to send to object classifier; use all if empty' - ) - # TODO: move to subclassed pipeline - # label_params: Union[LabelFromBoundarySegParams, None] = None - export_label_interm: bool = False - - -class RoiSetToObjectMapRecord(PipelineRecord): - pass -# class _ModelIds(BaseModel): -# pixel_classifier_segmentation: str -# object_classifier: Union[str, None] -# pixel_classifier_derived: Union[str, None] -# input_filename: str -# model_ids: _ModelIds - -@router.put('/roiset_to_obmap/infer') -def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord: - """ - Compute a RoiSet from 2d segmentation, apply to z-stack, and optionally apply object classification. - """ - # inpath = session.paths['inbound_images'] / p.input_filename - # validate_workflow_inputs([p.model_ids.pixel_classifier_segmentation, p.model_ids.object_classifier], [inpath]) - - # def _get_model_dict(mid): - # return { - # 'name': mid, - # 'model': session.models[mid]['object'], - # 'params': session.models[mid]['params'], - # } - - - record, dataframe = call_pipeline(roiset_object_map_pipeline, p) - - # TODO: try labeling this correctly as input accessor_id instead of filename - session.write_to_table('RoiSet', {'input_filename': p.accessor_id}, dataframe) - return record - - - # models = {'pixel_classifier_segmentation': _get_model_dict(p.model_ids.pixel_classifier_segmentation)} - # if p.model_ids.object_classifier is not None: - # models['object_classifier'] = _get_model_dict(p.model_ids.object_classifier) - # if p.model_ids.pixel_classifier_derived is not None: - # models['pixel_classifier_derived'] = _get_model_dict(p.model_ids.pixel_classifier_derived) - - # record = export_zstack_roiset(inpath, session.paths['outbound_images'], models, p) - # object_map_filepaths = [record['interm'][k] for k in record['interm'].keys() if k.startswith('object_classes_')] - - # df = record['dataframe'] - # session.write_to_table('RoiSet', {'input_filename': p.input_filename}, df) - # session.log_info(f'Completed classification of {p.input_filename}, recorded {len(df)} ROIs') - # - # resp = WorkflowRunRecord( - # pixel_model_id=p.model_ids.pixel_classifier_segmentation, - # object_model_id=p.model_ids.object_classifier, - # input_filepath=p.input_filename, - # pixel_map_filepath=record['interm']['mask'].__str__(), - # object_map_filepath=object_map_filepaths[0].__str__(), - # success=True, - # timer_results=record['timer_results'], - # ).dict() - # resp['interm'] = record['interm'] - # return resp - - - - -def roiset_object_map_pipeline( - accessors: Dict[str, GenericImageDataAccessor], - models: Dict[str, Model], - **k -) -> (RoiSetToObjectMapRecord, pd.DataFrame): - if not isinstance(models['pixel_classifier_segmentation'], SemanticSegmentationModel): - raise IncompatibleModelsError('Expecting a pixel classification model') - # if not isinstance(models['ob_model'], IlastikObjectClassifierFromPixelPredictionsModel): - # raise IncompatibleModelsError('Expecting an ilastik object classification from pixel predictions model') - # assert isinstance(models['pixel_classifier_segmentation']['model'], SemanticSegmentationModel) - - # ti = Timer() - # stack = generate_file_accessor(input_file_path) - # fstem = Path(input_file_path).stem - # ti.click('file_input') - - d = PipelineTrace() - d['raw'] = accessors['accessor'] - - # MIP if no zmask z-index is given, then classify pixels - zmi = k.get('zmask_zindex') - sch = k.get('segmentation_channel') - - if isinstance(zmi, int): - assert 0 < zmi < d.last.nz - d['mip'] = d.last.get_mono(channel=sch).get_zi(zmi) - else: - d['mip'] = d.last.get_mono(channel=sch).data.max(axis=-1, keepdims=True) - - d['mip_mask'] = models['pixel_classifier_segmentation'].label_pixel_class(d.last) - - # TODO: subclass this boundary pipeline in separate module - # # optionally label objects using a boundary segmentation pipeline - # if p.label_params: - # la_nda, interm = label_from_boundary_seg( - # mip_mask.data[:, :, 0, 0], - # p.label_params, - # return_interm=True - # ) - # labels = InMemoryDataAccessor(la_nda) - # rois = RoiSet(stack, labels, params=p.roi_params) - # else: - # rois = RoiSet.from_segmentation(stack, mip_mask, params=p.roi_params) - # ti.click('generate_zmasks') - - rois = RoiSet.from_segmentation(d['raw'], d.last, params=k.get('roi_params')) - - # optionally derive additional inputs for object classification - if dmod := models.get('pixel_classifier_derived'): - ic = k.get('derived_channels_input_channel') - ocs = k.get('derived_channels_output_channels') - assert ic < d['raw'].chroma - if not isinstance(dmod, SemanticSegmentationModel): - raise IncompatibleModelsError('Expecting pixel_classifier_derived to be a pixel classification model') - - def _derive(acc: GenericImageDataAccessor, oc: int): - acc_mono = acc.get_channels([ic]) - pxmap = dmod.infer_patch_stack(acc_mono)[0] - assert oc < pxmap.chroma - return PatchStack((pxmap.get_channels([oc]).data * 255).astype('uint8')) - derived_channel_handles = [ - lambda a: _derive(a, oc) for oc in ocs - ] - else: - derived_channel_handles = None - - # optionally classify if an object classifier is passed - if 'object_classifier' in models.keys(): - assert isinstance(models['object_classifier']['model'], InstanceSegmentationModel) - rois.classify_by( - models['object_classifier']['name'], # TODO: does model need a name? - [p.patches_channel], - models['object_classifier']['model'], - derived_channel_functions=derived_channel_handles - ) - ti.click('classify_objects') - - record = rois.run_exports( - Path(output_folder_path), - p.patches_channel, - fstem, - p.export_params, - ) - ti.click('export_roi_products') - - if p.label_params and p.export_label_interm: - fp_la_interm = Path(output_folder_path) / 'mip_masks' / f'interm_{fstem}.png' - plot_image_sequence_with_markers(interm, fp_la_interm, n_rows=3) - record['label_interm'] = fp_la_interm - - fp_mask = Path(output_folder_path) / 'mip_masks' / f'mask_{fstem}.tif' - write_accessor_data_to_file(fp_mask, mip_mask) - - unit_mask_data = (mip_mask.data > 0).astype('uint8') - fp_unit_mask = Path(output_folder_path) / 'unit_masks' / f'mask_{fstem}.tif' - write_accessor_data_to_file(fp_unit_mask, InMemoryDataAccessor(unit_mask_data)) - record['object_classes_all'] = fp_unit_mask.__str__() - record['mask'] = fp_mask - - # return { - # 'timer_results': ti.events, - # 'dataframe': rois.get_df(), - # 'interm': record, - # 'output_path': output_folder_path, - # } - - return record, df - -class Error(Exception): - pass - -class IncompatibleModelsError(Error): - pass \ No newline at end of file diff --git a/tests/base/test_pipelines.py b/tests/base/test_pipelines.py index bbcd2305..84563127 100644 --- a/tests/base/test_pipelines.py +++ b/tests/base/test_pipelines.py @@ -17,7 +17,7 @@ class TestSegmentationPipeline(unittest.TestCase): def test_call_pipeline_function(self): acc = generate_file_accessor(czifile['path']) - trace = segment.segment_pipeline(acc, self.model, channel=2, smooth=3) + trace = segment.segment_pipeline({'accessor': acc}, {'model': self.model}, channel=2, smooth=3) outfp = output_path / 'classify_pixels.tif' write_accessor_data_to_file(outfp, trace.last) diff --git a/tests/test_ilastik/test_roiset_workflow.py b/tests/test_ilastik/test_roiset_workflow.py new file mode 100644 index 00000000..62fb87f5 --- /dev/null +++ b/tests/test_ilastik/test_roiset_workflow.py @@ -0,0 +1,317 @@ +import shutil +from pathlib import Path +from shutil import copyfile +import unittest + +import numpy as np + +import model_server.conf.testing as conf + +from model_server.base.accessors import generate_file_accessor +from tests.base.test_model import DummyInstanceSegmentationModel +from model_server.base.roiset import RoiSetMetaParams, RoiSetExportParams + +import model_server.conf.testing as conf +# from extensions.chaeo.workflows import ClassifyZStackApiParams, export_zstack_roiset +from model_server.base.pipelines.roiset_obmap import RoiSetObjectMapParams, roiset_object_map_pipeline +from model_server.extensions.ilastik.models import IlastikPixelClassifierModel, IlastikPixelClassifierParams + +data = conf.meta['image_files'] +output_path = conf.meta['output_path'] +test_params = conf.meta['roiset'] +classifiers = conf.meta['ilastik_classifiers'] + +class BaseTestRoiSetMonoProducts(object): + + def setUp(self) -> None: + # set up test raw data and segmentation from file + self.stack = generate_file_accessor(data['multichannel_zstack_raw']['path']) + self.stack_ch_pa = self.stack.get_mono(test_params['patches_channel']) + self.seg_mask = generate_file_accessor(data['multichannel_zstack_mask2d']['path']) + + +def _get_export_params(): + return RoiSetExportParams(**{ + 'pixel_probabilities': True, + 'patches_3d': {}, + 'annotated_patches_2d': { + 'draw_bounding_box': True, + 'rgb_overlay_channels': [3, None, None], + 'rgb_overlay_weights': [0.2, 1.0, 1.0], + 'pad_to': 512, + }, + 'patches_2d': { + 'draw_bounding_box': False, + 'draw_mask': False, + }, + 'patch_masks': { + 'pad_to': 256, + }, + 'annotated_zstacks': {}, + 'object_classes': True, + 'dataframe': True, + }) + +def _get_roi_params(): + return RoiSetMetaParams(**{ + 'mask_type': 'boxes', + 'filters': { + 'area': {'min': 1e3, 'max': 1e8} + }, + 'expand_box_by': [128, 2] + }) + +# def _get_model_params(): +# return { +# 'pixel_classifier_segmentation': { +# # 'project_file': params['pixel_classifier'], +# 'project_file': classifiers['px']['path'], +# }, +# 'object_classifier': { +# 'name': 'dummy', +# } +# } + +def _get_models(): + + return { + 'pixel_classifier_segmentation': { + 'name': 'ilastik_px_mod', + 'project_file': classifiers['px']['path'], + 'model': IlastikPixelClassifierModel( + IlastikPixelClassifierParams( + project_file=classifiers['px']['path'].__str__() + ) + ) + }, + 'object_classifier': { + 'name': 'dummy_ob_mod', + 'model': DummyInstanceSegmentationModel() + }, + + } + + +class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase): + + def test_object_map_workflow(self): + acc_in = generate_file_accessor(data['multichannel_zstack_raw']['path']) + params = RoiSetObjectMapParams( + api=False, + accessor_id='acc_id', + pixel_classifier_segmentation_model_id='px_id', + object_classifier_model_id='ob_id', + segmentation_channel=test_params['segmentation_channel'], + patches_channel=test_params['patches_channel'], + roi_params=_get_roi_params(), + export_params=_get_export_params(), + ) + trace = roiset_object_map_pipeline( + {'accessor': acc_in}, + {f'{k}_model': v['model'] for k, v in _get_models().items()}, + **params.dict() + ) + + self.assertTrue('ob_id' in trace.keys()) + self.assertEqual(len(trace['ob_id']._unique()[0]), 2) + + +class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass): + + app_name = 'api:app' + + def _copy_input_file_to_server(self): + resp = self._get('paths') + pa = resp.json()['inbound_images'] + copyfile( + data['multichannel_zstack_raw']['path'], + Path(pa) / data['multichannel_zstack_raw']['name'] + ) + + def setUp(self) -> None: + self.where_out = output_path / 'trec-adaptive-feedback' / 'roiset' + self.where_out.mkdir(parents=True, exist_ok=True) + return super().setUp() + + def test_bounceback(self): + resp = self._put('chaeo/bounce_back', query={'par1': 'hello'}) + self.assertEqual(resp.status_code, 200, resp.json()) + self.assertEqual(resp.json()['params']['par1'], 'hello', resp.json()) + + def test_load_pixel_classifier(self): + resp = self._put( + 'ilastik/seg/load/', + body={'project_file': _get_model_params()['pixel_classifier_segmentation']['project_file'].__str__()}, + ) + model_id = resp.json()['model_id'] + self.assertTrue(model_id.startswith('IlastikPixelClassifierModel')) + return model_id + + def test_load_object_classifier(self): + resp = self._put(f'models/dummy_instance/load') + model_id = resp.json()['model_id'] + self.assertTrue(model_id.startswith('DummyInstanceSegmentationModel')) + return model_id + + def test_object_map_workflow(self): + mid_px = self.test_load_pixel_classifier() + mid_ob = self.test_load_object_classifier() + resp = self._put( + 'chaeo/classify_zstack/infer', + body=ClassifyZStackApiParams( + model_ids={ + 'pixel_classifier_segmentation': mid_px, + 'object_classifier': mid_ob, + }, + **{ + 'input_filename': data['multichannel_zstack_raw']['path'].__str__(), + 'segmentation_channel': 0, + 'patches_channel': 1, + 'roi_params': _get_roi_params(), + 'export_params': {'object_classes': True}, + }, + ).dict() + ) + self.assertEqual(resp.status_code, 200, resp.json()) + omfp = Path(resp.json()['object_map_filepath']) + self.assertTrue(omfp.exists()) + acc_obmap = generate_file_accessor(omfp) + self.assertTrue(all(np.unique(acc_obmap.data) == [0, 1])) + shutil.copy(omfp, self.where_out / f'normal_{omfp.name}') + + def test_workflow_without_object_classifier(self): + mid_px = self.test_load_pixel_classifier() + resp = self._put( + 'chaeo/classify_zstack/infer', + body=ClassifyZStackApiParams( + model_ids={ + 'pixel_classifier_segmentation': mid_px, + }, + **{ + 'input_filename': data['multichannel_zstack_raw']['path'].__str__(), + 'segmentation_channel': 0, + 'patches_channel': 1, + 'roi_params': _get_roi_params(), + 'export_params': {'object_classes': True}, + }, + ).dict() + ) + self.assertEqual(resp.status_code, 200, resp.json()) + omfp = Path(resp.json()['object_map_filepath']) + self.assertTrue(omfp.exists()) + acc_obmap = generate_file_accessor(omfp) + self.assertTrue(all(np.unique(acc_obmap.data) == [0, 1])) + shutil.copy(omfp, self.where_out / f'normal_{omfp.name}') + + def test_object_map_workflow_boundary_channel(self): + fp_in = data['multichannel_zstack_raw']['path'] + resp = self._put( + 'ilastik/seg/load/', + body={ + 'project_file': classifiers['px']['path'].__str__(), + 'px_class': 1, + 'px_prob_threshold': 0.5 + }, + ) + mid_px = resp.json()['model_id'] + mid_ob = self.test_load_object_classifier() + resp = self._put( + 'chaeo/classify_zstack/infer', + body=ClassifyZStackApiParams( + model_ids={ + 'pixel_classifier_segmentation': mid_px, + 'object_classifier': mid_ob, + }, + **{ + 'input_filename': fp_in.__str__(), + 'segmentation_channel': 0, + 'patches_channel': 1, + 'roi_params': _get_roi_params(), + 'export_params': {'object_classes': True}, + 'label_params': { + 'bbox_filter': {'area': {'min': 1e3, 'max': 1e8}}, + 'px_expand': 0, + 'n_dilate': 2, + 'marker_source': 'distmax', + 'min_marker_dist': 10, + 'background_method': 'threshold', + 'coarse_sig': 15, + 'background_tr': 0.2, + }, + 'export_label_interm': False, + }, + ).dict() + ) + self.assertEqual(resp.status_code, 200, resp.json()) + fp_obmap = Path(resp.json()['object_map_filepath']) + print(fp_obmap) + self.assertTrue(fp_obmap.exists()) + acc_obmap = generate_file_accessor(fp_obmap) + self.assertTrue(all(np.unique(acc_obmap.data) == [0, 1])) + + + +class TestRoiSetWithDerivedChannels(conf.TestServerBaseClass): + + app_name = 'api:app' + + def test_object_map_workflow_with_derived_channels(self): + models = _get_models() + models['pixel_classifier_derived'] = models['pixel_classifier_segmentation'] # re-use same classifier + where_out = output_path / 'roiset' / 'workflow' + + p = ClassifyZStackApiParams( + input_filename=str(data['multichannel_zstack_raw']['path']), + model_ids={ + 'pixel_classifier_segmentation': 'id_px', + 'object_classifier': 'id_ob', + 'pixel_classifier_derived': 'id_px', + }, + derived_channels_input_channel=0, + derived_channels_output_channels=[1], + segmentation_channel=test_params['segmentation_channel'], + patches_channel=test_params['patches_channel'], + roi_params=_get_roi_params(), + export_params=_get_export_params(), + ) + + record = export_zstack_roiset( + data['multichannel_zstack_raw']['path'], + where_out, + models, + p + ) + self.assertTrue(all([Path(pa).exists() for pa in record['interm']['derived_channels']])) + + def test_derived_channels_api(self): + resp = self._put( + 'ilastik/seg/load/', + body={'project_file': _get_model_params()['pixel_classifier_segmentation']['project_file'].__str__()}, + ) + self.assertEqual(resp.status_code, 200) + mid_px = resp.json()['model_id'] + resp = self._put(f'models/dummy_instance/load') + mid_ob = resp.json()['model_id'] + + + p = ClassifyZStackApiParams( + input_filename=data['multichannel_zstack_raw']['path'].__str__(), + model_ids={ + 'pixel_classifier_segmentation': mid_px, + 'object_classifier': mid_ob, + 'pixel_classifier_derived': mid_px, + }, + derived_channels_input_channel=0, + derived_channels_output_channels=[1], + segmentation_channel=test_params['segmentation_channel'], + patches_channel=test_params['patches_channel'], + roi_params=_get_roi_params(), + export_params=_get_export_params(), + ) + resp = self._put('chaeo/classify_zstack/infer', body=p.dict()) + self.assertEqual(resp.status_code, 200, resp.json()) + omfp = resp.json()['object_map_filepath'] + self.assertTrue(Path(omfp).exists()) + acc_obmap = generate_file_accessor(omfp) + self.assertTrue(all(np.unique(acc_obmap.data) == [0, 1])) + -- GitLab