Skip to content
Snippets Groups Projects
shared.py 7.45 KiB
from collections import OrderedDict
from pathlib import Path
from time import perf_counter
from typing import List, Union

from fastapi import HTTPException
from pydantic import BaseModel, Field, root_validator

from ..accessors import GenericImageDataAccessor
from ..session import session, AccessorIdError


class PipelineParams(BaseModel):
    keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session')
    api: bool = Field(True, description='Validate parameters against server session and map HTTP errors if True')

    @root_validator(pre=False)
    def models_are_loaded(cls, dd):
        for k, v in dd.items():
            if dd['api'] and k.endswith('model_id') and v is not None:
                if v not in session.describe_loaded_models().keys():
                    raise HTTPException(status_code=409, detail=f'Model with {k} = {v} has not been loaded')
        return dd

    @root_validator(pre=False)
    def accessors_are_loaded(cls, dd):
        for k, v in dd.items():
            if dd['api'] and k.endswith('accessor_id'):
                try:
                    info = session.get_accessor_info(v)
                except AccessorIdError as e:
                    raise HTTPException(status_code=409, detail=str(e))
                if not info['loaded']:
                    raise HTTPException(status_code=409, detail=f'Accessor with {k} = {v} has not been loaded')
        return dd


class PipelineRecord(BaseModel):
    output_accessor_id: str
    interm_accessor_ids: Union[List[str], None]
    success: bool
    timer: dict


def call_pipeline(func, p: PipelineParams) -> PipelineRecord:
    # match accessor IDs to loaded accessor objects
    accessors_in = {}

    # TODO: remove "model" from model.keys()
    for k, v in p.dict().items():
        if k.endswith('accessor_id'):
            accessors_in[k.split('_id')[0]] = session.get_accessor(v, pop=True)
    if len(accessors_in) == 0:
        raise NoAccessorsFoundError('Expecting as least one valid accessor to run pipeline')

    # use first validated accessor ID to name log file entries and derived accessors
    input_description = [p.dict()[pk] for pk in p.dict().keys() if pk.endswith('accessor_id')][0]

    models = {}
    for k, v in p.dict().items():
        if k.endswith('model_id') and v is not None:
            models[k.split('_id')[0]] = session.models[v]['object']

    # call the actual pipeline; expect a single PipelineTrace or a tuple where first element is PipelineTrace
    ret = func(
        accessors_in,
        models,
        **p.dict(),
    )
    if isinstance(ret, PipelineTrace):
        steps = ret
        misc = None
    elif isinstance(ret, tuple) and isinstance(ret[0], PipelineTrace):
        steps = ret[0]
        misc = ret[1:]
    else:
        raise UnexpectedPipelineReturnError(
            f'{func.__name__} returned unexpected value of {type(ret)}'
        )
    session.log_info(f'Completed {func.__name__} on {input_description}.')

    # map intermediate data accessors to accessor IDs
    if p.keep_interm:
        interm_ids = []
        acc_interm = steps.accessors(skip_first=True, skip_last=True).items()
        for i, item in enumerate(acc_interm):
            stk, acc = item
            interm_ids.append(
                session.add_accessor(
                    acc,
                    accessor_id=f'{input_description}_{func.__name__}_step{(i + 1):02d}_{stk}'
                )
            )
    else:
        interm_ids = None

    # map final result to an accessor ID
    result_id = session.add_accessor(
        steps.last,
        accessor_id=f'{p.accessor_id}_{func.__name__}_result'
    )

    record = PipelineRecord(
        output_accessor_id=result_id,
        interm_accessor_ids=interm_ids,
        success=True,
        timer=steps.times
    )

    # return miscellaneous objects if pipeline returns these
    if misc:
        return record, *misc
    else:
        return record


class PipelineTrace(OrderedDict):
    tfunc = perf_counter

    def __init__(self, start_acc: GenericImageDataAccessor = None, enforce_accessors=True, allow_overwrite=False):
        """
        A container and timer for data at each stage of a pipeline.
        :param start_acc: (optional) accessor to initialize as 'input' step
        :param enforce_accessors: if True, only allow accessors to be appended as items
        :param allow_overwrite: if True, allow an item to be overwritten
        """
        self.enforce_accessors = enforce_accessors
        self.allow_overwrite = allow_overwrite
        self.last_time = self.tfunc()
        self.markers = None
        self.timer = OrderedDict()
        super().__init__()
        if start_acc is not None:
            self['input'] = start_acc

    def __setitem__(self, key, value: GenericImageDataAccessor):
        if self.enforce_accessors:
            assert isinstance(value, GenericImageDataAccessor), f'Pipeline trace expects data accessor type'
        if not self.allow_overwrite and key in self.keys():
            raise KeyAlreadyExists(f'key {key} already exists in pipeline trace')
        self.timer.__setitem__(key, self.tfunc() - self.last_time)
        self.last_time = self.tfunc()
        return super().__setitem__(key, value)

    @property
    def times(self):
        """
        Return an ordered dictionary of incremental times for each item that is appended
        """
        return {k: self.timer[k] for k in self.keys()}

    @property
    def last(self):
        """
        Return most recently appended item
        :return:
        """
        return list(self.values())[-1]

    def accessors(self, skip_first=True, skip_last=True) -> dict:
        """
        Return subset ordered dictionary that guarantees items are accessors
        :param skip_first: if True, exclude first item in trace
        :param skip_last: if False, exclude last item in trace
        :return: dictionary of accessors that meet input criteria
        """
        res = OrderedDict()
        for i, item in enumerate(self.items()):
            k, v = item
            if not isinstance(v, GenericImageDataAccessor):
                continue
            if skip_first and k == list(self.keys())[0]:
                continue
            if skip_last and k == list(self.keys())[-1]:
                continue
            res[k] = v
        return res

    def write_interm(
            self,
            where: Path,
            prefix: str = 'interm',
            skip_first=True,
            skip_last=True,
            debug=False
    ) -> List[Path]:
        """
        Write accessor data to TIF files under specified path
        :param where: directory in which to write image files
        :param prefix: (optional) file prefix
        :param skip_first: if True, do not write first item in trace
        :param skip_last: if False, do not write last item in trace
        :param debug: if True, report destination filepaths but do not write files
        :return: list of destination filepaths
        """
        paths = []
        accessors = self.accessors(skip_first=skip_first, skip_last=skip_last)
        for i, item in enumerate(accessors.items()):
            k, v = item
            fp = where / f'{prefix}_{i:02d}_{k}.tif'
            paths.append(fp)
            if not debug:
                v.write(fp)
        return paths


class Error(Exception):
    pass

class IncompatibleModelsError(Error):
    pass

class KeyAlreadyExists(Error):
    pass

class NoAccessorsFoundError(Error):
    pass

class UnexpectedPipelineReturnError(Error):
    pass