from collections import OrderedDict from pathlib import Path from time import perf_counter from typing import List, Union from fastapi import HTTPException from numba.scripts.generate_lower_listing import description from pydantic import BaseModel, Field, root_validator from ..accessors import GenericImageDataAccessor, InMemoryDataAccessor from ..roiset import RoiSet from ..session import session, AccessorIdError class PipelineParams(BaseModel): schedule: bool = Field(False, description='Schedule as a task instead of running immediately') keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session') api: bool = Field(True, description='Validate parameters against server session and map HTTP errors if True') @root_validator(pre=False) def models_are_loaded(cls, dd): for k, v in dd.items(): if dd['api'] and k.endswith('model_id') and v is not None: if v not in session.describe_loaded_models().keys(): raise HTTPException(status_code=409, detail=f'Model with {k} = {v} has not been loaded') return dd @root_validator(pre=False) def accessors_are_loaded(cls, dd): for k, v in dd.items(): if dd['api'] and k.endswith('accessor_id'): try: info = session.get_accessor_info(v) except AccessorIdError as e: raise HTTPException(status_code=409, detail=str(e)) if not info['loaded']: raise HTTPException(status_code=409, detail=f'Accessor with {k} = {v} has not been loaded') return dd class PipelineRecord(BaseModel): output_accessor_id: str interm_accessor_ids: Union[List[str], None] success: bool timer: dict roiset_id: Union[str, None] = None class PipelineQueueRecord(BaseModel): task_id: str def call_pipeline(func, p: PipelineParams) -> Union[PipelineRecord, PipelineQueueRecord]: # instead of running right away, schedule pipeline as a task if p.schedule: p.schedule = False task_id = session.queue.add_task( lambda x: call_pipeline(func, x), p ) return PipelineQueueRecord(task_id=task_id) # match accessor IDs to loaded accessor objects accessors_in = {} for k, v in p.dict().items(): if k.endswith('accessor_id'): accessors_in[k.split('accessor_id')[0]] = session.get_accessor(v, pop=True) if len(accessors_in) == 0: raise NoAccessorsFoundError('Expecting as least one valid accessor to run pipeline') # use first validated accessor ID to name log file entries and derived accessors input_description = [p.dict()[pk] for pk in p.dict().keys() if pk.endswith('accessor_id')][0] models = {} for k, v in p.dict().items(): if k.endswith('model_id') and v is not None: models[k.split('model_id')[0]] = session.models[v]['object'] # call the actual pipeline; expect a single PipelineTrace or a tuple where first element is PipelineTrace ret = func( accessors_in, models, **p.dict(), ) if isinstance(ret, PipelineTrace): trace = ret roiset_id = None elif isinstance(ret, tuple) and isinstance(ret[0], PipelineTrace) and isinstance(ret[1], RoiSet): trace = ret[0] roiset_id = session.add_roiset(ret[1]) else: raise UnexpectedPipelineReturnError( f'{func.__name__} returned unexpected value of {type(ret)}' ) session.log_info(f'Completed {func.__name__} on {input_description}.') # map intermediate data accessors to accessor IDs if p.keep_interm: interm_ids = [] acc_interm = trace.accessors(skip_first=True, skip_last=True).items() for i, item in enumerate(acc_interm): stk, acc = item interm_ids.append( session.add_accessor( acc, accessor_id=f'{input_description}_{func.__name__}_step{(i + 1):02d}_{stk}' ) ) else: interm_ids = None # map final result to an accessor ID result_id = session.add_accessor( trace.last, accessor_id=f'{p.accessor_id}_{func.__name__}_result' ) record = PipelineRecord( output_accessor_id=result_id, interm_accessor_ids=interm_ids, success=True, timer=trace.times, roiset_id=roiset_id, ) return record class PipelineTrace(OrderedDict): tfunc = perf_counter def __init__(self, start_acc: GenericImageDataAccessor = None, enforce_accessors=True, allow_overwrite=False): """ A container and timer for data at each stage of a pipeline. :param start_acc: (optional) accessor to initialize as 'input' step :param enforce_accessors: if True, only allow accessors to be appended as items :param allow_overwrite: if True, allow an item to be overwritten """ self.enforce_accessors = enforce_accessors self.allow_overwrite = allow_overwrite self.last_time = self.tfunc() self.markers = None self.timer = OrderedDict() super().__init__() if start_acc is not None: self['input'] = start_acc def __setitem__(self, key, value: GenericImageDataAccessor): if isinstance(value, GenericImageDataAccessor): acc = value else: if self.enforce_accessors: raise NoAccessorsFoundError(f'Pipeline trace expects data accessor type') else: acc = InMemoryDataAccessor(value) if not self.allow_overwrite and key in self.keys(): raise KeyAlreadyExists(f'key {key} already exists in pipeline trace') self.timer.__setitem__(key, self.tfunc() - self.last_time) self.last_time = self.tfunc() return super().__setitem__(key, acc) def append(self, tr, skip_first=True): new_tr = self.copy() for k, v in tr.items(): if skip_first and v == tr.first: continue dt = tr.timer[k] if k == 'input': k = 'appended_input' if not self.allow_overwrite and k in self.keys(): raise KeyAlreadyExists(f'Trying to append trace with key {k} that already exists') new_tr.__setitem__(k, v) new_tr.timer.__setitem__(k, dt) new_tr.last_time = self.tfunc() return new_tr @property def times(self): """ Return an ordered dictionary of incremental times for each item that is appended """ return {k: self.timer[k] for k in self.keys()} @property def first(self): """ Return first item :return: """ return list(self.values())[0] @property def last(self): """ Return most recently appended item :return: """ return list(self.values())[-1] def accessors(self, skip_first=True, skip_last=True) -> dict: """ Return subset ordered dictionary that guarantees items are accessors :param skip_first: if True, exclude first item in trace :param skip_last: if False, exclude last item in trace :return: dictionary of accessors that meet input criteria """ res = OrderedDict() for i, item in enumerate(self.items()): k, v = item if not isinstance(v, GenericImageDataAccessor): continue if skip_first and k == list(self.keys())[0]: continue if skip_last and k == list(self.keys())[-1]: continue res[k] = v return res def write_interm( self, where: Path, prefix: str = 'interm', skip_first=True, skip_last=True, debug=False ) -> List[Path]: """ Write accessor data to TIF files under specified path :param where: directory in which to write image files :param prefix: (optional) file prefix :param skip_first: if True, do not write first item in trace :param skip_last: if False, do not write last item in trace :param debug: if True, report destination filepaths but do not write files :return: list of destination filepaths """ paths = [] accessors = self.accessors(skip_first=skip_first, skip_last=skip_last) for i, item in enumerate(accessors.items()): k, v = item fp = where / f'{prefix}_{i:02d}_{k}.tif' paths.append(fp) if not debug: v.write(fp) return paths class Error(Exception): pass class IncompatibleModelsError(Error): pass class KeyAlreadyExists(Error): pass class NoAccessorsFoundError(Error): pass class UnexpectedPipelineReturnError(Error): pass