From 6735a2f784f9baf0f12c61e0711aa4fdab9e6f2a Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 12 Aug 2024 10:40:36 +0200
Subject: [PATCH] Parameters auto-validates anything ending in accessor_id and
 model_id; call_pipeline() maps these to objects and sends to pipeline
 function as dicts of respective objects

---
 model_server/base/pipelines/params.py  | 38 ++++++++++++++------------
 model_server/base/pipelines/segment.py | 22 +++++++++++----
 model_server/base/pipelines/util.py    | 19 +++++++++----
 tests/base/test_api.py                 | 31 +++++++++++++++++----
 tests/base/test_pipelines.py           |  2 ++
 5 files changed, 78 insertions(+), 34 deletions(-)

diff --git a/model_server/base/pipelines/params.py b/model_server/base/pipelines/params.py
index 9f1bda03..d07d68fe 100644
--- a/model_server/base/pipelines/params.py
+++ b/model_server/base/pipelines/params.py
@@ -1,31 +1,33 @@
 from typing import List, Union
 
 from fastapi import HTTPException
-from pydantic import BaseModel, Field, validator
-
+from pydantic import BaseModel, Field, root_validator
 from ..session import session, AccessorIdError
 
 
-class SingleModelPipelineParams(BaseModel):
+class PipelineParams(BaseModel):
     accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input')
     model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)')
     keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session')
 
-    @validator('model_id')
-    def models_are_loaded(cls, model_id):
-        if model_id not in session.describe_loaded_models().keys():
-            raise HTTPException(status_code=409, detail=f'Model with ID {model_id} has not been loaded')
-        return model_id
-
-    @validator('accessor_id')
-    def accessors_are_loaded(cls, accessor_id):
-        try:
-            info = session.get_accessor_info(accessor_id)
-        except AccessorIdError as e:
-            raise HTTPException(status_code=409, detail=str(e))
-        if not info['loaded']:
-            raise HTTPException(status_code=409, detail=f'Accessor with ID {accessor_id} is not loaded')
-        return accessor_id
+    @root_validator(pre=False)
+    def models_are_loaded(cls, dd):
+        for k, v in dd.items():
+            if k.endswith('model_id') and v not in session.describe_loaded_models().keys():
+                raise HTTPException(status_code=409, detail=f'Model with {k} = {v} has not been loaded')
+        return dd
+
+    @root_validator(pre=False)
+    def accessors_are_loaded(cls, dd):
+        for k, v in dd.items():
+            if k.endswith('accessor_id'):
+                try:
+                    info = session.get_accessor_info(v)
+                except AccessorIdError as e:
+                    raise HTTPException(status_code=409, detail=str(e))
+                if not info['loaded']:
+                    raise HTTPException(status_code=409, detail=f'Accessor with {k} = {v} has not been loaded')
+        return dd
 
 
 class PipelineRecord(BaseModel):
diff --git a/model_server/base/pipelines/segment.py b/model_server/base/pipelines/segment.py
index 3dfa86d3..2e6fde6c 100644
--- a/model_server/base/pipelines/segment.py
+++ b/model_server/base/pipelines/segment.py
@@ -1,9 +1,11 @@
+from typing import Dict
+
 from fastapi import APIRouter
 
 from .util import call_pipeline
 from ..accessors import GenericImageDataAccessor
-from ..models import SemanticSegmentationModel
-from .params import SingleModelPipelineParams, PipelineRecord
+from ..models import Model
+from .params import PipelineParams, PipelineRecord
 from ..process import smooth
 from ..util import PipelineTrace
 
@@ -14,7 +16,9 @@ router = APIRouter(
     tags=['pipelines'],
 )
 
-class SegmentParams(SingleModelPipelineParams):
+class SegmentParams(PipelineParams):
+    accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input')
+    model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)')
     channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.')
 
 
@@ -25,11 +29,17 @@ def segment(p: SegmentParams) -> PipelineRecord:
     """
     return call_pipeline(segment_pipeline, p)
 
-def segment_pipeline(acc_in: GenericImageDataAccessor, model: SemanticSegmentationModel, **k) -> PipelineTrace:
+def segment_pipeline(
+        accessors: Dict[str, GenericImageDataAccessor],
+        models: Dict[str, Model],
+        **k
+) -> PipelineTrace:
+    input_accessor = accessors.get('accessor')
+    model = models.get('model')
     d = PipelineTrace()
-    d['input'] = acc_in
+    d['input'] = input_accessor
     if ch := k.get('channel') is not None:
-        d['mono'] = acc_in.get_mono(ch)
+        d['mono'] = input_accessor.get_mono(ch)
     d['inference'] = model.label_pixel_class(d.last)
     if sm := k.get('smooth') is not None:
         d['smooth'] = d.last.apply(lambda x: smooth(x, sm))
diff --git a/model_server/base/pipelines/util.py b/model_server/base/pipelines/util.py
index a4a37feb..50459da4 100644
--- a/model_server/base/pipelines/util.py
+++ b/model_server/base/pipelines/util.py
@@ -1,13 +1,22 @@
-from .params import SingleModelPipelineParams, PipelineRecord
+from .params import PipelineParams, PipelineRecord
 from ..session import session
 
 
-def call_pipeline(func, p: SingleModelPipelineParams):
-    acc_in = session.get_accessor(p.accessor_id, pop=True)
+def call_pipeline(func, p: PipelineParams):
+
+    accessors_in = {}
+    for k, v in p.dict().items():
+        if k.endswith('accessor_id'):
+            accessors_in[k.split('_id')[0]] = session.get_accessor(v, pop=True)
+
+    models = {}
+    for k, v in p.dict().items():
+        if k.endswith('model_id'):
+            models[k.split('_id')[0]] = session.models[v]['object']
 
     steps = func(
-        acc_in,
-        session.models[p.model_id]['object'],
+        accessors_in,
+        models,
         **p.dict(),
     )
 
diff --git a/tests/base/test_api.py b/tests/base/test_api.py
index 35476d2b..bf5bc87e 100644
--- a/tests/base/test_api.py
+++ b/tests/base/test_api.py
@@ -127,12 +127,33 @@ class TestApiFromAutomatedClient(TestServerTestCase):
 
 
     def test_pipeline_errors_when_ids_not_found(self):
-        model_id = 'not_a_real_model'
-        resp = self._put(
-            f'pipelines/segment',
-           body={'model_id': model_id, 'accessor_id': 'fake'}
+        self.copy_input_file_to_server()
+        model_id = self._put(f'testing/models/dummy_semantic/load').json()['model_id']
+        in_acc_id = self._put(
+            f'accessors/read_from_file',
+            query={
+                'filename': czifile['name'],
+            },
+        ).json()
+
+        # respond with 409 for invalid accessor_id
+        self.assertEqual(
+            self._put(
+                f'pipelines/segment',
+                body={'model_id': model_id, 'accessor_id': 'fake'}
+            ).status_code,
+            409
         )
-        self.assertEqual(resp.status_code, 409, resp.content.decode())
+
+        # respond with 409 for invalid model_id
+        self.assertEqual(
+            self._put(
+                f'pipelines/segment',
+                body={'model_id': 'fake', 'accessor_id': in_acc_id}
+            ).status_code,
+            409
+        )
+
 
     def test_i2i_dummy_inference_by_api(self):
         self.copy_input_file_to_server()
diff --git a/tests/base/test_pipelines.py b/tests/base/test_pipelines.py
index f900f941..bbcd2305 100644
--- a/tests/base/test_pipelines.py
+++ b/tests/base/test_pipelines.py
@@ -3,12 +3,14 @@ import unittest
 
 from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file
 from model_server.base.pipelines import segment
+
 import model_server.conf.testing as conf
 from tests.base.test_model import DummySemanticSegmentationModel
 
 czifile = conf.meta['image_files']['czifile']
 output_path = conf.meta['output_path']
 
+
 class TestSegmentationPipeline(unittest.TestCase):
     def setUp(self) -> None:
         self.model = DummySemanticSegmentationModel()
-- 
GitLab