From 240193386c41bc2402e5695059659ca9e7a0e861 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 15 Aug 2024 10:50:47 +0200 Subject: [PATCH] Added second segmentation pipeline with specific z-stack processing --- model_server/base/pipelines/segment_zproj.py | 47 ++++++++++++++++++++ tests/base/test_pipelines.py | 26 ++++++++--- 2 files changed, 68 insertions(+), 5 deletions(-) create mode 100644 model_server/base/pipelines/segment_zproj.py diff --git a/model_server/base/pipelines/segment_zproj.py b/model_server/base/pipelines/segment_zproj.py new file mode 100644 index 00000000..1314fbb4 --- /dev/null +++ b/model_server/base/pipelines/segment_zproj.py @@ -0,0 +1,47 @@ +from typing import Dict + +from fastapi import APIRouter + +from .segment import SegmentParams, SegmentRecord, segment_pipeline +from .shared import call_pipeline, PipelineTrace +from ..accessors import GenericImageDataAccessor +from ..models import Model + +from pydantic import Field + +router = APIRouter( + prefix='/pipelines', + tags=['pipelines'], +) + + +class SegmentZStackParams(SegmentParams): + zi: int = Field(None, description='z coordinate to use on input stack; apply MIP if empty') + + +class SegmentZStackRecord(SegmentRecord): + pass + + +@router.put('/segment_zproj') +def segment_zproj(p: SegmentZStackParams) -> SegmentZStackRecord: + """ + Run a semantic segmentation model to compute a binary mask from a projected input zstack + """ + return call_pipeline(segment_zproj(), p) + + +def segment_zproj_pipeline( + accessors: Dict[str, GenericImageDataAccessor], + models: Dict[str, Model], + **k +) -> PipelineTrace: + d = PipelineTrace(accessors.get('accessor')) + + if isinstance(k.get('zi'), int): + assert 0 < k['zi'] < d.last.nz + d['mip'] = d.last.get_zi(k['zi']) + else: + d['mip'] = d.last.get_mip() + return segment_pipeline({'accessor': d.last}, models, **k) + diff --git a/tests/base/test_pipelines.py b/tests/base/test_pipelines.py index 84563127..3c643c0a 100644 --- a/tests/base/test_pipelines.py +++ b/tests/base/test_pipelines.py @@ -2,23 +2,24 @@ from pathlib import Path import unittest from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file -from model_server.base.pipelines import segment +from model_server.base.pipelines import segment, segment_zproj import model_server.conf.testing as conf from tests.base.test_model import DummySemanticSegmentationModel czifile = conf.meta['image_files']['czifile'] +zstack = conf.meta['image_files']['tifffile'] output_path = conf.meta['output_path'] -class TestSegmentationPipeline(unittest.TestCase): +class TestSegmentationPipelines(unittest.TestCase): def setUp(self) -> None: self.model = DummySemanticSegmentationModel() - def test_call_pipeline_function(self): + def test_call_segment_pipeline(self): acc = generate_file_accessor(czifile['path']) trace = segment.segment_pipeline({'accessor': acc}, {'model': self.model}, channel=2, smooth=3) - outfp = output_path / 'classify_pixels.tif' + outfp = output_path / 'pipelines' / 'segment_binary_mask.tif' write_accessor_data_to_file(outfp, trace.last) import tifffile @@ -48,4 +49,19 @@ class TestSegmentationPipeline(unittest.TestCase): output_path / 'pipelines' / 'segment_interm', prefix=czifile['name'] ) - self.assertTrue([ofp.stem.split('_')[-1] for ofp in interm_fps] == ['mono', 'inference', 'smooth']) \ No newline at end of file + self.assertTrue([ofp.stem.split('_')[-1] for ofp in interm_fps] == ['mono', 'inference', 'smooth']) + + def test_call_segment_zproj_pipeline(self): + acc = generate_file_accessor(zstack['path']) + + trace1 = segment_zproj.segment_zproj_pipeline({'accessor': acc}, {'model': self.model}, channel=0, smooth=3, zi=4) + self.assertEqual(trace1.last.chroma, 1) + self.assertEqual(trace1.last.nz, 1) + + trace2 = segment_zproj.segment_zproj_pipeline({'accessor': acc}, {'model': self.model}, channel=0, smooth=3) + self.assertEqual(trace2.last.chroma, 1) + self.assertEqual(trace2.last.nz, 1) + + trace3 = segment_zproj.segment_zproj_pipeline({'accessor': acc}, {'model': self.model}) + self.assertEqual(trace3.last.chroma, 1) # still == 1: model returns a single channel regardless of input + self.assertEqual(trace3.last.nz, 1) -- GitLab