Skip to content
Snippets Groups Projects
Commit 976206d7 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Put simplest functional RoiSet pipeline model_server for now; derived channel...

Put simplest functional RoiSet pipeline model_server for now; derived channel and boundary seg alternatives in trec-adaptive-feedback
parent cba0f358
No related branches found
No related tags found
No related merge requests found
......@@ -7,8 +7,9 @@ from pydantic import Field
from model_server.base.accessors import GenericImageDataAccessor, PatchStack
from model_server.base.pipelines.params import PipelineParams, PipelineRecord
from model_server.base.pipelines.segment import segment_pipeline
from model_server.base.pipelines.util import call_pipeline
from model_server.base.roiset import RoiSet, RoiSetMetaParams, RoiSetExportParams
from model_server.base.roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams
from model_server.base.session import session
from model_server.base.util import PipelineTrace
......@@ -73,8 +74,6 @@ class RoiSetObjectMapParams(PipelineParams):
None,
description='Derived channels to send to object classifier; use all if empty'
)
# TODO: move to subclassed pipeline
# label_params: Union[LabelFromBoundarySegParams, None] = None
export_label_interm: bool = False
......@@ -92,7 +91,6 @@ def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord:
assert isinstance(rois, RoiSet)
table = rois.get_serializable_dataframe()
# TODO: instead, explicitly let trace include RoiSets
session.write_to_table('RoiSet', {'input_filename': p.accessor_id}, table)
ret = RoiSetToObjectMapRecord(
......@@ -101,20 +99,6 @@ def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord:
)
return ret
# def _segment(d: PipelineTrace, models: Dict[str, Model], **k) -> PipelineTrace:
# # MIP if no zmask z-index is given, then classify pixels
# zmi = k['zmask_zindex']
# sch = k['segmentation_channel']
#
# if isinstance(zmi, int):
# assert 0 < zmi < d.last.nz
# d['mip'] = d.last.get_mono(channel=sch).get_zi(zmi)
# else:
# d['mip'] = d.last.get_mono(channel=sch).apply(lambda x: x.max(axis=-1, keepdims=True))
#
# d['mip_mask'] = models['pixel_classifier_segmentation_model'].label_pixel_class(d.last)
# return d
def roiset_object_map_pipeline(
accessors: Dict[str, GenericImageDataAccessor],
......@@ -124,10 +108,7 @@ def roiset_object_map_pipeline(
if not isinstance(models['pixel_classifier_segmentation_model'], SemanticSegmentationModel):
raise IncompatibleModelsError('Expecting a pixel classification model')
d = PipelineTrace()
d['raw'] = accessors['accessor']
# d = _segment(d, models, **k)
d = PipelineTrace(accessors['accessor'])
zmi = k['zmask_zindex']
sch = k['segmentation_channel']
......@@ -138,28 +119,8 @@ def roiset_object_map_pipeline(
d['mip'] = d.last.get_mono(channel=sch).apply(lambda x: x.max(axis=-1, keepdims=True))
d['mip_mask'] = models['pixel_classifier_segmentation_model'].label_pixel_class(d.last)
rois = RoiSet.from_binary_mask(d['raw'], d.last, RoiSetMetaParams(**k['roi_params']))
# TODO: split out derived channel logic
# # optionally derive additional inputs for object classification
# if dpmod := models.get('pixel_classifier_derived_model'):
# ic = k['derived_channels_input_channel']
# ocs = k['derived_channels_output_channels']
# assert ic < d['raw'].chroma
# if not isinstance(dpmod, SemanticSegmentationModel):
# raise IncompatibleModelsError('Expecting pixel_classifier_derived to be a pixel classification model')
#
# def _derive(acc: GenericImageDataAccessor, oc: int):
# acc_mono = acc.get_channels([ic])
# pxmap = dpmod.infer_patch_stack(acc_mono)[0]
# assert oc < pxmap.chroma
# return PatchStack((pxmap.get_channels([oc]).data * 255).astype('uint8'))
# derived_channel_handles = [
# lambda a: _derive(a, oc) for oc in ocs
# ]
# else:
# derived_channel_handles = None
d['segmentation'] = get_label_ids(d.last)
rois = RoiSet.from_object_ids(d['input'], d['segmentation'], RoiSetMetaParams(**k['roi_params']))
# optionally run an object classifier if specified
if obmod := models.get('object_classifier_model'):
......@@ -169,16 +130,9 @@ def roiset_object_map_pipeline(
obmod_name,
[k['patches_channel']],
obmod,
# derived_channel_functions=derived_channel_handles
)
# TODO: split out derived channel logic
# # add derived channels to intermediate data
# for di, dacc in enumerate(rois.accs_derived):
# d[f'derived_{di:02d}'] = dacc
d[obmod_name] = rois.get_object_class_map(obmod_name)
return d, rois
class Error(Exception):
......
......@@ -118,27 +118,27 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):
self.assertEqual(len(trace['ob_id']._unique()[0]), 2)
def test_object_map_workflow_with_derived_channels(self):
acc_in = generate_file_accessor(self.fpi)
params = RoiSetObjectMapParams(
**self._pipeline_params(),
pixel_classifier_derived_model_id='id_px',
derived_channels_input_channel=0,
derived_channels_output_channels=[1],
)
models = self._get_models()
models['pixel_classifier_derived'] = models['pixel_classifier_segmentation'] # re-use same classifier
trace, _ = roiset_object_map_pipeline(
{'accessor': acc_in},
{f'{k}_model': v['model'] for k, v in models.items()},
**params.dict()
)
self.assertTrue('ob_id' in trace.keys())
self.assertEqual(len(trace['ob_id']._unique()[0]), 2)
self.assertTrue('derived_00' in trace.keys())
self.assertEqual(trace['derived_00'].chroma, 1)
# def test_object_map_workflow_with_derived_channels(self):
# acc_in = generate_file_accessor(self.fpi)
# params = RoiSetObjectMapParams(
# **self._pipeline_params(),
# pixel_classifier_derived_model_id='id_px',
# derived_channels_input_channel=0,
# derived_channels_output_channels=[1],
# )
# models = self._get_models()
# models['pixel_classifier_derived'] = models['pixel_classifier_segmentation'] # re-use same classifier
#
# trace, _ = roiset_object_map_pipeline(
# {'accessor': acc_in},
# {f'{k}_model': v['model'] for k, v in models.items()},
# **params.dict()
# )
#
# self.assertTrue('ob_id' in trace.keys())
# self.assertEqual(len(trace['ob_id']._unique()[0]), 2)
# self.assertTrue('derived_00' in trace.keys())
# self.assertEqual(trace['derived_00'].chroma, 1)
class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment