Skip to content
Snippets Groups Projects
Commit 7aa5e80e authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

workflow test passes with nontrivial RoiSet

parent 8ed6d1fb
No related branches found
No related tags found
No related merge requests found
from typing import Dict, Union
from fastapi import APIRouter
import pandas as pd
from pydantic import Field
from pydantic import BaseModel, Field
from ..accessors import GenericImageDataAccessor
from ..pipelines.segment_zproj import segment_zproj_pipeline
from ..pipelines.shared import call_pipeline
from ..roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams
from ..session import session
from model_server.base.accessors import GenericImageDataAccessor, PatchStack
from model_server.base.pipelines.params import PipelineParams, PipelineRecord
from model_server.base.pipelines.segment import segment_pipeline
from model_server.base.pipelines.util import call_pipeline
from model_server.base.roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams
from model_server.base.session import session
from model_server.base.util import PipelineTrace
from ..pipelines.shared import PipelineTrace, PipelineParams, PipelineRecord
from model_server.base.models import Model, InstanceSegmentationModel, SemanticSegmentationModel
from ..models import Model, InstanceSegmentationModel
router = APIRouter(
prefix='/pipelines',
......@@ -21,6 +19,16 @@ router = APIRouter(
)
class RoiSetObjectMapParams(PipelineParams):
class _SegmentationParams(BaseModel):
channel: int = Field(
None,
description='Channel of input image to use for solving segmentation; use all channels if empty'
)
zi: Union[int, None] = Field(
None,
description='z coordinate to use on input image when solving segmentation; apply MIP if empty',
)
accessor_id: str = Field(
description='ID(s) of previously loaded accessor(s) to use as pipeline input'
)
......@@ -35,15 +43,12 @@ class RoiSetObjectMapParams(PipelineParams):
None,
description='Pixel classifier used to derive channel(s) as additional inputs to object classification'
)
segmentation_channel: int = Field(
description='Channel of input image to use for solving segmentation'
)
patches_channel: int = Field(
description='Channel of input image used in patches sent to object classifier'
)
zmask_zindex: Union[int, None] = Field(
None,
description='z coordinate to use on input image when solving segmentation; apply MIP if empty',
segmentation: _SegmentationParams = Field(
_SegmentationParams(),
description='Parameters used to solve segmentation'
)
roi_params: RoiSetMetaParams = RoiSetMetaParams(**{
'mask_type': 'boxes',
......@@ -105,22 +110,17 @@ def roiset_object_map_pipeline(
models: Dict[str, Model],
**k
) -> (PipelineTrace, RoiSet):
if not isinstance(models['pixel_classifier_segmentation_model'], SemanticSegmentationModel):
raise IncompatibleModelsError('Expecting a pixel classification model')
d = PipelineTrace(accessors['accessor'])
zmi = k['zmask_zindex']
sch = k['segmentation_channel']
if isinstance(zmi, int):
assert 0 < zmi < d.last.nz
d['mip'] = d.last.get_mono(channel=sch).get_zi(zmi)
else:
d['mip'] = d.last.get_mono(channel=sch).apply(lambda x: x.max(axis=-1, keepdims=True))
d['mask'] = segment_zproj_pipeline(
accessors,
{'model': models['pixel_classifier_segmentation_model']},
**k['segmentation'],
).last
d['mip_mask'] = models['pixel_classifier_segmentation_model'].label_pixel_class(d.last)
d['segmentation'] = get_label_ids(d.last)
rois = RoiSet.from_object_ids(d['input'], d['segmentation'], RoiSetMetaParams(**k['roi_params']))
# d['mask'] = models['pixel_classifier_segmentation_model'].label_pixel_class(d.last)
d['labeled'] = get_label_ids(d.last)
rois = RoiSet.from_object_ids(d['input'], d['labeled'], RoiSetMetaParams(**k['roi_params']))
# optionally run an object classifier if specified
if obmod := models.get('object_classifier_model'):
......@@ -134,9 +134,3 @@ def roiset_object_map_pipeline(
d[obmod_name] = rois.get_object_class_map(obmod_name)
return d, rois
class Error(Exception):
pass
class IncompatibleModelsError(Error):
pass
\ No newline at end of file
......@@ -65,7 +65,7 @@ class BaseTestRoiSetMonoProducts(object):
return {
'mask_type': 'boxes',
'filters': {
'area': {'min': 1e3, 'max': 1e8}
'area': {'min': 1e0, 'max': 1e8}
},
'expand_box_by': [128, 2]
}
......@@ -97,7 +97,9 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):
'accessor_id': 'acc_id',
'pixel_classifier_segmentation_model_id': 'px_id',
'object_classifier_model_id': 'ob_id',
'segmentation_channel': test_params['segmentation_channel'],
'segmentation': {
'channel': test_params['segmentation_channel'],
},
'patches_channel': test_params['patches_channel'],
'roi_params': self._get_roi_params(),
'export_params': self._get_export_params(),
......@@ -108,13 +110,16 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):
params = RoiSetObjectMapParams(
**self._pipeline_params(),
)
trace, _ = roiset_object_map_pipeline(
trace, rois = roiset_object_map_pipeline(
{'accessor': acc_in},
{f'{k}_model': v['model'] for k, v in self._get_models().items()},
**params.dict()
)
trace.write_interm(Path(output_path) / 'trace', 'roiset_worfklow_trace', skip_first=False, skip_last=False)
self.assertTrue('ob_id' in trace.keys())
self.assertEqual(len(trace['labeled']._unique()[0]), 14)
self.assertEqual(rois.count, 13)
self.assertEqual(len(trace['ob_id']._unique()[0]), 2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment