From dce34bcf9dbb5cf0878423b6172ccdf690ca0a37 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 15 Aug 2024 14:53:17 +0200
Subject: [PATCH] Implemented hierarchical routing

---
 model_server/base/api.py                     | 13 +++----------
 model_server/base/pipelines/__init__.py      |  1 +
 model_server/base/pipelines/roiset_obmap.py  | 10 +++-------
 model_server/base/pipelines/router.py        |  9 +++++++++
 model_server/base/pipelines/segment.py       |  8 +-------
 model_server/base/pipelines/segment_zproj.py |  9 +--------
 tests/base/test_api.py                       | 17 ++++++++---------
 tests/test_ilastik/test_ilastik.py           | 13 ++++++++++---
 8 files changed, 36 insertions(+), 44 deletions(-)
 create mode 100644 model_server/base/pipelines/router.py

diff --git a/model_server/base/api.py b/model_server/base/api.py
index 1d51cc04..db484d0f 100644
--- a/model_server/base/api.py
+++ b/model_server/base/api.py
@@ -4,18 +4,11 @@ from fastapi import FastAPI, HTTPException
 from .accessors import generate_file_accessor
 from .session import session, AccessorIdError, InvalidPathError, WriteAccessorError
 
-app = FastAPI(debug=True)
-
-# TODO: automate all this
 
-import model_server.extensions.ilastik.router
-app.include_router(model_server.extensions.ilastik.router.router)
-
-import model_server.base.pipelines.segment
-app.include_router(model_server.base.pipelines.segment.router)
+app = FastAPI(debug=True)
 
-import model_server.base.pipelines.roiset_obmap
-app.include_router(model_server.base.pipelines.roiset_obmap.router)
+from .pipelines.router import router
+app.include_router(router)
 
 
 @app.on_event("startup")
diff --git a/model_server/base/pipelines/__init__.py b/model_server/base/pipelines/__init__.py
index e69de29b..537a687d 100644
--- a/model_server/base/pipelines/__init__.py
+++ b/model_server/base/pipelines/__init__.py
@@ -0,0 +1 @@
+__all__ = ['roiset_obmap', 'segment', 'segment_zproj']
\ No newline at end of file
diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py
index 24f88601..c4991aa3 100644
--- a/model_server/base/pipelines/roiset_obmap.py
+++ b/model_server/base/pipelines/roiset_obmap.py
@@ -1,11 +1,11 @@
 from typing import Dict, Union
 
-from fastapi import APIRouter
 from pydantic import BaseModel, Field
 
 from ..accessors import GenericImageDataAccessor
-from ..pipelines.segment_zproj import segment_zproj_pipeline
-from ..pipelines.shared import call_pipeline
+from .router import router
+from .segment_zproj import segment_zproj_pipeline
+from .shared import call_pipeline
 from ..roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams
 from ..session import session
 
@@ -13,10 +13,6 @@ from ..pipelines.shared import PipelineTrace, PipelineParams, PipelineRecord
 
 from ..models import Model, InstanceSegmentationModel
 
-router = APIRouter(
-    prefix='/pipelines',
-    tags=['pipelines'],
-)
 
 class RoiSetObjectMapParams(PipelineParams):
     class _SegmentationParams(BaseModel):
diff --git a/model_server/base/pipelines/router.py b/model_server/base/pipelines/router.py
new file mode 100644
index 00000000..3f77fd1d
--- /dev/null
+++ b/model_server/base/pipelines/router.py
@@ -0,0 +1,9 @@
+from fastapi import APIRouter
+
+router = APIRouter(
+    prefix='/pipelines',
+    tags=['pipelines'],
+)
+
+# this completes routing in individual pipeline modules
+from . import *
\ No newline at end of file
diff --git a/model_server/base/pipelines/segment.py b/model_server/base/pipelines/segment.py
index 2e34c209..fac18351 100644
--- a/model_server/base/pipelines/segment.py
+++ b/model_server/base/pipelines/segment.py
@@ -1,19 +1,13 @@
 from typing import Dict
 
-from fastapi import APIRouter
-
 from .shared import call_pipeline, IncompatibleModelsError, PipelineTrace, PipelineParams, PipelineRecord
 from ..accessors import GenericImageDataAccessor
 from ..models import Model, SemanticSegmentationModel
 from ..process import smooth
+from .router import router
 
 from pydantic import Field
 
-router = APIRouter(
-    prefix='/pipelines',
-    tags=['pipelines'],
-)
-
 
 class SegmentParams(PipelineParams):
     accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input')
diff --git a/model_server/base/pipelines/segment_zproj.py b/model_server/base/pipelines/segment_zproj.py
index 1314fbb4..5f5ed9ba 100644
--- a/model_server/base/pipelines/segment_zproj.py
+++ b/model_server/base/pipelines/segment_zproj.py
@@ -1,7 +1,6 @@
 from typing import Dict
 
