Skip to content
Snippets Groups Projects
Commit b0b8499d authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Intermediate data can now be kept in session context

parent fb7e1245
No related branches found
No related tags found
No related merge requests found
from typing import List, Union
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
...@@ -7,6 +9,7 @@ from ..session import session, AccessorIdError ...@@ -7,6 +9,7 @@ from ..session import session, AccessorIdError
class SingleModelPipelineParams(BaseModel): class SingleModelPipelineParams(BaseModel):
accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input') accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input')
model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)') model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)')
keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session')
@validator('model_id') @validator('model_id')
def models_are_loaded(cls, model_id): def models_are_loaded(cls, model_id):
...@@ -27,6 +30,7 @@ class SingleModelPipelineParams(BaseModel): ...@@ -27,6 +30,7 @@ class SingleModelPipelineParams(BaseModel):
class PipelineRecord(BaseModel): class PipelineRecord(BaseModel):
output_accessor_id: str output_accessor_id: str
interm_accessor_ids: Union[List[str], None]
model_id: str model_id: str
success: bool success: bool
timer: dict timer: dict
...@@ -13,8 +13,28 @@ def call_pipeline(func, p: SingleModelPipelineParams): ...@@ -13,8 +13,28 @@ def call_pipeline(func, p: SingleModelPipelineParams):
session.log_info(f'Completed {func.__name__} on {p.accessor_id}.') session.log_info(f'Completed {func.__name__} on {p.accessor_id}.')
if p.keep_interm:
interm_ids = []
acc_interm = steps.accessors(skip_first=True, skip_last=True).items()
for i, item in enumerate(acc_interm):
stk, acc = item
interm_ids.append(
session.add_accessor(
acc,
accessor_id=f'{p.accessor_id}_{func.__name__}_step{(i + 1):02d}_{stk}'
)
)
else:
interm_ids = None
result_id = session.add_accessor(
steps.last,
accessor_id=f'{p.accessor_id}_{func.__name__}_result'
)
return PipelineRecord( return PipelineRecord(
output_accessor_id=session.add_accessor(steps.last), output_accessor_id=result_id,
interm_accessor_ids=interm_ids,
model_id=p.model_id, model_id=p.model_id,
success=True, success=True,
timer=steps.times timer=steps.times
......
...@@ -151,6 +151,7 @@ class TestApiFromAutomatedClient(TestServerTestCase): ...@@ -151,6 +151,7 @@ class TestApiFromAutomatedClient(TestServerTestCase):
'accessor_id': in_acc_id, 'accessor_id': in_acc_id,
'model_id': model_id, 'model_id': model_id,
'channel': 2, 'channel': 2,
'keep_interm': True,
}, },
) )
self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
...@@ -166,6 +167,10 @@ class TestApiFromAutomatedClient(TestServerTestCase): ...@@ -166,6 +167,10 @@ class TestApiFromAutomatedClient(TestServerTestCase):
acc_out = generate_file_accessor(fp_out) acc_out = generate_file_accessor(fp_out)
self.assertEqual(acc_out.shape_dict['C'], 1) self.assertEqual(acc_out.shape_dict['C'], 1)
# validate intermediate data
resp_list = self._get(f'accessors').json()
self.assertEqual(len([k for k in resp_list.keys() if '_step' in k]), 2)
def test_restarting_session_clears_loaded_models(self): def test_restarting_session_clears_loaded_models(self):
resp_load = self._put(f'testing/models/dummy_semantic/load',) resp_load = self._put(f'testing/models/dummy_semantic/load',)
self.assertEqual(resp_load.status_code, 200, resp_load.json()) self.assertEqual(resp_load.status_code, 200, resp_load.json())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment