From b0b8499d092c1935da0b2ae26fb32c913710fd1e Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 9 Aug 2024 15:40:00 +0200
Subject: [PATCH] Intermediate data can now be kept in session context

---
 model_server/base/pipelines/params.py |  4 ++++
 model_server/base/pipelines/util.py   | 22 +++++++++++++++++++++-
 tests/base/test_api.py                |  5 +++++
 3 files changed, 30 insertions(+), 1 deletion(-)

diff --git a/model_server/base/pipelines/params.py b/model_server/base/pipelines/params.py
index e8fe92fb..9f1bda03 100644
--- a/model_server/base/pipelines/params.py
+++ b/model_server/base/pipelines/params.py
@@ -1,3 +1,5 @@
+from typing import List, Union
+
 from fastapi import HTTPException
 from pydantic import BaseModel, Field, validator
 
@@ -7,6 +9,7 @@ from ..session import session, AccessorIdError
 class SingleModelPipelineParams(BaseModel):
     accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input')
     model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)')
+    keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session')
 
     @validator('model_id')
     def models_are_loaded(cls, model_id):
@@ -27,6 +30,7 @@ class SingleModelPipelineParams(BaseModel):
 
 class PipelineRecord(BaseModel):
     output_accessor_id: str
+    interm_accessor_ids: Union[List[str], None]
     model_id: str
     success: bool
     timer: dict
diff --git a/model_server/base/pipelines/util.py b/model_server/base/pipelines/util.py
index 64598a43..a4a37feb 100644
--- a/model_server/base/pipelines/util.py
+++ b/model_server/base/pipelines/util.py
@@ -13,8 +13,28 @@ def call_pipeline(func, p: SingleModelPipelineParams):
 
     session.log_info(f'Completed {func.__name__} on {p.accessor_id}.')
 
+    if p.keep_interm:
+        interm_ids = []
+        acc_interm = steps.accessors(skip_first=True, skip_last=True).items()
+        for i, item in enumerate(acc_interm):
+            stk, acc = item
+            interm_ids.append(
+                session.add_accessor(
+                    acc,
+                    accessor_id=f'{p.accessor_id}_{func.__name__}_step{(i + 1):02d}_{stk}'
+                )
+            )
+    else:
+        interm_ids = None
+
+    result_id = session.add_accessor(
+        steps.last,
+        accessor_id=f'{p.accessor_id}_{func.__name__}_result'
+    )
+
     return PipelineRecord(
-        output_accessor_id=session.add_accessor(steps.last),
+        output_accessor_id=result_id,
+        interm_accessor_ids=interm_ids,
         model_id=p.model_id,
         success=True,
         timer=steps.times
diff --git a/tests/base/test_api.py b/tests/base/test_api.py
index 3cb0ebf1..35476d2b 100644
--- a/tests/base/test_api.py
+++ b/tests/base/test_api.py
@@ -151,6 +151,7 @@ class TestApiFromAutomatedClient(TestServerTestCase):
                 'accessor_id': in_acc_id,
                 'model_id': model_id,
                 'channel': 2,
+                'keep_interm': True,
             },
         )
         self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
@@ -166,6 +167,10 @@ class TestApiFromAutomatedClient(TestServerTestCase):
         acc_out = generate_file_accessor(fp_out)
         self.assertEqual(acc_out.shape_dict['C'], 1)
 
+        # validate intermediate data
+        resp_list = self._get(f'accessors').json()
+        self.assertEqual(len([k for k in resp_list.keys() if '_step' in k]), 2)
+
     def test_restarting_session_clears_loaded_models(self):
         resp_load = self._put(f'testing/models/dummy_semantic/load',)
         self.assertEqual(resp_load.status_code, 200, resp_load.json())
-- 
GitLab