From b0b8499d092c1935da0b2ae26fb32c913710fd1e Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 9 Aug 2024 15:40:00 +0200 Subject: [PATCH] Intermediate data can now be kept in session context --- model_server/base/pipelines/params.py | 4 ++++ model_server/base/pipelines/util.py | 22 +++++++++++++++++++++- tests/base/test_api.py | 5 +++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/model_server/base/pipelines/params.py b/model_server/base/pipelines/params.py index e8fe92fb..9f1bda03 100644 --- a/model_server/base/pipelines/params.py +++ b/model_server/base/pipelines/params.py @@ -1,3 +1,5 @@ +from typing import List, Union + from fastapi import HTTPException from pydantic import BaseModel, Field, validator @@ -7,6 +9,7 @@ from ..session import session, AccessorIdError class SingleModelPipelineParams(BaseModel): accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input') model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)') + keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session') @validator('model_id') def models_are_loaded(cls, model_id): @@ -27,6 +30,7 @@ class SingleModelPipelineParams(BaseModel): class PipelineRecord(BaseModel): output_accessor_id: str + interm_accessor_ids: Union[List[str], None] model_id: str success: bool timer: dict diff --git a/model_server/base/pipelines/util.py b/model_server/base/pipelines/util.py index 64598a43..a4a37feb 100644 --- a/model_server/base/pipelines/util.py +++ b/model_server/base/pipelines/util.py @@ -13,8 +13,28 @@ def call_pipeline(func, p: SingleModelPipelineParams): session.log_info(f'Completed {func.__name__} on {p.accessor_id}.') + if p.keep_interm: + interm_ids = [] + acc_interm = steps.accessors(skip_first=True, skip_last=True).items() + for i, item in enumerate(acc_interm): + stk, acc = item + interm_ids.append( + session.add_accessor( + acc, + accessor_id=f'{p.accessor_id}_{func.__name__}_step{(i + 1):02d}_{stk}' + ) + ) + else: + interm_ids = None + + result_id = session.add_accessor( + steps.last, + accessor_id=f'{p.accessor_id}_{func.__name__}_result' + ) + return PipelineRecord( - output_accessor_id=session.add_accessor(steps.last), + output_accessor_id=result_id, + interm_accessor_ids=interm_ids, model_id=p.model_id, success=True, timer=steps.times diff --git a/tests/base/test_api.py b/tests/base/test_api.py index 3cb0ebf1..35476d2b 100644 --- a/tests/base/test_api.py +++ b/tests/base/test_api.py @@ -151,6 +151,7 @@ class TestApiFromAutomatedClient(TestServerTestCase): 'accessor_id': in_acc_id, 'model_id': model_id, 'channel': 2, + 'keep_interm': True, }, ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) @@ -166,6 +167,10 @@ class TestApiFromAutomatedClient(TestServerTestCase): acc_out = generate_file_accessor(fp_out) self.assertEqual(acc_out.shape_dict['C'], 1) + # validate intermediate data + resp_list = self._get(f'accessors').json() + self.assertEqual(len([k for k in resp_list.keys() if '_step' in k]), 2) + def test_restarting_session_clears_loaded_models(self): resp_load = self._put(f'testing/models/dummy_semantic/load',) self.assertEqual(resp_load.status_code, 200, resp_load.json()) -- GitLab