From 423afe65fa31d0ac20d356889d7949960f921df9 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 31 Oct 2024 12:05:56 +0100
Subject: [PATCH] Can append stacks for tracking pipelines

---
 model_server/base/pipelines/shared.py | 21 +++++++++++++++++++++
 tests/base/test_pipelines.py          | 16 ++++++++++++++++
 2 files changed, 37 insertions(+)

diff --git a/model_server/base/pipelines/shared.py b/model_server/base/pipelines/shared.py
index 13d1b825..44ec0c2c 100644
--- a/model_server/base/pipelines/shared.py
+++ b/model_server/base/pipelines/shared.py
@@ -140,6 +140,19 @@ class PipelineTrace(OrderedDict):
         self.last_time = self.tfunc()
         return super().__setitem__(key, value)
 
+    def append(self, tr):
+        new_tr = self.copy()
+        for k, v in tr.items():
+            dt = tr.timer[k]
+            if k == 'input':
+                k = 'appended_input'
+            if not self.allow_overwrite and k in self.keys():
+                raise KeyAlreadyExists(f'Trying to append trace with key {k} that already exists')
+            new_tr.__setitem__(k, v)
+            new_tr.timer.__setitem__(k, dt)
+        new_tr.last_time = self.tfunc()
+        return new_tr
+
     @property
     def times(self):
         """
@@ -147,6 +160,14 @@ class PipelineTrace(OrderedDict):
         """
         return {k: self.timer[k] for k in self.keys()}
 
+    @property
+    def first(self):
+        """
+        Return first item
+        :return:
+        """
+        return list(self.values())[0]
+
     @property
     def last(self):
         """
diff --git a/tests/base/test_pipelines.py b/tests/base/test_pipelines.py
index efc0349a..ee7d6507 100644
--- a/tests/base/test_pipelines.py
+++ b/tests/base/test_pipelines.py
@@ -1,7 +1,10 @@
 import unittest
 
+import numpy as np
+
 from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file
 from model_server.base.pipelines import router, segment, segment_zproj
+from model_server.base.pipelines.shared import PipelineTrace
 
 import model_server.conf.testing as conf
 from model_server.conf.testing import DummySemanticSegmentationModel
@@ -64,3 +67,16 @@ class TestSegmentationPipelines(unittest.TestCase):
         trace3 = segment_zproj.segment_zproj_pipeline({'': acc}, {'': self.model})
         self.assertEqual(trace3.last.chroma, 1)  # still == 1: model returns a single channel regardless of input
         self.assertEqual(trace3.last.nz, 1)
+
+    def test_append_traces(self):
+        acc = generate_file_accessor(zstack['path'])
+        trace1 = PipelineTrace(acc)
+        trace1['halve'] = trace1.last.apply(lambda x: 0.5 * x)
+
+        trace2 = PipelineTrace(trace1.last)
+        trace2['double'] = trace2.last.apply(lambda x: 2 * x)
+        trace3 = trace1.append(trace2)
+
+        self.assertEqual(len(trace3), len(trace1) + len(trace2))
+        self.assertEqual(trace3['halve'], trace3['appended_input'])
+        self.assertTrue(np.all(trace3['input'].data == trace3['double'].data))
\ No newline at end of file
-- 
GitLab