-from fastapi import APIRouter
-
+from .router import router
 from .segment import SegmentParams, SegmentRecord, segment_pipeline
 from .shared import call_pipeline, PipelineTrace
 from ..accessors import GenericImageDataAccessor
@@ -9,12 +8,6 @@ from ..models import Model
 
 from pydantic import Field
 
-router = APIRouter(
-    prefix='/pipelines',
-    tags=['pipelines'],
-)
-
-
 class SegmentZStackParams(SegmentParams):
     zi: int = Field(None, description='z coordinate to use on input stack; apply MIP if empty')
 
diff --git a/tests/base/test_api.py b/tests/base/test_api.py
index 7f692af2..a6aad378 100644
--- a/tests/base/test_api.py
+++ b/tests/base/test_api.py
@@ -1,11 +1,11 @@
 from pathlib import Path
 
-from fastapi import APIRouter
+from fastapi import APIRouter, FastAPI
 import numpy as np
 from pydantic import BaseModel
 
 import model_server.conf.testing as conf
-from model_server.base.accessors import InMemoryDataAccessor, generate_file_accessor
+from model_server.base.accessors import InMemoryDataAccessor
 from model_server.base.api import app
 from model_server.base.session import session
 from tests.base.test_model import DummyInstanceSegmentationModel, DummySemanticSegmentationModel
@@ -15,18 +15,17 @@ czifile = conf.meta['image_files']['czifile']
 """
 Configure additional endpoints for testing
 """
-testing_app = app
-router = APIRouter(prefix='/testing', tags=['testing'])
+test_router = APIRouter(prefix='/testing', tags=['testing'])
 
 class BounceBackParams(BaseModel):
     par1: str
     par2: list
 
-@router.put('/bounce_back')
+@test_router.put('/bounce_back')
 def list_bounce_back(params: BounceBackParams):
     return {'success': True, 'params': {'par1': params.par1, 'par2': params.par2}}
 
-@router.put('/accessors/dummy_accessor/load')
+@test_router.put('/accessors/dummy_accessor/load')
 def load_dummy_accessor() -> str:
     acc = InMemoryDataAccessor(
         np.random.randint(
@@ -38,19 +37,19 @@ def load_dummy_accessor() -> str:
     )
     return session.add_accessor(acc)
 
-@router.put('/models/dummy_semantic/load/')
+@test_router.put('/models/dummy_semantic/load/')
 def load_dummy_model() -> dict:
     mid = session.load_model(DummySemanticSegmentationModel)
     session.log_info(f'Loaded model {mid}')
     return {'model_id': mid}
 
-@router.put('/models/dummy_instance/load/')
+@test_router.put('/models/dummy_instance/load/')
 def load_dummy_model() -> dict:
     mid = session.load_model(DummyInstanceSegmentationModel)
     session.log_info(f'Loaded model {mid}')
     return {'model_id': mid}
 
-app.include_router(router)
+app.include_router(test_router)
 
 """
 Implement unit testing on extended base app
diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py
index 9e2bbd3d..e7744acb 100644
--- a/tests/test_ilastik/test_ilastik.py
+++ b/tests/test_ilastik/test_ilastik.py
@@ -5,9 +5,11 @@ import unittest
 import numpy as np
 
 from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file
+from model_server.base.api import app
 from model_server.extensions.ilastik import models as ilm
 from model_server.extensions.ilastik.pipelines import px_then_ob
-from model_server.base.roiset import get_label_ids, RoiSet, RoiSetMetaParams
+from model_server.extensions.ilastik.router import router
+from model_server.base.roiset import RoiSet, RoiSetMetaParams
 from model_server.base.pipelines import segment
 import model_server.conf.testing as conf
 
@@ -17,6 +19,8 @@ params = conf.meta['roiset']
 czifile = conf.meta['image_files']['czifile']
 ilastik_classifiers = conf.meta['ilastik_classifiers']
 
+app.include_router(router)
+
 def _random_int(*args):
     return np.random.randint(0, 2 ** 8, size=args, dtype='uint8')
 
@@ -197,10 +201,13 @@ class TestIlastikPixelClassification(unittest.TestCase):
         )
         self.assertGreater(res.times['inference'], 0.1)
 
-class TestIlastikOverApi(conf.TestServerBaseClass):
 
+class TestServerTestCase(conf.TestServerBaseClass):
+    app_name = 'tests.test_ilastik.test_ilastik:app'
     input_data = czifile
 
+
+class TestIlastikOverApi(TestServerTestCase):
     def test_httpexception_if_incorrect_project_file_loaded(self):
         resp_load = self._put(
             'ilastik/seg/load/',
@@ -327,7 +334,7 @@ class TestIlastikOverApi(conf.TestServerBaseClass):
         self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
 
 
-class TestIlastikOnMultichannelInputs(conf.TestServerBaseClass):
+class TestIlastikOnMultichannelInputs(TestServerTestCase):
     def setUp(self) -> None:
         super(TestIlastikOnMultichannelInputs, self).setUp()
         self.pa_px_classifier = ilastik_classifiers['px_color_zstack']['path']
-- 
GitLab