From b9f27095a0ca3494043beed30afb927fe66add30 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 31 Aug 2023 16:33:41 +0200 Subject: [PATCH] Can now load ilastik models by API --- api.py | 36 ++++++++++++++---------------------- tests/test_api.py | 1 - tests/test_ilastik.py | 39 ++++++++++++++++++++++++--------------- 3 files changed, 38 insertions(+), 38 deletions(-) diff --git a/api.py b/api.py index 1c168b44..ea23b59e 100644 --- a/api.py +++ b/api.py @@ -27,30 +27,22 @@ def load_dummy_model() -> dict: return {'model_id': session.load_model(DummyImageToImageModel)} @app.put('/models/ilastik/pixel_classification/load/') -def load_ilastik_pixel_classification_model(params: str) -> dict: - return {'model_id': session.load_model(IlastikPixelClassifierModel, params)} +def load_ilastik_pixel_classification_model(project_file: str) -> dict: + return { + 'model_id': session.load_model( + IlastikPixelClassifierModel, + {'project_file': project_file} + ) + } @app.put('/models/ilastik/object_classification/load/') -def load_ilastik_object_classification_model(params: str) -> dict: - return {'model_id': session.load_model(IlastikObjectClassifierModel, params)} - -# @app.put('/models/ilastik/pixel_classification/load/') -# def infer_ilastik_pixel_classification_from_file(input_filename: str, channel: int = None) -> dict: -# inpath = session.inbound.path / input_filename -# if not inpath.exists(): -# raise HTTPException( -# status_code=404, -# detail=f'Could not find file:\n{inpath}' -# ) -# -# record = infer_image_to_image( -# inpath, -# session.models[model_id]['object'], -# session.outbound.path, -# channel=channel, -# ) -# session.record_workflow_run(record) -# return record +def load_ilastik_object_classification_model(project_file: str) -> dict: + return { + 'model_id': session.load_model( + IlastikObjectClassifierModel, + {'project_file': project_file} + ) + } @app.put('/infer/from_image_file') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: diff --git a/tests/test_api.py b/tests/test_api.py index c61e4c2b..e4091cb2 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -47,7 +47,6 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(resp.content, b'{}') def test_load_dummy_model(self): - # model_key = DummyImageToImageModel.__name__ + '_00' resp_load = requests.put( self.uri + f'models/dummy/load', ) diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py index 643225b2..8b687874 100644 --- a/tests/test_ilastik.py +++ b/tests/test_ilastik.py @@ -34,9 +34,6 @@ class TestIlastikPixelClassification(unittest.TestCase): with self.assertRaises(AttributeError): pxmap , _= model.infer(input_img) - def test_ilastik_subclasses_are_found(self): - self.assertIn(IlastikPixelClassifierModel, Model.get_all_subclasses()) - self.assertIn(IlastikObjectClassifierModel, Model.get_all_subclasses()) def test_run_pixel_classifier_on_random_data(self): model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}) @@ -48,6 +45,7 @@ class TestIlastikPixelClassification(unittest.TestCase): pxmap, _ = model.infer(input_img) self.assertEqual(pxmap.shape, (w, h, 2, 1)) + def test_run_pixel_classifier(self): channel = 0 model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}) @@ -99,17 +97,28 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertGreater(result.timer_results['inference'], 1.0) class TestIlastikOverApi(TestServerBaseClass): - def test_load_ilastik_model(self): - model_id = IlastikPixelClassifierModel.model_id + def test_load_ilastik_pixel_model(self): + resp_load = requests.put( + self.uri + 'models/ilastik/pixel_classification/load/', + params={'project_file': str(ilastik['pixel_classifier'])}, + ) + model_id = resp_load.json()['model_id'] + + self.assertEqual(resp_load.status_code, 200, resp_load.json()) + resp_list = requests.get(self.uri + 'models') + self.assertEqual(resp_list.status_code, 200) + rj = resp_list.json() + self.assertEqual(rj[model_id]['class'], 'IlastikPixelClassifierModel') + + def test_load_ilastik_object_model(self): resp_load = requests.put( - self.uri + f'models/load', - params={'model_id': model_id}, - # data={'project_file': str(ilastik['pixel_classifier'])}, - data={'project_file': 'hii',}, + self.uri + 'models/ilastik/object_classification/load/', + params={'project_file': str(ilastik['object_classifier'])}, ) - self.assertEqual(resp_load.status_code, 200, resp_load.content) - # resp_list = requests.get(self.uri + 'models') - # self.assertEqual(resp_list.status_code, 200) - # rj = resp_list.json() - # self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') - # return model_id \ No newline at end of file + model_id = resp_load.json()['model_id'] + + self.assertEqual(resp_load.status_code, 200, resp_load.json()) + resp_list = requests.get(self.uri + 'models') + self.assertEqual(resp_list.status_code, 200) + rj = resp_list.json() + self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierModel') \ No newline at end of file -- GitLab