diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py index 383136755d886be4e2ff31e3f48aff96a65fc763..2b533f44be3c16104e23e6fa3c0658fae01426ec 100644 --- a/model_server/base/pipelines/roiset_obmap.py +++ b/model_server/base/pipelines/roiset_obmap.py @@ -1,19 +1,17 @@ from typing import Dict, Union from fastapi import APIRouter -import pandas as pd -from pydantic import Field +from pydantic import BaseModel, Field +from ..accessors import GenericImageDataAccessor +from ..pipelines.segment_zproj import segment_zproj_pipeline +from ..pipelines.shared import call_pipeline +from ..roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams +from ..session import session -from model_server.base.accessors import GenericImageDataAccessor, PatchStack -from model_server.base.pipelines.params import PipelineParams, PipelineRecord -from model_server.base.pipelines.segment import segment_pipeline -from model_server.base.pipelines.util import call_pipeline -from model_server.base.roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams -from model_server.base.session import session -from model_server.base.util import PipelineTrace +from ..pipelines.shared import PipelineTrace, PipelineParams, PipelineRecord -from model_server.base.models import Model, InstanceSegmentationModel, SemanticSegmentationModel +from ..models import Model, InstanceSegmentationModel router = APIRouter( prefix='/pipelines', @@ -21,6 +19,16 @@ router = APIRouter( ) class RoiSetObjectMapParams(PipelineParams): + class _SegmentationParams(BaseModel): + channel: int = Field( + None, + description='Channel of input image to use for solving segmentation; use all channels if empty' + ) + zi: Union[int, None] = Field( + None, + description='z coordinate to use on input image when solving segmentation; apply MIP if empty', + ) + accessor_id: str = Field( description='ID(s) of previously loaded accessor(s) to use as pipeline input' ) @@ -35,15 +43,12 @@ class RoiSetObjectMapParams(PipelineParams): None, description='Pixel classifier used to derive channel(s) as additional inputs to object classification' ) - segmentation_channel: int = Field( - description='Channel of input image to use for solving segmentation' - ) patches_channel: int = Field( description='Channel of input image used in patches sent to object classifier' ) - zmask_zindex: Union[int, None] = Field( - None, - description='z coordinate to use on input image when solving segmentation; apply MIP if empty', + segmentation: _SegmentationParams = Field( + _SegmentationParams(), + description='Parameters used to solve segmentation' ) roi_params: RoiSetMetaParams = RoiSetMetaParams(**{ 'mask_type': 'boxes', @@ -105,22 +110,17 @@ def roiset_object_map_pipeline( models: Dict[str, Model], **k ) -> (PipelineTrace, RoiSet): - if not isinstance(models['pixel_classifier_segmentation_model'], SemanticSegmentationModel): - raise IncompatibleModelsError('Expecting a pixel classification model') - d = PipelineTrace(accessors['accessor']) - zmi = k['zmask_zindex'] - sch = k['segmentation_channel'] - if isinstance(zmi, int): - assert 0 < zmi < d.last.nz - d['mip'] = d.last.get_mono(channel=sch).get_zi(zmi) - else: - d['mip'] = d.last.get_mono(channel=sch).apply(lambda x: x.max(axis=-1, keepdims=True)) + d['mask'] = segment_zproj_pipeline( + accessors, + {'model': models['pixel_classifier_segmentation_model']}, + **k['segmentation'], + ).last - d['mip_mask'] = models['pixel_classifier_segmentation_model'].label_pixel_class(d.last) - d['segmentation'] = get_label_ids(d.last) - rois = RoiSet.from_object_ids(d['input'], d['segmentation'], RoiSetMetaParams(**k['roi_params'])) + # d['mask'] = models['pixel_classifier_segmentation_model'].label_pixel_class(d.last) + d['labeled'] = get_label_ids(d.last) + rois = RoiSet.from_object_ids(d['input'], d['labeled'], RoiSetMetaParams(**k['roi_params'])) # optionally run an object classifier if specified if obmod := models.get('object_classifier_model'): @@ -134,9 +134,3 @@ def roiset_object_map_pipeline( d[obmod_name] = rois.get_object_class_map(obmod_name) return d, rois - -class Error(Exception): - pass - -class IncompatibleModelsError(Error): - pass \ No newline at end of file diff --git a/tests/test_ilastik/test_roiset_workflow.py b/tests/test_ilastik/test_roiset_workflow.py index bfe51ed0dd166eac173ffbb2319e07604b87c7f2..31b0c0d6e4a3b590c23f629120c0b981020e478b 100644 --- a/tests/test_ilastik/test_roiset_workflow.py +++ b/tests/test_ilastik/test_roiset_workflow.py @@ -65,7 +65,7 @@ class BaseTestRoiSetMonoProducts(object): return { 'mask_type': 'boxes', 'filters': { - 'area': {'min': 1e3, 'max': 1e8} + 'area': {'min': 1e0, 'max': 1e8} }, 'expand_box_by': [128, 2] } @@ -97,7 +97,9 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase): 'accessor_id': 'acc_id', 'pixel_classifier_segmentation_model_id': 'px_id', 'object_classifier_model_id': 'ob_id', - 'segmentation_channel': test_params['segmentation_channel'], + 'segmentation': { + 'channel': test_params['segmentation_channel'], + }, 'patches_channel': test_params['patches_channel'], 'roi_params': self._get_roi_params(), 'export_params': self._get_export_params(), @@ -108,13 +110,16 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase): params = RoiSetObjectMapParams( **self._pipeline_params(), ) - trace, _ = roiset_object_map_pipeline( + trace, rois = roiset_object_map_pipeline( {'accessor': acc_in}, {f'{k}_model': v['model'] for k, v in self._get_models().items()}, **params.dict() ) + trace.write_interm(Path(output_path) / 'trace', 'roiset_worfklow_trace', skip_first=False, skip_last=False) self.assertTrue('ob_id' in trace.keys()) + self.assertEqual(len(trace['labeled']._unique()[0]), 14) + self.assertEqual(rois.count, 13) self.assertEqual(len(trace['ob_id']._unique()[0]), 2)