From 2464806cd18ab4be5187d019d7568cb690b951d2 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 8 Aug 2024 11:44:22 +0200
Subject: [PATCH] Implemented and tested method to write intermediate pipeline
 products to file

---
 model_server/base/util.py    | 49 +++++++++++++++++++++++++++++++++---
 tests/base/test_pipelines.py |  8 +++++-
 2 files changed, 52 insertions(+), 5 deletions(-)

diff --git a/model_server/base/util.py b/model_server/base/util.py
index 97e01ab3..f4299979 100644
--- a/model_server/base/util.py
+++ b/model_server/base/util.py
@@ -171,6 +171,12 @@ class PipelineTrace(OrderedDict):
     tfunc = perf_counter
 
     def __init__(self, enforce_accessors=True, allow_overwrite=False):
+        """
+        A container and timer for data at each stage of a pipeline.
+
+        :param enforce_accessors: if True, only allow accessors to be appended as items
+        :param allow_overwrite: if True, allow an item to be overwritten
+        """
         self.enforce_accessors = enforce_accessors
         self.allow_overwrite = allow_overwrite
         self.last_time = self.tfunc()
@@ -187,18 +193,53 @@ class PipelineTrace(OrderedDict):
         self.last_time = self.tfunc()
         return super().__setitem__(key, value)
 
-    @property
-    def accessors(self):
-        return self.items()
-
     @property
     def times(self):
+        """
+        Return an ordered dictionary of incremental times for each item that is appended
+        """
         return {k: self.timer[k] for k in self.keys()}
 
     @property
     def last(self):
+        """
+        Return most recently appended item
+        :return:
+        """
         return list(self.values())[-1]
 
+    def write_interm(
+            self,
+            where: Path,
+            prefix: str = 'interm',
+            skip_first=True,
+            skip_last=True,
+            debug=False
+    ) -> List[Path]:
+        """
+        Write accessor data to TIF files under specified path
+        :param where: directory in which to write image files
+        :param prefix: (optional) file prefix
+        :param skip_first: if True, do not write first item in trace
+        :param skip_last: if False, do not write last item in trace
+        :param debug: if True, report destination filepaths but do not write files
+        :return: list of destination filepaths
+        """
+        paths = []
+        for i, item in enumerate(self.items()):
+            k, v = item
+            if not isinstance(v, GenericImageDataAccessor):
+                continue
+            if skip_first and k == list(self.keys())[0]:
+                continue
+            if skip_last and k == list(self.keys())[-1]:
+                continue
+            fp = where / f'{prefix}_{i:02d}_{k}.tif'
+            paths.append(fp)
+            if not debug:
+                v.write(fp)
+        return paths
+
 
 class Error(Exception):
     pass
diff --git a/tests/base/test_pipelines.py b/tests/base/test_pipelines.py
index 774b789f..580c5e3d 100644
--- a/tests/base/test_pipelines.py
+++ b/tests/base/test_pipelines.py
@@ -39,4 +39,10 @@ class TestGetSessionObject(unittest.TestCase):
             img[0, 0],
             0,
             'First pixel is not black as expected'
-        )
\ No newline at end of file
+        )
+
+        interm_fps = trace.write_interm(
+            output_path / 'pipelines' / 'segment_interm',
+            prefix=czifile['name']
+        )
+        self.assertTrue([ofp.stem.split('_')[-1] for ofp in interm_fps] == ['mono', 'inference', 'smooth'])
\ No newline at end of file
-- 
GitLab