diff --git a/model_server/base/api.py b/model_server/base/api.py index 1d51cc04fbd964e8f082aa59889dffa10f41811e..db484d0f879d4717ceb1ebe40b6920383c222520 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -4,18 +4,11 @@ from fastapi import FastAPI, HTTPException from .accessors import generate_file_accessor from .session import session, AccessorIdError, InvalidPathError, WriteAccessorError -app = FastAPI(debug=True) - -# TODO: automate all this -import model_server.extensions.ilastik.router -app.include_router(model_server.extensions.ilastik.router.router) - -import model_server.base.pipelines.segment -app.include_router(model_server.base.pipelines.segment.router) +app = FastAPI(debug=True) -import model_server.base.pipelines.roiset_obmap -app.include_router(model_server.base.pipelines.roiset_obmap.router) +from .pipelines.router import router +app.include_router(router) @app.on_event("startup") diff --git a/model_server/base/pipelines/__init__.py b/model_server/base/pipelines/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..537a687dd8ff464c5584804da871167ef202993e 100644 --- a/model_server/base/pipelines/__init__.py +++ b/model_server/base/pipelines/__init__.py @@ -0,0 +1 @@ +__all__ = ['roiset_obmap', 'segment', 'segment_zproj'] \ No newline at end of file diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py index 24f88601afeb3a9a969452a402da762870f7db1d..c4991aa3e6ef27382e450a44aa22b44106a728ef 100644 --- a/model_server/base/pipelines/roiset_obmap.py +++ b/model_server/base/pipelines/roiset_obmap.py @@ -1,11 +1,11 @@ from typing import Dict, Union -from fastapi import APIRouter from pydantic import BaseModel, Field from ..accessors import GenericImageDataAccessor -from ..pipelines.segment_zproj import segment_zproj_pipeline -from ..pipelines.shared import call_pipeline +from .router import router +from .segment_zproj import segment_zproj_pipeline +from .shared import call_pipeline from ..roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams from ..session import session @@ -13,10 +13,6 @@ from ..pipelines.shared import PipelineTrace, PipelineParams, PipelineRecord from ..models import Model, InstanceSegmentationModel -router = APIRouter( - prefix='/pipelines', - tags=['pipelines'], -) class RoiSetObjectMapParams(PipelineParams): class _SegmentationParams(BaseModel): diff --git a/model_server/base/pipelines/router.py b/model_server/base/pipelines/router.py new file mode 100644 index 0000000000000000000000000000000000000000..3f77fd1d86e09bfc2600bedb0452454620f519e5 --- /dev/null +++ b/model_server/base/pipelines/router.py @@ -0,0 +1,9 @@ +from fastapi import APIRouter + +router = APIRouter( + prefix='/pipelines', + tags=['pipelines'], +) + +# this completes routing in individual pipeline modules +from . import * \ No newline at end of file diff --git a/model_server/base/pipelines/segment.py b/model_server/base/pipelines/segment.py index 2e34c209b857ae6f2643b0a64d951327102c9e3c..fac1835111872ab563b71d6795982e170d52e61f 100644 --- a/model_server/base/pipelines/segment.py +++ b/model_server/base/pipelines/segment.py @@ -1,19 +1,13 @@ from typing import Dict -from fastapi import APIRouter - from .shared import call_pipeline, IncompatibleModelsError, PipelineTrace, PipelineParams, PipelineRecord from ..accessors import GenericImageDataAccessor from ..models import Model, SemanticSegmentationModel from ..process import smooth +from .router import router from pydantic import Field -router = APIRouter( - prefix='/pipelines', - tags=['pipelines'], -) - class SegmentParams(PipelineParams): accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input') diff --git a/model_server/base/pipelines/segment_zproj.py b/model_server/base/pipelines/segment_zproj.py index 1314fbb483abf31c6f0e8d83f39850c9e5d7ca76..5f5ed9ba755433ef8445198b80885412e2161d98 100644 --- a/model_server/base/pipelines/segment_zproj.py +++ b/model_server/base/pipelines/segment_zproj.py @@ -1,7 +1,6 @@ from typing import Dict -from fastapi import APIRouter - +from .router import router from .segment import SegmentParams, SegmentRecord, segment_pipeline from .shared import call_pipeline, PipelineTrace from ..accessors import GenericImageDataAccessor @@ -9,12 +8,6 @@ from ..models import Model from pydantic import Field -router = APIRouter( - prefix='/pipelines', - tags=['pipelines'], -) - - class SegmentZStackParams(SegmentParams): zi: int = Field(None, description='z coordinate to use on input stack; apply MIP if empty') diff --git a/tests/base/test_api.py b/tests/base/test_api.py index 7f692af23e42651fed96b0714a9b4798f7b5c412..a6aad378e4f8a875573afbcc88308f55a0f8573c 100644 --- a/tests/base/test_api.py +++ b/tests/base/test_api.py @@ -1,11 +1,11 @@ from pathlib import Path -from fastapi import APIRouter +from fastapi import APIRouter, FastAPI import numpy as np from pydantic import BaseModel import model_server.conf.testing as conf -from model_server.base.accessors import InMemoryDataAccessor, generate_file_accessor +from model_server.base.accessors import InMemoryDataAccessor from model_server.base.api import app from model_server.base.session import session from tests.base.test_model import DummyInstanceSegmentationModel, DummySemanticSegmentationModel @@ -15,18 +15,17 @@ czifile = conf.meta['image_files']['czifile'] """ Configure additional endpoints for testing """ -testing_app = app -router = APIRouter(prefix='/testing', tags=['testing']) +test_router = APIRouter(prefix='/testing', tags=['testing']) class BounceBackParams(BaseModel): par1: str par2: list -@router.put('/bounce_back') +@test_router.put('/bounce_back') def list_bounce_back(params: BounceBackParams): return {'success': True, 'params': {'par1': params.par1, 'par2': params.par2}} -@router.put('/accessors/dummy_accessor/load') +@test_router.put('/accessors/dummy_accessor/load') def load_dummy_accessor() -> str: acc = InMemoryDataAccessor( np.random.randint( @@ -38,19 +37,19 @@ def load_dummy_accessor() -> str: ) return session.add_accessor(acc) -@router.put('/models/dummy_semantic/load/') +@test_router.put('/models/dummy_semantic/load/') def load_dummy_model() -> dict: mid = session.load_model(DummySemanticSegmentationModel) session.log_info(f'Loaded model {mid}') return {'model_id': mid} -@router.put('/models/dummy_instance/load/') +@test_router.put('/models/dummy_instance/load/') def load_dummy_model() -> dict: mid = session.load_model(DummyInstanceSegmentationModel) session.log_info(f'Loaded model {mid}') return {'model_id': mid} -app.include_router(router) +app.include_router(test_router) """ Implement unit testing on extended base app diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py index 9e2bbd3d42a4a0fdbf65b553852aa4b26cb94e91..e7744acb0c53e02353654602e7916b69de61679e 100644 --- a/tests/test_ilastik/test_ilastik.py +++ b/tests/test_ilastik/test_ilastik.py @@ -5,9 +5,11 @@ import unittest import numpy as np from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file +from model_server.base.api import app from model_server.extensions.ilastik import models as ilm from model_server.extensions.ilastik.pipelines import px_then_ob -from model_server.base.roiset import get_label_ids, RoiSet, RoiSetMetaParams +from model_server.extensions.ilastik.router import router +from model_server.base.roiset import RoiSet, RoiSetMetaParams from model_server.base.pipelines import segment import model_server.conf.testing as conf @@ -17,6 +19,8 @@ params = conf.meta['roiset'] czifile = conf.meta['image_files']['czifile'] ilastik_classifiers = conf.meta['ilastik_classifiers'] +app.include_router(router) + def _random_int(*args): return np.random.randint(0, 2 ** 8, size=args, dtype='uint8') @@ -197,10 +201,13 @@ class TestIlastikPixelClassification(unittest.TestCase): ) self.assertGreater(res.times['inference'], 0.1) -class TestIlastikOverApi(conf.TestServerBaseClass): +class TestServerTestCase(conf.TestServerBaseClass): + app_name = 'tests.test_ilastik.test_ilastik:app' input_data = czifile + +class TestIlastikOverApi(TestServerTestCase): def test_httpexception_if_incorrect_project_file_loaded(self): resp_load = self._put( 'ilastik/seg/load/', @@ -327,7 +334,7 @@ class TestIlastikOverApi(conf.TestServerBaseClass): self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) -class TestIlastikOnMultichannelInputs(conf.TestServerBaseClass): +class TestIlastikOnMultichannelInputs(TestServerTestCase): def setUp(self) -> None: super(TestIlastikOnMultichannelInputs, self).setUp() self.pa_px_classifier = ilastik_classifiers['px_color_zstack']['path']