Skip to content
Snippets Groups Projects
Commit 5c797626 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Test covers task API

parent 1bbc9f07
No related branches found
No related tags found
2 merge requests!102Merge staging as release,!72Pipeline task management
This commit is part of merge request !72. Comments created here will be created in the context of that merge request.
......@@ -15,6 +15,7 @@ from urllib3 import Retry
from .fastapi import app
from ..base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from ..base.models import SemanticSegmentationModel, InstanceMaskSegmentationModel
from ..base.pipelines.shared import call_pipeline, PipelineParams, PipelineQueueRecord, PipelineTrace
from ..base.session import session
from ..base.accessors import generate_file_accessor
......@@ -45,17 +46,35 @@ def load_dummy_accessor() -> str:
return session.add_accessor(acc)
@test_router.put('/models/dummy_semantic/load/')
def load_dummy_model() -> dict:
def load_dummy_semantic_model() -> dict:
mid = session.load_model(DummySemanticSegmentationModel)
session.log_info(f'Loaded model {mid}')
return {'model_id': mid}
@test_router.put('/models/dummy_instance/load/')
def load_dummy_model() -> dict:
def load_dummy_instance_model() -> dict:
mid = session.load_model(DummyInstanceMaskSegmentationModel)
session.log_info(f'Loaded model {mid}')
return {'model_id': mid}
class DummyTaskParams(PipelineParams):
accessor_id: str
@test_router.put('/tasks/create_dummy_task')
def create_dummy_task(params: DummyTaskParams) -> PipelineQueueRecord:
def _dummy_pipeline(accessors, models, **k):
d = PipelineTrace(accessors.get(''))
model = models.get('')
d['res'] = d.last.apply(lambda x: 2 * x)
return d
task_id = session.queue.add_task(
lambda x: call_pipeline(_dummy_pipeline, x),
params
)
return PipelineQueueRecord(task_id=task_id)
app.include_router(test_router)
......@@ -132,7 +151,13 @@ class TestServerBaseClass(unittest.TestCase):
return self.input_data['name']
def get_accessor(self, accessor_id, filename=None, copy_to=None):
r = self.assertPutSuccess(f'/accessors/write_to_file/{accessor_id}', query={'filename': filename})
r = self.assertPutSuccess(
f'/accessors/write_to_file/{accessor_id}',
query={
'filename': filename,
'pop': False,
},
)
fp_out = Path(self.assertGetSuccess('paths')['outbound_images']) / r
self.assertTrue(fp_out.exists())
if copy_to:
......
......@@ -197,4 +197,22 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
self.assertPutSuccess(
'/models/classify/threshold/load',
body={}
)
\ No newline at end of file
)
def test_run_dummy_task(self):
acc_id = self.assertPutSuccess('/testing/accessors/dummy_accessor/load')
acc_in = self.get_accessor(acc_id)
task_id = self.assertPutSuccess(
'/testing/tasks/create_dummy_task',
body={'accessor_id': acc_id}
)['task_id']
task_info1 = self.assertGetSuccess(f'/tasks/{task_id}')
self.assertEqual(task_info1['status'], 'WAITING')
rec = self.assertPutSuccess(f'/tasks/{task_id}/run')
task_info2 = self.assertGetSuccess(f'/tasks/{task_id}')
self.assertEqual(task_info2['status'], 'FINISHED')
acc_out = self.get_accessor(task_info2['result']['output_accessor_id'])
self.assertTrue((acc_out.data == acc_in.data * 2).all())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment