From dce34bcf9dbb5cf0878423b6172ccdf690ca0a37 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 15 Aug 2024 14:53:17 +0200 Subject: [PATCH] Implemented hierarchical routing --- model_server/base/api.py | 13 +++---------- model_server/base/pipelines/__init__.py | 1 + model_server/base/pipelines/roiset_obmap.py | 10 +++------- model_server/base/pipelines/router.py | 9 +++++++++ model_server/base/pipelines/segment.py | 8 +------- model_server/base/pipelines/segment_zproj.py | 9 +-------- tests/base/test_api.py | 17 ++++++++--------- tests/test_ilastik/test_ilastik.py | 13 ++++++++++--- 8 files changed, 36 insertions(+), 44 deletions(-) create mode 100644 model_server/base/pipelines/router.py diff --git a/model_server/base/api.py b/model_server/base/api.py index 1d51cc04..db484d0f 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -4,18 +4,11 @@ from fastapi import FastAPI, HTTPException from .accessors import generate_file_accessor from .session import session, AccessorIdError, InvalidPathError, WriteAccessorError -app = FastAPI(debug=True) - -# TODO: automate all this -import model_server.extensions.ilastik.router -app.include_router(model_server.extensions.ilastik.router.router) - -import model_server.base.pipelines.segment -app.include_router(model_server.base.pipelines.segment.router) +app = FastAPI(debug=True) -import model_server.base.pipelines.roiset_obmap -app.include_router(model_server.base.pipelines.roiset_obmap.router) +from .pipelines.router import router +app.include_router(router) @app.on_event("startup") diff --git a/model_server/base/pipelines/__init__.py b/model_server/base/pipelines/__init__.py index e69de29b..537a687d 100644 --- a/model_server/base/pipelines/__init__.py +++ b/model_server/base/pipelines/__init__.py @@ -0,0 +1 @@ +__all__ = ['roiset_obmap', 'segment', 'segment_zproj'] \ No newline at end of file diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py index 24f88601..c4991aa3 100644 --- a/model_server/base/pipelines/roiset_obmap.py +++ b/model_server/base/pipelines/roiset_obmap.py @@ -1,11 +1,11 @@ from typing import Dict, Union -from fastapi import APIRouter from pydantic import BaseModel, Field from ..accessors import GenericImageDataAccessor -from ..pipelines.segment_zproj import segment_zproj_pipeline -from ..pipelines.shared import call_pipeline +from .router import router +from .segment_zproj import segment_zproj_pipeline +from .shared import call_pipeline from ..roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams from ..session import session @@ -13,10 +13,6 @@ from ..pipelines.shared import PipelineTrace, PipelineParams, PipelineRecord from ..models import Model, InstanceSegmentationModel -router = APIRouter( - prefix='/pipelines', - tags=['pipelines'], -) class RoiSetObjectMapParams(PipelineParams): class _SegmentationParams(BaseModel): diff --git a/model_server/base/pipelines/router.py b/model_server/base/pipelines/router.py new file mode 100644 index 00000000..3f77fd1d --- /dev/null +++ b/model_server/base/pipelines/router.py @@ -0,0 +1,9 @@ +from fastapi import APIRouter + +router = APIRouter( + prefix='/pipelines', + tags=['pipelines'], +) + +# this completes routing in individual pipeline modules +from . import * \ No newline at end of file diff --git a/model_server/base/pipelines/segment.py b/model_server/base/pipelines/segment.py index 2e34c209..fac18351 100644 --- a/model_server/base/pipelines/segment.py +++ b/model_server/base/pipelines/segment.py @@ -1,19 +1,13 @@ from typing import Dict -from fastapi import APIRouter - from .shared import call_pipeline, IncompatibleModelsError, PipelineTrace, PipelineParams, PipelineRecord from ..accessors import GenericImageDataAccessor from ..models import Model, SemanticSegmentationModel from ..process import smooth +from .router import router from pydantic import Field -router = APIRouter( - prefix='/pipelines', - tags=['pipelines'], -) - class SegmentParams(PipelineParams): accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input') diff --git a/model_server/base/pipelines/segment_zproj.py b/model_server/base/pipelines/segment_zproj.py index 1314fbb4..5f5ed9ba 100644 --- a/model_server/base/pipelines/segment_zproj.py +++ b/model_server/base/pipelines/segment_zproj.py @@ -1,7 +1,6 @@ from typing import Dict -from fastapi import APIRouter - +from .router import router from .segment import SegmentParams, SegmentRecord, segment_pipeline from .shared import call_pipeline, PipelineTrace from ..accessors import GenericImageDataAccessor @@ -9,12 +8,6 @@ from ..models import Model from pydantic import Field -router = APIRouter( - prefix='/pipelines', - tags=['pipelines'], -) - - class SegmentZStackParams(SegmentParams): zi: int = Field(None, description='z coordinate to use on input stack; apply MIP if empty') diff --git a/tests/base/test_api.py b/tests/base/test_api.py index 7f692af2..a6aad378 100644 --- a/tests/base/test_api.py +++ b/tests/base/test_api.py @@ -1,11 +1,11 @@ from pathlib import Path -from fastapi import APIRouter +from fastapi import APIRouter, FastAPI import numpy as np from pydantic import BaseModel import model_server.conf.testing as conf -from model_server.base.accessors import InMemoryDataAccessor, generate_file_accessor +from model_server.base.accessors import InMemoryDataAccessor from model_server.base.api import app from model_server.base.session import session from tests.base.test_model import DummyInstanceSegmentationModel, DummySemanticSegmentationModel @@ -15,18 +15,17 @@ czifile = conf.meta['image_files']['czifile'] """ Configure additional endpoints for testing """ -testing_app = app -router = APIRouter(prefix='/testing', tags=['testing']) +test_router = APIRouter(prefix='/testing', tags=['testing']) class BounceBackParams(BaseModel): par1: str par2: list -@router.put('/bounce_back') +@test_router.put('/bounce_back') def list_bounce_back(params: BounceBackParams): return {'success': True, 'params': {'par1': params.par1, 'par2': params.par2}} -@router.put('/accessors/dummy_accessor/load') +@test_router.put('/accessors/dummy_accessor/load') def load_dummy_accessor() -> str: acc = InMemoryDataAccessor( np.random.randint( @@ -38,19 +37,19 @@ def load_dummy_accessor() -> str: ) return session.add_accessor(acc) -@router.put('/models/dummy_semantic/load/') +@test_router.put('/models/dummy_semantic/load/') def load_dummy_model() -> dict: mid = session.load_model(DummySemanticSegmentationModel) session.log_info(f'Loaded model {mid}') return {'model_id': mid} -@router.put('/models/dummy_instance/load/') +@test_router.put('/models/dummy_instance/load/') def load_dummy_model() -> dict: mid = session.load_model(DummyInstanceSegmentationModel) session.log_info(f'Loaded model {mid}') return {'model_id': mid} -app.include_router(router) +app.include_router(test_router) """ Implement unit testing on extended base app diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py index 9e2bbd3d..e7744acb 100644 --- a/tests/test_ilastik/test_ilastik.py +++ b/tests/test_ilastik/test_ilastik.py @@ -5,9 +5,11 @@ import unittest import numpy as np from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file +from model_server.base.api import app from model_server.extensions.ilastik import models as ilm from model_server.extensions.ilastik.pipelines import px_then_ob -from model_server.base.roiset import get_label_ids, RoiSet, RoiSetMetaParams +from model_server.extensions.ilastik.router import router +from model_server.base.roiset import RoiSet, RoiSetMetaParams from model_server.base.pipelines import segment import model_server.conf.testing as conf @@ -17,6 +19,8 @@ params = conf.meta['roiset'] czifile = conf.meta['image_files']['czifile'] ilastik_classifiers = conf.meta['ilastik_classifiers'] +app.include_router(router) + def _random_int(*args): return np.random.randint(0, 2 ** 8, size=args, dtype='uint8') @@ -197,10 +201,13 @@ class TestIlastikPixelClassification(unittest.TestCase): ) self.assertGreater(res.times['inference'], 0.1) -class TestIlastikOverApi(conf.TestServerBaseClass): +class TestServerTestCase(conf.TestServerBaseClass): + app_name = 'tests.test_ilastik.test_ilastik:app' input_data = czifile + +class TestIlastikOverApi(TestServerTestCase): def test_httpexception_if_incorrect_project_file_loaded(self): resp_load = self._put( 'ilastik/seg/load/', @@ -327,7 +334,7 @@ class TestIlastikOverApi(conf.TestServerBaseClass): self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) -class TestIlastikOnMultichannelInputs(conf.TestServerBaseClass): +class TestIlastikOnMultichannelInputs(TestServerTestCase): def setUp(self) -> None: super(TestIlastikOnMultichannelInputs, self).setUp() self.pa_px_classifier = ilastik_classifiers['px_color_zstack']['path'] -- GitLab