From 240193386c41bc2402e5695059659ca9e7a0e861 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 15 Aug 2024 10:50:47 +0200
Subject: [PATCH] Added second segmentation pipeline with specific z-stack
 processing

---
 model_server/base/pipelines/segment_zproj.py | 47 ++++++++++++++++++++
 tests/base/test_pipelines.py                 | 26 ++++++++---
 2 files changed, 68 insertions(+), 5 deletions(-)
 create mode 100644 model_server/base/pipelines/segment_zproj.py

diff --git a/model_server/base/pipelines/segment_zproj.py b/model_server/base/pipelines/segment_zproj.py
new file mode 100644
index 00000000..1314fbb4
--- /dev/null
+++ b/model_server/base/pipelines/segment_zproj.py
@@ -0,0 +1,47 @@
+from typing import Dict
+
+from fastapi import APIRouter
+
+from .segment import SegmentParams, SegmentRecord, segment_pipeline
+from .shared import call_pipeline, PipelineTrace
+from ..accessors import GenericImageDataAccessor
+from ..models import Model
+
+from pydantic import Field
+
+router = APIRouter(
+    prefix='/pipelines',
+    tags=['pipelines'],
+)
+
+
+class SegmentZStackParams(SegmentParams):
+    zi: int = Field(None, description='z coordinate to use on input stack; apply MIP if empty')
+
+
+class SegmentZStackRecord(SegmentRecord):
+    pass
+
+
+@router.put('/segment_zproj')
+def segment_zproj(p: SegmentZStackParams) -> SegmentZStackRecord:
+    """
+    Run a semantic segmentation model to compute a binary mask from a projected input zstack
+    """
+    return call_pipeline(segment_zproj(), p)
+
+
+def segment_zproj_pipeline(
+        accessors: Dict[str, GenericImageDataAccessor],
+        models: Dict[str, Model],
+        **k
+) -> PipelineTrace:
+    d = PipelineTrace(accessors.get('accessor'))
+
+    if isinstance(k.get('zi'), int):
+        assert 0 < k['zi'] < d.last.nz
+        d['mip'] = d.last.get_zi(k['zi'])
+    else:
+        d['mip'] = d.last.get_mip()
+    return segment_pipeline({'accessor': d.last}, models, **k)
+
diff --git a/tests/base/test_pipelines.py b/tests/base/test_pipelines.py
index 84563127..3c643c0a 100644
--- a/tests/base/test_pipelines.py
+++ b/tests/base/test_pipelines.py
@@ -2,23 +2,24 @@ from pathlib import Path
 import unittest
 
 from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file
-from model_server.base.pipelines import segment
+from model_server.base.pipelines import segment, segment_zproj
 
 import model_server.conf.testing as conf
 from tests.base.test_model import DummySemanticSegmentationModel
 
 czifile = conf.meta['image_files']['czifile']
+zstack = conf.meta['image_files']['tifffile']
 output_path = conf.meta['output_path']
 
 
-class TestSegmentationPipeline(unittest.TestCase):
+class TestSegmentationPipelines(unittest.TestCase):
     def setUp(self) -> None:
         self.model = DummySemanticSegmentationModel()
 
-    def test_call_pipeline_function(self):
+    def test_call_segment_pipeline(self):
         acc = generate_file_accessor(czifile['path'])
         trace = segment.segment_pipeline({'accessor': acc}, {'model': self.model}, channel=2, smooth=3)
-        outfp = output_path / 'classify_pixels.tif'
+        outfp = output_path / 'pipelines' / 'segment_binary_mask.tif'
         write_accessor_data_to_file(outfp, trace.last)
 
         import tifffile
@@ -48,4 +49,19 @@ class TestSegmentationPipeline(unittest.TestCase):
             output_path / 'pipelines' / 'segment_interm',
             prefix=czifile['name']
         )
-        self.assertTrue([ofp.stem.split('_')[-1] for ofp in interm_fps] == ['mono', 'inference', 'smooth'])
\ No newline at end of file
+        self.assertTrue([ofp.stem.split('_')[-1] for ofp in interm_fps] == ['mono', 'inference', 'smooth'])
+
+    def test_call_segment_zproj_pipeline(self):
+        acc = generate_file_accessor(zstack['path'])
+
+        trace1 = segment_zproj.segment_zproj_pipeline({'accessor': acc}, {'model': self.model}, channel=0, smooth=3, zi=4)
+        self.assertEqual(trace1.last.chroma, 1)
+        self.assertEqual(trace1.last.nz, 1)
+
+        trace2 = segment_zproj.segment_zproj_pipeline({'accessor': acc}, {'model': self.model}, channel=0, smooth=3)
+        self.assertEqual(trace2.last.chroma, 1)
+        self.assertEqual(trace2.last.nz, 1)
+
+        trace3 = segment_zproj.segment_zproj_pipeline({'accessor': acc}, {'model': self.model})
+        self.assertEqual(trace3.last.chroma, 1)  # still == 1: model returns a single channel regardless of input
+        self.assertEqual(trace3.last.nz, 1)
-- 
GitLab