diff --git a/model_server/base/api.py b/model_server/base/api.py index 1a533b92256f5c2d61ac054186aee55dfd807def..d1e3abd3f460919b63e0519b0b5e03d2ecad0293 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -52,7 +52,7 @@ def show_session_status(): 'paths': session.get_paths(), 'rois': session.list_rois(), 'accessors': session.list_accessors(), - 'tasks': session.queue.list_tasks(), + 'tasks': session.tasks.list_tasks(), } @@ -214,24 +214,20 @@ class TaskInfo(BaseModel): error: Union[str, None] result: Union[Dict, None] -# TODO: cover task API with tests, using dummy task resource endpoint - -# TODO: test to cover reporting of exception in running tasks - @app.put('/tasks/{task_id}/run') def run_task(task_id: str) -> PipelineRecord: try: - return session.queue.run_task(task_id) + return session.tasks.run_task(task_id) except RunTaskError as e: raise HTTPException(409, str(e)) @app.get('/tasks/{task_id}') def get_task(task_id: str) -> TaskInfo: - return session.queue.get_task_info(task_id) + return session.tasks.get_task_info(task_id) @app.get('/tasks') def list_tasks() -> Dict[str, TaskInfo]: - res = session.queue.list_tasks() + res = session.tasks.list_tasks() return res # TODO: implement diff --git a/model_server/base/pipelines/shared.py b/model_server/base/pipelines/shared.py index 1a1409108831e26e893a9db027469754c7177dc0..be12c1f20243b67c269075856f435978923d4375 100644 --- a/model_server/base/pipelines/shared.py +++ b/model_server/base/pipelines/shared.py @@ -53,7 +53,7 @@ def call_pipeline(func, p: PipelineParams) -> Union[PipelineRecord, PipelineQueu # instead of running right away, schedule pipeline as a task if p.schedule: p.schedule = False - task_id = session.queue.add_task( + task_id = session.tasks.add_task( lambda x: call_pipeline(func, x), p ) diff --git a/model_server/base/session.py b/model_server/base/session.py index 9f55935ee3bf02f73a9aacb6b012fcffc5e516a0..9df135f5679d5650144ab003a2bdf3725357e19c 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -33,8 +33,8 @@ class CsvTable(object): self.empty = False return True -# TODO: is there a standard library alternative here? -class Queue(object): + +class TaskCollection(object): status_codes = { 'waiting': 'WAITING', @@ -44,14 +44,13 @@ class Queue(object): } def __init__(self): - self._queue = OrderedDict() + self._tasks = OrderedDict() self._handles = OrderedDict() def add_task(self, func: callable, params: dict) -> str: task_id = str(uuid.uuid4()) - name = func.__name__ - self._queue[task_id] = { + self._tasks[task_id] = { 'module': func.__module__, 'params': params, 'func_str': str(func), @@ -65,13 +64,13 @@ class Queue(object): return str(task_id) def get_task_info(self, task_id: str) -> dict: - return self._queue[task_id] + return self._tasks[task_id] def list_tasks(self) -> OrderedDict: - return self._queue + return self._tasks def run_task(self, task_id: str): - task = self._queue[task_id] + task = self._tasks[task_id] f = self._handles[task_id] p = task['params'] try: @@ -96,7 +95,7 @@ class _Session(object): self.models = {} # model_id : model object self.paths = self.make_paths() self.accessors = OrderedDict() - self.queue = Queue() + self.tasks = TaskCollection() self.rois = OrderedDict() self.logfile = self.paths['logs'] / f'session.log' diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py index 54d82907aea65743154231f108ee90be2483cc4b..5e6cfb7e002c30adc310929b6a4de23479a4e022 100644 --- a/model_server/conf/testing.py +++ b/model_server/conf/testing.py @@ -72,7 +72,7 @@ def create_dummy_task(params: DummyTaskParams) -> PipelineQueueRecord: d['res'] = d.last.apply(lambda x: 2 * x) return d - task_id = session.queue.add_task( + task_id = session.tasks.add_task( lambda x: call_pipeline(_dummy_pipeline, x), params )