from pathlib import Path from typing import Dict, List, Union import uuid from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from .accessors import generate_file_accessor, generate_multiposition_file_accessors from .models import BinaryThresholdSegmentationModel from .pipelines.shared import PipelineRecord from .roiset import IntensityThresholdInstanceMaskSegmentationModel from .session import session, AccessorIdError, InvalidPathError, WriteAccessorError app = FastAPI(debug=True) from .pipelines.router import router app.include_router(router) @app.on_event("startup") def startup(): pass @app.get('/favicon.ico') async def favicon(): return None @app.get('/') def read_root(): return {'success': True} @app.get('/paths/session') def get_top_session_path(): return session.get_paths()['session'] @app.get('/paths/inbound') def get_inbound_path(): return session.get_paths()['inbound_images'] @app.get('/paths/outbound') def get_outbound_path(): return session.get_paths()['outbound_images'] @app.get('/paths') def list_session_paths(): return session.get_paths() @app.get('/status') def show_session_status(): return { 'status': 'running', 'memory': session.get_mem(), 'models': session.describe_loaded_models(), 'paths': session.get_paths(), 'accessors': session.list_accessors(), 'tasks': session.tasks.list_tasks(), } def _change_path(key, path, touch=False) -> Union[str, None]: try: if touch: with open(session.paths[key] / f'{key}.path', 'w') as fh: fh.write(path.__str__()) session.set_data_directory(key, path) if touch: uid = str(uuid.uuid4()) with open(Path(path) / 'svlt.touch', 'w') as fh: fh.write(uid) return uid except InvalidPathError as e: raise HTTPException(404, f'Did not find valid folder at: {path}') @app.put('/paths/watch_input') def watch_input_path(path: str, touch: bool = False) -> Union[str, None]: return _change_path('inbound_images', path, touch=touch) @app.put('/paths/watch_output') def watch_output_path(path: str, touch: bool = False) -> Union[str, None]: return _change_path('outbound_images', path, touch=touch) @app.put('/paths/watch_conf') def watch_output_path(path: str, touch: bool = False) -> Union[str, None]: return _change_path('conf', path, touch=touch) @app.get('/session/restart') def restart_session() -> dict: session.restart() return session.describe_loaded_models() @app.get('/session/logs') def list_session_log() -> list: return session.get_log_data() @app.get('/session/errors') def get_errors() -> list: return [li for li in session.get_log_data() if li['level'] == 'ERROR'] @app.get('/session/mem') def get_mem() -> dict: return session.get_mem() @app.get('/models') def list_active_models(): return session.describe_loaded_models() class BinaryThresholdSegmentationParams(BaseModel): tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation') @app.put('/models/seg/threshold/load/') def load_binary_threshold_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict: result = session.load_model(BinaryThresholdSegmentationModel, key=model_id, params=p) session.log_info(f'Loaded binary threshold segmentation model {result}') return {'model_id': result} @app.put('/models/classify/threshold/load') def load_intensity_threshold_instance_segmentation_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict: result = session.load_model(IntensityThresholdInstanceMaskSegmentationModel, key=model_id, params=p) session.log_info(f'Loaded permissive instance segmentation model {result}') return {'model_id': result} @app.get('/accessors') def list_accessors() -> Dict: return session.list_accessors() @app.get('/accessors/loaded') def list_loaded_accessors() -> Dict: return {k: d for k, d in session.list_accessors().items() if d['loaded']} def _session_accessor(func, acc_id): try: return func(acc_id) except AccessorIdError as e: raise HTTPException(404, f'Did not find accessor with ID {acc_id}') @app.get('/accessors/get/{accessor_id}') def get_accessor(accessor_id: str): return _session_accessor(session.get_accessor_info, accessor_id) @app.get('/accessors/delete/{accessor_id}') def delete_accessor(accessor_id: str): if accessor_id == '*': return session.del_all_accessors() else: return _session_accessor(session.del_accessor, accessor_id) @app.put('/accessors/read_from_file/{filename}') def read_accessor_from_file(filename: str, lazy: bool = False) -> str: fp = session.paths['inbound_images'] / filename if not fp.exists(): raise HTTPException(status_code=404, detail=f'Could not find file:\n{filename}') acc = generate_file_accessor(fp, lazy=lazy) return session.add_accessor(acc) @app.put('/accessors/write_to_file/{accessor_id}') def write_accessor_to_file(accessor_id: str, filename: Union[str, None] = None, pop: bool = True) -> str: try: return session.write_accessor(accessor_id, filename, pop=pop) except AccessorIdError as e: raise HTTPException(404, f'Did not find accessor with ID {accessor_id}') except WriteAccessorError as e: raise HTTPException(409, str(e)) @app.put('/accessors/read_multiposition_file/{filename}') def read_multiposition_from_file(filename: str, lazy: bool = False) -> List[str]: fp = session.paths['inbound_images'] / filename if not fp.exists(): raise HTTPException(status_code=404, detail=f'Could not find file:\n{filename}') multipos = generate_multiposition_file_accessors(fp, lazy=lazy) acc_list = multipos.get_accessors() return [session.add_accessor(acc) for acc in acc_list] class TaskInfo(BaseModel): module: str params: dict func_str: str status: str target: Union[str, None] error: Union[str, None] result: Union[Dict, None] @app.put('/tasks/run_all') def run_all_tasks(write_all: bool = False) -> List[PipelineRecord]: resl = [] while task_id := session.tasks.next_waiting: resl.append(session.tasks.run_task(task_id, write_all=write_all)) return resl @app.get('/tasks/ready') def list_waiting_tasks_by_status() -> Dict[str, TaskInfo]: return session.tasks.list_tasks(status='READY') @app.get('/tasks/waiting') def list_waiting_tasks_by_status() -> Dict[str, TaskInfo]: return session.tasks.list_tasks(status='WAITING') @app.get('/tasks/by_status') def list_tasks_by_status() -> Dict[str, Dict[str, TaskInfo]]: return {st: session.tasks.list_tasks(status=st) for st in session.tasks.status_codes.values()} @app.get('/tasks/next') def get_next_waiting_text() -> Union[str, None]: return session.tasks.next_waiting @app.put('/tasks/delete_accessors/{task_id}') def delete_accessors_by_parent_task(task_id: str) -> list[str]: return session.del_all_accessors(parent_task_id=task_id) @app.put('/tasks/delete_accessors') def delete_accessors() -> list[str]: return session.del_all_accessors() @app.put('/tasks/run/{task_id}') def run_task( task_id: str, write_all: bool = False, prefix: str = None, output_subdirectory: str = None, allow_overwrite: bool = False ) -> PipelineRecord: return session.tasks.run_task( task_id, write_all=write_all, prefix=prefix, output_subdirectory=output_subdirectory, allow_overwrite=allow_overwrite, ) @app.get('/tasks/get/{task_id}') def get_task(task_id: str) -> TaskInfo: return session.tasks.get_task_info(task_id) @app.get('/tasks/get_output_accessor/{task_id}') def get_output_accessor_id_for_task(task_id: str) -> str: return session.tasks.get_output_accessor_id(task_id) @app.get('/tasks') def list_tasks() -> Dict[str, TaskInfo]: return session.tasks.list_tasks() @app.get('/phenobase/bounding_box') def get_phenobase_bounding_boxes() -> list: return session.phenobase.list_bounding_boxes()