Skip to content
Snippets Groups Projects
Commit 037231bc authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Pipeline workflow detects optional RoiSet return and maps to session roiset_id

parent 9daeeb97
No related branches found
No related tags found
No related merge requests found
......@@ -56,23 +56,23 @@ class RoiSetObjectMapParams(PipelineParams):
class RoiSetToObjectMapRecord(PipelineRecord):
roiset_table: dict
pass
@router.put('/roiset_to_obmap/infer')
def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord:
"""
Compute a RoiSet from 2d segmentation, apply to z-stack, and optionally apply object classification.
"""
record, rois = call_pipeline(roiset_object_map_pipeline, p)
table = rois.get_serializable_dataframe()
session.write_to_table('RoiSet', {'input_filename': p.accessor_id}, table)
ret = RoiSetToObjectMapRecord(
roiset_table=table.to_dict(),
**record.dict()
)
return ret
return call_pipeline(roiset_object_map_pipeline, p)
# table = rois.get_serializable_dataframe()
#
# session.write_to_table('RoiSet', {'input_filename': p.accessor_id}, table)
# ret = RoiSetToObjectMapRecord(
# roiset_table=table.to_dict(),
# **record.dict()
# )
# return ret
def roiset_object_map_pipeline(
......
......@@ -7,6 +7,7 @@ from fastapi import HTTPException
from pydantic import BaseModel, Field, root_validator
from ..accessors import GenericImageDataAccessor
from ..roiset import RoiSet
from ..session import session, AccessorIdError
......@@ -40,6 +41,7 @@ class PipelineRecord(BaseModel):
interm_accessor_ids: Union[List[str], None]
success: bool
timer: dict
roiset_id: Union[str, None] = None
def call_pipeline(func, p: PipelineParams) -> PipelineRecord:
......@@ -67,11 +69,11 @@ def call_pipeline(func, p: PipelineParams) -> PipelineRecord:
**p.dict(),
)
if isinstance(ret, PipelineTrace):
steps = ret
misc = None
elif isinstance(ret, tuple) and isinstance(ret[0], PipelineTrace):
steps = ret[0]
misc = ret[1:]
trace = ret
roiset_id = None
elif isinstance(ret, tuple) and isinstance(ret[0], PipelineTrace) and isinstance(ret[1], RoiSet):
trace = ret[0]
roiset_id = session.add_roiset(ret[1])
else:
raise UnexpectedPipelineReturnError(
f'{func.__name__} returned unexpected value of {type(ret)}'
......@@ -81,7 +83,7 @@ def call_pipeline(func, p: PipelineParams) -> PipelineRecord:
# map intermediate data accessors to accessor IDs
if p.keep_interm:
interm_ids = []
acc_interm = steps.accessors(skip_first=True, skip_last=True).items()
acc_interm = trace.accessors(skip_first=True, skip_last=True).items()
for i, item in enumerate(acc_interm):
stk, acc = item
interm_ids.append(
......@@ -95,7 +97,7 @@ def call_pipeline(func, p: PipelineParams) -> PipelineRecord:
# map final result to an accessor ID
result_id = session.add_accessor(
steps.last,
trace.last,
accessor_id=f'{p.accessor_id}_{func.__name__}_result'
)
......@@ -103,14 +105,11 @@ def call_pipeline(func, p: PipelineParams) -> PipelineRecord:
output_accessor_id=result_id,
interm_accessor_ids=interm_ids,
success=True,
timer=steps.times
timer=trace.times,
roiset_id=roiset_id,
)
# return miscellaneous objects if pipeline returns these
if misc:
return record, *misc
else:
return record
return record
class PipelineTrace(OrderedDict):
......
......@@ -95,7 +95,7 @@ class _Session(object):
raise AccessorIdError(f'Access with ID {accessor_id} already exists')
if accessor_id is None:
idx = len(self.accessors)
accessor_id = f'auto_{idx:06d}'
accessor_id = f'acc_{idx:06d}'
self.accessors[accessor_id] = {'loaded': True, 'object': acc, **acc.info}
return accessor_id
......@@ -200,7 +200,7 @@ class _Session(object):
raise AccessorIdError(f'RoiSet with ID {roiset_id} already exists')
if roiset_id is None:
idx = len(self.rois)
roiset_id = f'auto_{idx:06d}'
roiset_id = f'roiset_{idx:06d}'
self.rois[roiset_id] = {'loaded': True, 'object': roiset, **roiset.info}
return roiset_id
......
......@@ -159,7 +159,6 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
self.assertEqual(len(self.assertGetSuccess(f'accessors/delete/*')), 1)
self.assertEqual(sum([v['loaded'] for v in self.assertGetSuccess('accessors').values()]), 0)
def test_empty_accessor_list(self):
r_list = self.assertGetSuccess(
f'accessors',
......
......@@ -159,7 +159,7 @@ class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProd
return mid
def _object_map_workflow(self, ob_classifer_id):
oid = self.assertPutSuccess(
res = self.assertPutSuccess(
'pipelines/roiset_to_obmap/infer',
body={
'accessor_id': self.test_load_input_accessor(),
......@@ -170,7 +170,8 @@ class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProd
'roi_params': self._get_roi_params(),
'export_params': self._get_export_params(),
},
)['output_accessor_id']
)
oid = res['output_accessor_id']
obmap_fn = self.assertPutSuccess(f'/accessors/write_to_file/{oid}')
where_out = self.assertGetSuccess('paths')['outbound_images']
obmap_fp = Path(where_out) / obmap_fn
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment