From 2291d080b6c269e953464d367b180ae4ae2f8856 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 31 Aug 2023 16:12:20 +0200 Subject: [PATCH] Inference API now works again with dummy model --- api.py | 26 ++++++++++++++++++++++---- model_server/session.py | 13 +++++++++---- tests/test_api.py | 23 +++++++++-------------- tests/test_model.py | 7 +------ 4 files changed, 41 insertions(+), 28 deletions(-) diff --git a/api.py b/api.py index ad40f5ef..1c168b44 100644 --- a/api.py +++ b/api.py @@ -24,17 +24,35 @@ def list_active_models(): @app.put('/models/dummy/load/') def load_dummy_model() -> dict: - return session.load_model(DummyImageToImageModel) + return {'model_id': session.load_model(DummyImageToImageModel)} @app.put('/models/ilastik/pixel_classification/load/') def load_ilastik_pixel_classification_model(params: str) -> dict: - return session.load_model(IlastikPixelClassifierModel, params) + return {'model_id': session.load_model(IlastikPixelClassifierModel, params)} @app.put('/models/ilastik/object_classification/load/') def load_ilastik_object_classification_model(params: str) -> dict: - return session.load_model(IlastikObjectClassifierModel, params) + return {'model_id': session.load_model(IlastikObjectClassifierModel, params)} -@app.put('/i2i/infer/') +# @app.put('/models/ilastik/pixel_classification/load/') +# def infer_ilastik_pixel_classification_from_file(input_filename: str, channel: int = None) -> dict: +# inpath = session.inbound.path / input_filename +# if not inpath.exists(): +# raise HTTPException( +# status_code=404, +# detail=f'Could not find file:\n{inpath}' +# ) +# +# record = infer_image_to_image( +# inpath, +# session.models[model_id]['object'], +# session.outbound.path, +# channel=channel, +# ) +# session.record_workflow_run(record) +# return record + +@app.put('/infer/from_image_file') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: if model_id not in session.describe_loaded_models().keys(): raise HTTPException( diff --git a/model_server/session.py b/model_server/session.py index 6db22003..1a9b38b9 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -74,15 +74,20 @@ class Session(object): mi = ModelClass(params=params) assert mi.loaded, f'Error loading instance of {ModelClass.__name__}' ii = 0 - def mid(i): return f'{ModelClass.__name__}_{ii:02d}' + + def mid(i): + return f'{ModelClass.__name__}_{i:02d}' + while mid(ii) in self.models.keys(): ii += 1 - self.models[mid(ii)] = { + + key = mid(ii) + self.models[key] = { 'object': mi, 'params': params } - self.log_event(f'Loaded model {mid}') - return self.describe_loaded_models() + self.log_event(f'Loaded model {key}') + return key def describe_loaded_models(self) -> dict: return { diff --git a/tests/test_api.py b/tests/test_api.py index d4f7d266..c61e4c2b 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -47,29 +47,24 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(resp.content, b'{}') def test_load_dummy_model(self): - model_key = DummyImageToImageModel.__name__ + '_00' + # model_key = DummyImageToImageModel.__name__ + '_00' resp_load = requests.put( self.uri + f'models/dummy/load', ) + model_id = resp_load.json()['model_id'] self.assertEqual(resp_load.status_code, 200, resp_load.json()) resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() - self.assertEqual(rj[model_key]['class'], 'DummyImageToImageModel') + self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') + return model_id - def test_respond_with_error_when_invalid_model_loaded(self): - model_id = 'not_a_real_model' - resp = requests.put( - self.uri + f'models/load', - params={'model_id': model_id} - ) - self.assertEqual(resp.status_code, 404) - print(resp.content) def test_respond_with_error_when_invalid_filepath_requested(self): - # model_id = self.test_load_dummy_model() + model_id = self.test_load_dummy_model() + resp = requests.put( - self.uri + f'i2i/infer/', + self.uri + f'infer/from_image_file', params={ 'model_id': model_id, 'input_filename': 'not_a_real_file.name' @@ -81,7 +76,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): def test_i2i_inference_errors_when_model_not_found(self): model_id = 'not_a_real_model' resp = requests.put( - self.uri + f'i2i/infer/', + self.uri + f'infer/from_image_file', params={ 'model_id': model_id, 'input_filename': 'not_a_real_file.name' @@ -93,7 +88,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): model_id = self.test_load_dummy_model() self.copy_input_file_to_server() resp_infer = requests.put( - self.uri + f'i2i/infer/', + self.uri + f'infer/from_image_file', params={ 'model_id': model_id, 'input_filename': czifile['filename'], diff --git a/tests/test_model.py b/tests/test_model.py index 911b9f74..ca0324a9 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -46,9 +46,4 @@ class TestCziImageFileAccess(unittest.TestCase): img.data[0, 0], 0, 'First pixel is not black as expected' - ) - - def test_find_subclasses_recursively(self): - sc = DummyImageToImageModel - scs = Model.get_all_subclasses() - self.assertIn(DummyImageToImageModel, scs) \ No newline at end of file + ) \ No newline at end of file -- GitLab