-
Christopher Randolph Rhodes authoredChristopher Randolph Rhodes authored
shared.py 7.45 KiB
from collections import OrderedDict
from pathlib import Path
from time import perf_counter
from typing import List, Union
from fastapi import HTTPException
from pydantic import BaseModel, Field, root_validator
from ..accessors import GenericImageDataAccessor
from ..session import session, AccessorIdError
class PipelineParams(BaseModel):
keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session')
api: bool = Field(True, description='Validate parameters against server session and map HTTP errors if True')
@root_validator(pre=False)
def models_are_loaded(cls, dd):
for k, v in dd.items():
if dd['api'] and k.endswith('model_id') and v is not None:
if v not in session.describe_loaded_models().keys():
raise HTTPException(status_code=409, detail=f'Model with {k} = {v} has not been loaded')
return dd
@root_validator(pre=False)
def accessors_are_loaded(cls, dd):
for k, v in dd.items():
if dd['api'] and k.endswith('accessor_id'):
try:
info = session.get_accessor_info(v)
except AccessorIdError as e:
raise HTTPException(status_code=409, detail=str(e))
if not info['loaded']:
raise HTTPException(status_code=409, detail=f'Accessor with {k} = {v} has not been loaded')
return dd
class PipelineRecord(BaseModel):
output_accessor_id: str
interm_accessor_ids: Union[List[str], None]
success: bool
timer: dict
def call_pipeline(func, p: PipelineParams) -> PipelineRecord:
# match accessor IDs to loaded accessor objects
accessors_in = {}
# TODO: remove "model" from model.keys()
for k, v in p.dict().items():
if k.endswith('accessor_id'):
accessors_in[k.split('_id')[0]] = session.get_accessor(v, pop=True)
if len(accessors_in) == 0:
raise NoAccessorsFoundError('Expecting as least one valid accessor to run pipeline')
# use first validated accessor ID to name log file entries and derived accessors
input_description = [p.dict()[pk] for pk in p.dict().keys() if pk.endswith('accessor_id')][0]
models = {}
for k, v in p.dict().items():
if k.endswith('model_id') and v is not None:
models[k.split('_id')[0]] = session.models[v]['object']
# call the actual pipeline; expect a single PipelineTrace or a tuple where first element is PipelineTrace
ret = func(
accessors_in,
models,
**p.dict(),
)
if isinstance(ret, PipelineTrace):
steps = ret
misc = None
elif isinstance(ret, tuple) and isinstance(ret[0], PipelineTrace):
steps = ret[0]
misc = ret[1:]
else:
raise UnexpectedPipelineReturnError(
f'{func.__name__} returned unexpected value of {type(ret)}'
)
session.log_info(f'Completed {func.__name__} on {input_description}.')
# map intermediate data accessors to accessor IDs
if p.keep_interm:
interm_ids = []
acc_interm = steps.accessors(skip_first=True, skip_last=True).items()
for i, item in enumerate(acc_interm):
stk, acc = item
interm_ids.append(
session.add_accessor(
acc,
accessor_id=f'{input_description}_{func.__name__}_step{(i + 1):02d}_{stk}'
)
)
else:
interm_ids = None
# map final result to an accessor ID
result_id = session.add_accessor(
steps.last,
accessor_id=f'{p.accessor_id}_{func.__name__}_result'
)
record = PipelineRecord(
output_accessor_id=result_id,
interm_accessor_ids=interm_ids,
success=True,
timer=steps.times
)
# return miscellaneous objects if pipeline returns these
if misc:
return record, *misc
else:
return record
class PipelineTrace(OrderedDict):
tfunc = perf_counter
def __init__(self, start_acc: GenericImageDataAccessor = None, enforce_accessors=True, allow_overwrite=False):
"""
A container and timer for data at each stage of a pipeline.
:param start_acc: (optional) accessor to initialize as 'input' step
:param enforce_accessors: if True, only allow accessors to be appended as items
:param allow_overwrite: if True, allow an item to be overwritten
"""
self.enforce_accessors = enforce_accessors
self.allow_overwrite = allow_overwrite
self.last_time = self.tfunc()
self.markers = None
self.timer = OrderedDict()
super().__init__()
if start_acc is not None:
self['input'] = start_acc
def __setitem__(self, key, value: GenericImageDataAccessor):
if self.enforce_accessors:
assert isinstance(value, GenericImageDataAccessor), f'Pipeline trace expects data accessor type'
if not self.allow_overwrite and key in self.keys():
raise KeyAlreadyExists(f'key {key} already exists in pipeline trace')
self.timer.__setitem__(key, self.tfunc() - self.last_time)
self.last_time = self.tfunc()
return super().__setitem__(key, value)
@property
def times(self):
"""
Return an ordered dictionary of incremental times for each item that is appended
"""
return {k: self.timer[k] for k in self.keys()}
@property
def last(self):
"""
Return most recently appended item
:return:
"""
return list(self.values())[-1]
def accessors(self, skip_first=True, skip_last=True) -> dict:
"""
Return subset ordered dictionary that guarantees items are accessors
:param skip_first: if True, exclude first item in trace
:param skip_last: if False, exclude last item in trace
:return: dictionary of accessors that meet input criteria
"""
res = OrderedDict()
for i, item in enumerate(self.items()):
k, v = item
if not isinstance(v, GenericImageDataAccessor):
continue
if skip_first and k == list(self.keys())[0]:
continue
if skip_last and k == list(self.keys())[-1]:
continue
res[k] = v
return res
def write_interm(
self,
where: Path,
prefix: str = 'interm',
skip_first=True,
skip_last=True,
debug=False
) -> List[Path]:
"""
Write accessor data to TIF files under specified path
:param where: directory in which to write image files
:param prefix: (optional) file prefix
:param skip_first: if True, do not write first item in trace
:param skip_last: if False, do not write last item in trace
:param debug: if True, report destination filepaths but do not write files
:return: list of destination filepaths
"""
paths = []
accessors = self.accessors(skip_first=skip_first, skip_last=skip_last)
for i, item in enumerate(accessors.items()):
k, v = item
fp = where / f'{prefix}_{i:02d}_{k}.tif'
paths.append(fp)
if not debug:
v.write(fp)
return paths
class Error(Exception):
pass
class IncompatibleModelsError(Error):
pass
class KeyAlreadyExists(Error):
pass
class NoAccessorsFoundError(Error):
pass
class UnexpectedPipelineReturnError(Error):
pass