diff --git a/api.py b/api.py index 32eea18793e676d6852d7d9f4287833c06d5d483..0167a17c3f1a3196b7ba4b5c4c16bd750d6f7cbb 100644 --- a/api.py +++ b/api.py @@ -2,7 +2,7 @@ from typing import Dict from fastapi import FastAPI, HTTPException -from model_server.session import Session +from model_server.session import CouldNotFindModelError, Session from model_server.workflow import infer_image_to_image app = FastAPI(debug=True) @@ -21,13 +21,20 @@ def list_active_models(): return session.describe_loaded_models() @app.put('/models/load/') -def load_model(model_id: str, params: Dict[str, str] = None) -> dict: +# def load_model(model_id: str, misc: Dict[str, str]) -> dict: +def load_model(model_id: str, misc: dict) -> dict: if model_id in session.models.keys(): raise HTTPException( status_code=409, detail=f'Model with id {model_id} has already been loaded' ) - session.load_model(model_id, params=params) + try: + session.load_model(model_id, params=misc) + except CouldNotFindModelError: + raise HTTPException( + status_code=404, + detail=f'Could not find {model_id} in defined models' + ) return session.describe_loaded_models() @app.put('/i2i/infer/') diff --git a/model_server/ilastik.py b/model_server/ilastik.py index 99a812a7e99109e34a62b949d044d165e11ad9f1..606c93d4294e2498502b013f3f32b62bf5d61526 100644 --- a/model_server/ilastik.py +++ b/model_server/ilastik.py @@ -1,11 +1,5 @@ import os -from ilastik import app -from ilastik.applets.dataSelection import DatasetInfo, FilesystemDatasetInfo -from ilastik.applets.dataSelection.opDataSelection import PreloadedArrayDatasetInfo -from ilastik.workflows.pixelClassification import PixelClassificationWorkflow -from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflow - import numpy as np import vigra @@ -20,11 +14,15 @@ class IlastikImageToImageModel(ImageToImageModel): raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file') self.project_file = str(params['project_file']) self.shell = None - # self.operator = None super().__init__(autoload, params) def load(self): + from ilastik import app + from ilastik.applets.dataSelection.opDataSelection import PreloadedArrayDatasetInfo + + self.PreloadedArrayDatasetInfo = PreloadedArrayDatasetInfo + os.environ["LAZYFLOW_THREADS"] = "8" os.environ["LAZYFLOW_TOTAL_RAM_MB"] = "24000" @@ -33,7 +31,7 @@ class IlastikImageToImageModel(ImageToImageModel): args.project = self.project_file shell = app.main(args) - if not isinstance(shell.workflow, self.workflow): + if not isinstance(shell.workflow, self.get_workflow()): raise ParameterExpectedError( f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}' ) @@ -44,13 +42,17 @@ class IlastikImageToImageModel(ImageToImageModel): class IlastikPixelClassifierModel(IlastikImageToImageModel): model_id = 'ilastik_pixel_classification' - workflow = PixelClassificationWorkflow + + @staticmethod + def get_workflow(): + from ilastik.workflows.pixelClassification import PixelClassificationWorkflow + return PixelClassificationWorkflow def infer(self, input_img: GenericImageDataAccessor) -> (np.ndarray, dict): tagged_input_data = vigra.taggedView(input_img.data, 'xycz') dsi = [ { - 'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), + 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), } ] pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [1 x w x h x n] @@ -66,7 +68,11 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel): class IlastikObjectClassifierModel(IlastikImageToImageModel): model_id = 'ilastik_object_classification' - workflow = ObjectClassificationWorkflow + + @staticmethod + def get_workflow(): + from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflow + return ObjectClassificationWorkflow def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): tagged_input_data = vigra.taggedView(input_img.data, 'xycz') @@ -74,8 +80,8 @@ class IlastikObjectClassifierModel(IlastikImageToImageModel): dsi = [ { - 'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), - 'Prediction Maps': PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data), + 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), + 'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data), } ] diff --git a/model_server/model.py b/model_server/model.py index 7ee63af4e906683c6040b994a149472390ed704d..a2d75c7b5453a8ad14092f6ea61a57d6860ea0fb 100644 --- a/model_server/model.py +++ b/model_server/model.py @@ -27,17 +27,6 @@ class Model(ABC): raise CouldNotLoadModelError() return None - @classmethod - def get_all_subclasses(cls): - """ - Recursively find all subclasses of Model - :return: set of all subclasses of Model - """ - def get_all_subclasses_of(cc): - return set(cc.__subclasses__()).union( - [s for c in cc.__subclasses__() for s in get_all_subclasses_of(c)]) - return get_all_subclasses_of(cls) - @abstractmethod def load(self): """ diff --git a/model_server/model_registry.py b/model_server/model_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..f6686dc08c0bcd9978bac2f241f7e25507da8ad8 --- /dev/null +++ b/model_server/model_registry.py @@ -0,0 +1,17 @@ +import model_server.ilastik +import model_server.model + +def get_all_model_subclasses(): + """ + Recursively find all subclasses of Model + :return: set of all subclasses of Model + """ + + def get_all_subclasses_of(cc): + return set(cc.__subclasses__()).union( + [s for c in cc.__subclasses__() for s in get_all_subclasses_of(c)]) + + return get_all_subclasses_of(model_server.model.Model) + +if __name__ == '__main__': + print(get_all_model_subclasses()) \ No newline at end of file diff --git a/model_server/session.py b/model_server/session.py index 7508f103f7341cef1ea49797538d5794ab5c4221..69a94fcb60e051faf08564fcde0fce6c862aafac 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -7,6 +7,7 @@ from typing import Dict from conf.server import paths from model_server.model import Model +from model_server.model_registry import get_all_model_subclasses from model_server.share import SharedImageDirectory from model_server.workflow import WorkflowRunRecord @@ -70,7 +71,7 @@ class Session(object): :param params: optional parameters that are passed upon loading a model :return: True if model successfully loaded, False if not """ - models = Model.get_all_subclasses() + models = get_all_model_subclasses() for mc in models: if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id: try: diff --git a/run_server.py b/run_server.py new file mode 100644 index 0000000000000000000000000000000000000000..1e90251c6650cac3be7b2383cf32fa61a4bdb41f --- /dev/null +++ b/run_server.py @@ -0,0 +1,7 @@ +import uvicorn + +host = '127.0.0.1' +port = 8001 + +if __name__ == '__main__': + uvicorn.run('api:app', **{'host': host, 'port': port, 'log_level': 'debug'}, reload=True) \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py index 0b2e844bdbea032a426191077036bb255aa14e93..bcb366eca8fa82ab5828e27ae0c1c5a29d1ff595 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -6,7 +6,7 @@ import unittest from conf.testing import czifile from model_server.model import DummyImageToImageModel -class TestApiFromAutomatedClient(unittest.TestCase): +class TestServerBaseClass(unittest.TestCase): def setUp(self) -> None: import uvicorn host = '127.0.0.1' @@ -37,6 +37,7 @@ class TestApiFromAutomatedClient(unittest.TestCase): def tearDown(self) -> None: self.server_process.terminate() +class TestApiFromAutomatedClient(TestServerBaseClass): def test_trivial_api_response(self): resp = requests.get(self.uri, ) self.assertEqual(resp.status_code, 200) @@ -59,6 +60,15 @@ class TestApiFromAutomatedClient(unittest.TestCase): self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') return model_id + def test_respond_with_error_when_invalid_model_loaded(self): + model_id = 'not_a_real_model' + resp = requests.put( + self.uri + f'models/load', + params={'model_id': model_id} + ) + self.assertEqual(resp.status_code, 404) + print(resp.content) + def test_respond_with_error_when_invalid_filepath_requested(self): model_id = self.test_load_dummy_model() resp = requests.put( diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py index a2d204e694dc25c282c9a63e6dcc1891303420eb..643225b2bc5cbb6e29124b897be3bcf22e64f91e 100644 --- a/tests/test_ilastik.py +++ b/tests/test_ilastik.py @@ -1,3 +1,4 @@ +import requests import unittest import numpy as np @@ -5,7 +6,9 @@ import numpy as np from conf.testing import czifile, ilastik, output_path from model_server.image import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.ilastik import IlastikObjectClassifierModel, IlastikPixelClassifierModel +from model_server.model import Model from model_server.workflow import infer_image_to_image +from tests.test_api import TestServerBaseClass class TestIlastikPixelClassification(unittest.TestCase): def setUp(self) -> None: @@ -31,6 +34,10 @@ class TestIlastikPixelClassification(unittest.TestCase): with self.assertRaises(AttributeError): pxmap , _= model.infer(input_img) + def test_ilastik_subclasses_are_found(self): + self.assertIn(IlastikPixelClassifierModel, Model.get_all_subclasses()) + self.assertIn(IlastikObjectClassifierModel, Model.get_all_subclasses()) + def test_run_pixel_classifier_on_random_data(self): model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}) w = 512 @@ -89,4 +96,20 @@ class TestIlastikPixelClassification(unittest.TestCase): channel=0, ) self.assertTrue(result.success) - self.assertGreater(result.timer_results['inference'], 1.0) \ No newline at end of file + self.assertGreater(result.timer_results['inference'], 1.0) + +class TestIlastikOverApi(TestServerBaseClass): + def test_load_ilastik_model(self): + model_id = IlastikPixelClassifierModel.model_id + resp_load = requests.put( + self.uri + f'models/load', + params={'model_id': model_id}, + # data={'project_file': str(ilastik['pixel_classifier'])}, + data={'project_file': 'hii',}, + ) + self.assertEqual(resp_load.status_code, 200, resp_load.content) + # resp_list = requests.get(self.uri + 'models') + # self.assertEqual(resp_list.status_code, 200) + # rj = resp_list.json() + # self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') + # return model_id \ No newline at end of file