Skip to content
Snippets Groups Projects
Commit 18c0884c authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Merge branch 'dev_task_queue' into 'staging'

Pipeline task management

See merge request !72
parents 7342a878 d206d455
No related branches found
No related tags found
2 merge requests!102Merge staging as release,!72Pipeline task management
from pydantic import BaseModel, Field
from typing import List, Union
from typing import Dict, List, Union
from fastapi import FastAPI, HTTPException
from .accessors import generate_file_accessor
from .models import BinaryThresholdSegmentationModel
from .pipelines.shared import PipelineRecord
from .roiset import IntensityThresholdInstanceMaskSegmentationModel, RoiSetExportParams, SerializeRoiSetError
from .session import session, AccessorIdError, InvalidPathError, RoiSetIdError, WriteAccessorError
from .session import session, AccessorIdError, InvalidPathError, RoiSetIdError, RunTaskError, WriteAccessorError
app = FastAPI(debug=True)
......@@ -51,6 +52,7 @@ def show_session_status():
'paths': session.get_paths(),
'rois': session.list_rois(),
'accessors': session.list_accessors(),
'tasks': session.tasks.list_tasks(),
}
......@@ -129,7 +131,6 @@ def delete_accessor(accessor_id: str):
else:
return _session_accessor(session.del_accessor, accessor_id)
@app.put('/accessors/read_from_file/{filename}')
def read_accessor_from_file(filename: str, accessor_id: Union[str, None] = None):
fp = session.paths['inbound_images'] / filename
......@@ -140,9 +141,9 @@ def read_accessor_from_file(filename: str, accessor_id: Union[str, None] = None)
@app.put('/accessors/write_to_file/{accessor_id}')
def write_accessor_to_file(accessor_id: str, filename: Union[str, None] = None) -> str:
def write_accessor_to_file(accessor_id: str, filename: Union[str, None] = None, pop: bool = True) -> str:
try:
return session.write_accessor(accessor_id, filename)
return session.write_accessor(accessor_id, filename, pop=pop)
except AccessorIdError as e:
raise HTTPException(404, f'Did not find accessor with ID {accessor_id}')
except WriteAccessorError as e:
......@@ -202,4 +203,28 @@ def roiset_get_object_map(
raise HTTPException(
404,
f'Did not find object map from classification model {model_id} in RoiSet {roiset_id}'
)
\ No newline at end of file
)
class TaskInfo(BaseModel):
module: str
params: dict
func_str: str
status: str
error: Union[str, None]
result: Union[Dict, None]
@app.put('/tasks/{task_id}/run')
def run_task(task_id: str) -> PipelineRecord:
try:
return session.tasks.run_task(task_id)
except RunTaskError as e:
raise HTTPException(409, str(e))
@app.get('/tasks/{task_id}')
def get_task(task_id: str) -> TaskInfo:
return session.tasks.get_task_info(task_id)
@app.get('/tasks')
def list_tasks() -> Dict[str, TaskInfo]:
res = session.tasks.list_tasks()
return res
\ No newline at end of file
......@@ -8,9 +8,10 @@ from .segment_zproj import segment_zproj_pipeline
from .shared import call_pipeline
from ..roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams
from ..pipelines.shared import PipelineTrace, PipelineParams, PipelineRecord
from ..pipelines.shared import PipelineQueueRecord, PipelineTrace, PipelineParams, PipelineRecord
from ..models import Model, InstanceMaskSegmentationModel
from ..session import session
class RoiSetObjectMapParams(PipelineParams):
......@@ -32,7 +33,7 @@ class RoiSetObjectMapParams(PipelineParams):
)
object_classifier_model_id: Union[str, None] = Field(
None,
description='Object classifier used to classify segmented objectss'
description='Object classifier used to classify segmented objects'
)
patches_channel: int = Field(
description='Channel of input image used in patches sent to object classifier'
......@@ -57,8 +58,9 @@ class RoiSetObjectMapParams(PipelineParams):
class RoiSetToObjectMapRecord(PipelineRecord):
pass
@router.put('/roiset_to_obmap/infer')
def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord:
def roiset_object_map(p: RoiSetObjectMapParams) -> Union[RoiSetToObjectMapRecord, PipelineQueueRecord]:
"""
Compute a RoiSet from 2d segmentation, apply to z-stack, and optionally apply object classification.
"""
......
......@@ -4,6 +4,7 @@ from time import perf_counter
from typing import List, Union
from fastapi import HTTPException
from numba.scripts.generate_lower_listing import description
from pydantic import BaseModel, Field, root_validator
from ..accessors import GenericImageDataAccessor, InMemoryDataAccessor
......@@ -12,6 +13,7 @@ from ..session import session, AccessorIdError
class PipelineParams(BaseModel):
schedule: bool = Field(False, description='Schedule as a task instead of running immediately')
keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session')
api: bool = Field(True, description='Validate parameters against server session and map HTTP errors if True')
......@@ -44,7 +46,19 @@ class PipelineRecord(BaseModel):
roiset_id: Union[str, None] = None
def call_pipeline(func, p: PipelineParams) -> PipelineRecord:
class PipelineQueueRecord(BaseModel):
task_id: str
def call_pipeline(func, p: PipelineParams) -> Union[PipelineRecord, PipelineQueueRecord]:
# instead of running right away, schedule pipeline as a task
if p.schedule:
p.schedule = False
task_id = session.tasks.add_task(
lambda x: call_pipeline(func, x),
p
)
return PipelineQueueRecord(task_id=task_id)
# match accessor IDs to loaded accessor objects
accessors_in = {}
......
from collections import OrderedDict
import logging
import os
import uuid
from pathlib import Path, PureWindowsPath
from pydantic import BaseModel
......@@ -32,6 +33,71 @@ class CsvTable(object):
self.empty = False
return True
class TaskCollection(object):
status_codes = {
'waiting': 'WAITING',
'in_progress': 'IN_PROGRESS',
'finished': 'FINISHED',
'failed': 'FAILED',
}
def __init__(self):
self._tasks = OrderedDict()
self._handles = OrderedDict()
def add_task(self, func: callable, params: dict) -> str:
task_id = str(uuid.uuid4())
self._tasks[task_id] = {
'module': func.__module__,
'params': params,
'func_str': str(func),
'status': self.status_codes['waiting'],
'error': None,
'result': None,
}
self._handles[task_id] = func
logger.info(f'Added task {task_id}: {str(func)}')
return str(task_id)
def get_task_info(self, task_id: str) -> dict:
return self._tasks[task_id]
def list_tasks(self) -> OrderedDict:
return self._tasks
@property
def next_waiting(self):
"""
Return the task_id of the first status waiting for completion, or else None
"""
for k, v in self._tasks.items():
if v[k]['status'] == self.status_codes['waiting']:
return k
return None
def run_task(self, task_id: str):
task = self._tasks[task_id]
f = self._handles[task_id]
p = task['params']
try:
logger.info(f'Started running task {task_id}')
task['status'] = self.status_codes['in_progress']
result = f(p)
logger.info(f'Finished running task {task_id}')
task['status'] = self.status_codes['finished']
task['result'] = result
return result
except Exception as e:
task['status'] = self.status_codes['failed']
task['error'] = str(e)
logger.error(f'Error running task {task_id}: {str(e)}')
raise RunTaskError(e)
class _Session(object):
"""
Singleton class for a server session that persists data between API calls
......@@ -43,6 +109,7 @@ class _Session(object):
self.models = {} # model_id : model object
self.paths = self.make_paths()
self.accessors = OrderedDict()
self.tasks = TaskCollection()
self.rois = OrderedDict()
self.logfile = self.paths['logs'] / f'session.log'
......@@ -97,6 +164,7 @@ class _Session(object):
idx = len(self.accessors)
accessor_id = f'acc_{idx:06d}'
self.accessors[accessor_id] = {'loaded': True, 'object': acc, **acc.info}
self.log_info(f'Added accessor {accessor_id}')
return accessor_id
def del_accessor(self, accessor_id: str) -> str:
......@@ -114,6 +182,7 @@ class _Session(object):
assert isinstance(v['object'], GenericImageDataAccessor)
v['loaded'] = False
v['object'] = None
self.log_info(f'Deleted accessor {accessor_id}')
return accessor_id
def del_all_accessors(self) -> list[str]:
......@@ -127,6 +196,7 @@ class _Session(object):
v['object'] = None
v['loaded'] = False
res.append(k)
self.log_info(f'Deleted accessor {k}')
return res
......@@ -161,11 +231,12 @@ class _Session(object):
self.del_accessor(acc_id)
return acc
def write_accessor(self, acc_id: str, filename: Union[str, None] = None) -> str:
def write_accessor(self, acc_id: str, filename: Union[str, None] = None, pop: bool = True) -> str:
"""
Write an accessor to file and unload it from the session
Write an accessor to file and optionally unload it from the session
:param acc_id: accessor's ID
:param filename: force use of a specific filename, raise InvalidPathError if this already exists
:param pop: unload accessor from the session if True
:return: name of file
"""
if filename is None:
......@@ -174,7 +245,7 @@ class _Session(object):
fp = self.paths['outbound_images'] / filename
if fp.exists():
raise InvalidPathError(f'Cannot overwrite file {filename} when writing accessor')
acc = self.get_accessor(acc_id, pop=True)
acc = self.get_accessor(acc_id, pop=pop)
old_fp = self.accessors[acc_id]['filepath']
if old_fp != '':
......@@ -187,6 +258,7 @@ class _Session(object):
else:
acc.write(fp)
self.accessors[acc_id]['filepath'] = fp.__str__()
self.log_info(f'Wrote accessor {acc_id} to {fp.__str__()}')
return fp.name
def add_roiset(self, roiset: RoiSet, roiset_id: str = None) -> str:
......@@ -418,4 +490,7 @@ class CouldNotAppendToTable(Error):
pass
class InvalidPathError(Error):
pass
class RunTaskError(Error):
pass
\ No newline at end of file
......@@ -15,6 +15,7 @@ from urllib3 import Retry
from .fastapi import app
from ..base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from ..base.models import SemanticSegmentationModel, InstanceMaskSegmentationModel
from ..base.pipelines.shared import call_pipeline, PipelineParams, PipelineQueueRecord, PipelineTrace
from ..base.session import session
from ..base.accessors import generate_file_accessor
......@@ -45,17 +46,38 @@ def load_dummy_accessor() -> str:
return session.add_accessor(acc)
@test_router.put('/models/dummy_semantic/load/')
def load_dummy_model() -> dict:
def load_dummy_semantic_model() -> dict:
mid = session.load_model(DummySemanticSegmentationModel)
session.log_info(f'Loaded model {mid}')
return {'model_id': mid}
@test_router.put('/models/dummy_instance/load/')
def load_dummy_model() -> dict:
def load_dummy_instance_model() -> dict:
mid = session.load_model(DummyInstanceMaskSegmentationModel)
session.log_info(f'Loaded model {mid}')
return {'model_id': mid}
class DummyTaskParams(PipelineParams):
accessor_id: str
break_me: bool = False
@test_router.put('/tasks/create_dummy_task')
def create_dummy_task(params: DummyTaskParams) -> PipelineQueueRecord:
def _dummy_pipeline(accessors, models, **k):
d = PipelineTrace(accessors.get(''))
if k.get('break_me'):
raise Exception('I broke')
model = models.get('')
d['res'] = d.last.apply(lambda x: 2 * x)
return d
task_id = session.tasks.add_task(
lambda x: call_pipeline(_dummy_pipeline, x),
params
)
return PipelineQueueRecord(task_id=task_id)
app.include_router(test_router)
......@@ -132,7 +154,13 @@ class TestServerBaseClass(unittest.TestCase):
return self.input_data['name']
def get_accessor(self, accessor_id, filename=None, copy_to=None):
r = self.assertPutSuccess(f'/accessors/write_to_file/{accessor_id}', query={'filename': filename})
r = self.assertPutSuccess(
f'/accessors/write_to_file/{accessor_id}',
query={
'filename': filename,
'pop': False,
},
)
fp_out = Path(self.assertGetSuccess('paths')['outbound_images']) / r
self.assertTrue(fp_out.exists())
if copy_to:
......
......@@ -199,4 +199,36 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
body={}
)
def test_run_dummy_task(self):
acc_id = self.assertPutSuccess('/testing/accessors/dummy_accessor/load')
acc_in = self.get_accessor(acc_id)
task_id = self.assertPutSuccess(
'/testing/tasks/create_dummy_task',
body={'accessor_id': acc_id}
)['task_id']
task_info1 = self.assertGetSuccess(f'/tasks/{task_id}')
self.assertEqual(task_info1['status'], 'WAITING')
# run the task and compare results
rec = self.assertPutSuccess(f'/tasks/{task_id}/run')
task_info2 = self.assertGetSuccess(f'/tasks/{task_id}')
self.assertEqual(task_info2['status'], 'FINISHED')
acc_out = self.get_accessor(task_info2['result']['output_accessor_id'])
self.assertTrue((acc_out.data == acc_in.data * 2).all())
def test_run_failed_dummy_task(self):
acc_id = self.assertPutSuccess('/testing/accessors/dummy_accessor/load')
acc_in = self.get_accessor(acc_id)
task_id = self.assertPutSuccess(
'/testing/tasks/create_dummy_task',
body={'accessor_id': acc_id, 'break_me': True}
)['task_id']
task_info1 = self.assertGetSuccess(f'/tasks/{task_id}')
self.assertEqual(task_info1['status'], 'WAITING')
# run the task and compare results
rec = self.assertPutFailure(f'/tasks/{task_id}/run', 409)
task_info2 = self.assertGetSuccess(f'/tasks/{task_id}')
self.assertEqual(task_info2['status'], 'FAILED')
import json
from pathlib import Path
import unittest
......@@ -106,7 +107,6 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):
self.assertEqual(rois.count, 22)
self.assertEqual(len(trace['ob_id'].unique()[0]), 2)
class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts):
input_data = data['multichannel_zstack_raw']
......@@ -183,3 +183,109 @@ class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProd
res = self._object_map_workflow(None)
acc_obmap = self.get_accessor(res['output_accessor_id'])
self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1]))
class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts):
input_data = data['multichannel_zstack_raw']
def setUp(self) -> None:
self.where_out = output_path / 'roiset'
self.where_out.mkdir(parents=True, exist_ok=True)
return conf.TestServerBaseClass.setUp(self)
def test_load_input_accessor(self):
fname = self.copy_input_file_to_server()
return self.assertPutSuccess(f'accessors/read_from_file/{fname}')
def test_load_pixel_classifier(self):
mid = self.assertPutSuccess(
'models/seg/threshold/load/',
query={'tr': 0.2},
)['model_id']
self.assertTrue(mid.startswith('BinaryThresholdSegmentationModel'))
return mid
def test_load_object_classifier(self):
mid = self.assertPutSuccess(
'models/classify/threshold/load/',
body={'tr': 0}
)['model_id']
self.assertTrue(mid.startswith('IntensityThresholdInstanceMaskSegmentation'))
return mid
def _object_map_workflow(self, ob_classifer_id):
res = self.assertPutSuccess(
'pipelines/roiset_to_obmap/infer',
body={
'accessor_id': self.test_load_input_accessor(),
'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(),
'object_classifier_model_id': ob_classifer_id,
'segmentation': {'channel': 0},
'patches_channel': 1,
'roi_params': self._get_roi_params(),
'export_params': self._get_export_params(),
},
)
# check on automatically written RoiSet
roiset_id = res['roiset_id']
roiset_info = self.assertGetSuccess(f'rois/{roiset_id}')
self.assertGreater(roiset_info['count'], 0)
return res
def test_workflow_with_object_classifier(self):
obmod_id = self.test_load_object_classifier()
res = self._object_map_workflow(obmod_id)
acc_obmap = self.get_accessor(res['output_accessor_id'])
self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1]))
# get object map via RoiSet API
roiset_id = res['roiset_id']
obmap_id = self.assertPutSuccess(f'rois/obmap/{roiset_id}/{obmod_id}', query={'object_classes': True})
acc_obmap_roiset = self.get_accessor(obmap_id)
self.assertTrue(np.all(acc_obmap_roiset.data == acc_obmap.data))
# check serialize RoiSet
self.assertPutSuccess(f'rois/write/{roiset_id}')
self.assertFalse(
self.assertGetSuccess(f'rois/{roiset_id}')['loaded']
)
def test_workflow_without_object_classifier(self):
res = self._object_map_workflow(None)
acc_obmap = self.get_accessor(res['output_accessor_id'])
self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1]))
class TestTaskQueuedRoiSetWorkflowOverApi(TestRoiSetWorkflowOverApi):
def _object_map_workflow(self, ob_classifer_id):
res_queue = self.assertPutSuccess(
'pipelines/roiset_to_obmap/infer',
body={
'schedule': True,
'accessor_id': self.test_load_input_accessor(),
'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(),
'object_classifier_model_id': ob_classifer_id,
'segmentation': {'channel': 0},
'patches_channel': 1,
'roi_params': self._get_roi_params(),
'export_params': self._get_export_params(),
}
)
# check that task in enqueued
task_id = res_queue['task_id']
task_info = self.assertGetSuccess(f'tasks/{task_id}')
self.assertEqual(task_info['status'], 'WAITING')
# run the task
res_run = self.assertPutSuccess(
f'tasks/{task_id}/run'
)
self.assertTrue(res_run)
self.assertEqual(self.assertGetSuccess(f'tasks/{task_id}')['status'], 'FINISHED')
return res_run
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment