from collections import OrderedDict from pathlib import Path from time import perf_counter from typing import List, Union from fastapi import HTTPException from pydantic import BaseModel, Field, root_validator from ..accessors import GenericImageDataAccessor from ..session import session, AccessorIdError class PipelineParams(BaseModel): keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session') api: bool = Field(True, description='Validate parameters against server session and map HTTP errors if True') @root_validator(pre=False) def models_are_loaded(cls, dd): for k, v in dd.items(): if dd['api'] and k.endswith('model_id') and v is not None: if v not in session.describe_loaded_models().keys(): raise HTTPException(status_code=409, detail=f'Model with {k} = {v} has not been loaded') return dd @root_validator(pre=False) def accessors_are_loaded(cls, dd): for k, v in dd.items(): if dd['api'] and k.endswith('accessor_id'): try: info = session.get_accessor_info(v) except AccessorIdError as e: raise HTTPException(status_code=409, detail=str(e)) if not info['loaded']: raise HTTPException(status_code=409, detail=f'Accessor with {k} = {v} has not been loaded') return dd class PipelineRecord(BaseModel): output_accessor_id: str interm_accessor_ids: Union[List[str], None] success: bool timer: dict def call_pipeline(func, p: PipelineParams) -> PipelineRecord: # match accessor IDs to loaded accessor objects accessors_in = {} # TODO: remove "model" from model.keys() for k, v in p.dict().items(): if k.endswith('accessor_id'): accessors_in[k.split('_id')[0]] = session.get_accessor(v, pop=True) if len(accessors_in) == 0: raise NoAccessorsFoundError('Expecting as least one valid accessor to run pipeline') # use first validated accessor ID to name log file entries and derived accessors input_description = [p.dict()[pk] for pk in p.dict().keys() if pk.endswith('accessor_id')][0] models = {} for k, v in p.dict().items(): if k.endswith('model_id') and v is not None: models[k.split('_id')[0]] = session.models[v]['object'] # call the actual pipeline; expect a single PipelineTrace or a tuple where first element is PipelineTrace ret = func( accessors_in, models, **p.dict(), ) if isinstance(ret, PipelineTrace): steps = ret misc = None elif isinstance(ret, tuple) and isinstance(ret[0], PipelineTrace): steps = ret[0] misc = ret[1:] else: raise UnexpectedPipelineReturnError( f'{func.__name__} returned unexpected value of {type(ret)}' ) session.log_info(f'Completed {func.__name__} on {input_description}.') # map intermediate data accessors to accessor IDs if p.keep_interm: interm_ids = [] acc_interm = steps.accessors(skip_first=True, skip_last=True).items() for i, item in enumerate(acc_interm): stk, acc = item interm_ids.append( session.add_accessor( acc, accessor_id=f'{input_description}_{func.__name__}_step{(i + 1):02d}_{stk}' ) ) else: interm_ids = None # map final result to an accessor ID result_id = session.add_accessor( steps.last, accessor_id=f'{p.accessor_id}_{func.__name__}_result' ) record = PipelineRecord( output_accessor_id=result_id, interm_accessor_ids=interm_ids, success=True, timer=steps.times ) # return miscellaneous objects if pipeline returns these if misc: return record, *misc else: return record class PipelineTrace(OrderedDict): tfunc = perf_counter def __init__(self, start_acc: GenericImageDataAccessor = None, enforce_accessors=True, allow_overwrite=False): """ A container and timer for data at each stage of a pipeline. :param start_acc: (optional) accessor to initialize as 'input' step :param enforce_accessors: if True, only allow accessors to be appended as items :param allow_overwrite: if True, allow an item to be overwritten """ self.enforce_accessors = enforce_accessors self.allow_overwrite = allow_overwrite self.last_time = self.tfunc() self.markers = None self.timer = OrderedDict() super().__init__() if start_acc is not None: self['input'] = start_acc def __setitem__(self, key, value: GenericImageDataAccessor): if self.enforce_accessors: assert isinstance(value, GenericImageDataAccessor), f'Pipeline trace expects data accessor type' if not self.allow_overwrite and key in self.keys(): raise KeyAlreadyExists(f'key {key} already exists in pipeline trace') self.timer.__setitem__(key, self.tfunc() - self.last_time) self.last_time = self.tfunc() return super().__setitem__(key, value) @property def times(self): """ Return an ordered dictionary of incremental times for each item that is appended """ return {k: self.timer[k] for k in self.keys()} @property def last(self): """ Return most recently appended item :return: """ return list(self.values())[-1] def accessors(self, skip_first=True, skip_last=True) -> dict: """ Return subset ordered dictionary that guarantees items are accessors :param skip_first: if True, exclude first item in trace :param skip_last: if False, exclude last item in trace :return: dictionary of accessors that meet input criteria """ res = OrderedDict() for i, item in enumerate(self.items()): k, v = item if not isinstance(v, GenericImageDataAccessor): continue if skip_first and k == list(self.keys())[0]: continue if skip_last and k == list(self.keys())[-1]: continue res[k] = v return res def write_interm( self, where: Path, prefix: str = 'interm', skip_first=True, skip_last=True, debug=False ) -> List[Path]: """ Write accessor data to TIF files under specified path :param where: directory in which to write image files :param prefix: (optional) file prefix :param skip_first: if True, do not write first item in trace :param skip_last: if False, do not write last item in trace :param debug: if True, report destination filepaths but do not write files :return: list of destination filepaths """ paths = [] accessors = self.accessors(skip_first=skip_first, skip_last=skip_last) for i, item in enumerate(accessors.items()): k, v = item fp = where / f'{prefix}_{i:02d}_{k}.tif' paths.append(fp) if not debug: v.write(fp) return paths class Error(Exception): pass class IncompatibleModelsError(Error): pass class KeyAlreadyExists(Error): pass class NoAccessorsFoundError(Error): pass class UnexpectedPipelineReturnError(Error): pass