Skip to content
Snippets Groups Projects

Updates for TREC pipelines

Merged Christopher Randolph Rhodes requested to merge int_trec into staging
26 files
+ 1193
617
Compare changes
  • Side-by-side
  • Inline
Files
26
from typing import Dict, Union
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field
from ..accessors import GenericImageDataAccessor
from .router import router
from .segment_zproj import segment_zproj_pipeline
from .shared import call_pipeline
from ..roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams
from ..session import session
from ..pipelines.shared import PipelineTrace, PipelineParams, PipelineRecord
from ..models import Model, InstanceSegmentationModel
from ..models import Model, InstanceMaskSegmentationModel
class RoiSetObjectMapParams(PipelineParams):
@@ -35,10 +34,6 @@ class RoiSetObjectMapParams(PipelineParams):
None,
description='Object classifier used to classify segmented objectss'
)
pixel_classifier_derived_model_id: Union[str, None] = Field(
None,
description='Pixel classifier used to derive channel(s) as additional inputs to object classification'
)
patches_channel: int = Field(
description='Channel of input image used in patches sent to object classifier'
)
@@ -55,35 +50,19 @@ class RoiSetObjectMapParams(PipelineParams):
'deproject_channel': None,
})
export_params: RoiSetExportParams = RoiSetExportParams()
derived_channels_input_channel: Union[int, None] = Field(
None,
description='Channel of input image from which to compute derived channels; use all if empty'
)
derived_channels_output_channels: Union[int, list] = Field(
None,
description='Derived channels to send to object classifier; use all if empty'
)
export_label_interm: bool = False
class RoiSetToObjectMapRecord(PipelineRecord):
roiset_table: dict
pass
@router.put('/roiset_to_obmap/infer')
def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord:
"""
Compute a RoiSet from 2d segmentation, apply to z-stack, and optionally apply object classification.
"""
record, rois = call_pipeline(roiset_object_map_pipeline, p)
table = rois.get_serializable_dataframe()
session.write_to_table('RoiSet', {'input_filename': p.accessor_id}, table)
ret = RoiSetToObjectMapRecord(
roiset_table=table.to_dict(),
**record.dict()
)
return ret
return call_pipeline(roiset_object_map_pipeline, p)
def roiset_object_map_pipeline(
@@ -91,11 +70,11 @@ def roiset_object_map_pipeline(
models: Dict[str, Model],
**k
) -> (PipelineTrace, RoiSet):
d = PipelineTrace(accessors['accessor'])
d = PipelineTrace(accessors[''])
d['mask'] = segment_zproj_pipeline(
accessors,
{'model': models['pixel_classifier_segmentation_model']},
{'': models['pixel_classifier_segmentation_']},
**k['segmentation'],
).last
@@ -107,9 +86,9 @@ def roiset_object_map_pipeline(
d[ki] = vi
# optionally run an object classifier if specified
if obmod := models.get('object_classifier_model'):
if obmod := models.get('object_classifier_'):
obmod_name = k['object_classifier_model_id']
assert isinstance(obmod, InstanceSegmentationModel)
assert isinstance(obmod, InstanceMaskSegmentationModel)
rois.classify_by(
obmod_name,
[k['patches_channel']],
Loading