diff --git a/api.py b/api.py index ad40f5ef7edb8bd6c7648b3b62c0ffab3dff732e..1c168b44237d544103008c761fafe300ca04c300 100644 --- a/api.py +++ b/api.py @@ -24,17 +24,35 @@ def list_active_models(): @app.put('/models/dummy/load/') def load_dummy_model() -> dict: - return session.load_model(DummyImageToImageModel) + return {'model_id': session.load_model(DummyImageToImageModel)} @app.put('/models/ilastik/pixel_classification/load/') def load_ilastik_pixel_classification_model(params: str) -> dict: - return session.load_model(IlastikPixelClassifierModel, params) + return {'model_id': session.load_model(IlastikPixelClassifierModel, params)} @app.put('/models/ilastik/object_classification/load/') def load_ilastik_object_classification_model(params: str) -> dict: - return session.load_model(IlastikObjectClassifierModel, params) + return {'model_id': session.load_model(IlastikObjectClassifierModel, params)} -@app.put('/i2i/infer/') +# @app.put('/models/ilastik/pixel_classification/load/') +# def infer_ilastik_pixel_classification_from_file(input_filename: str, channel: int = None) -> dict: +# inpath = session.inbound.path / input_filename +# if not inpath.exists(): +# raise HTTPException( +# status_code=404, +# detail=f'Could not find file:\n{inpath}' +# ) +# +# record = infer_image_to_image( +# inpath, +# session.models[model_id]['object'], +# session.outbound.path, +# channel=channel, +# ) +# session.record_workflow_run(record) +# return record + +@app.put('/infer/from_image_file') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: if model_id not in session.describe_loaded_models().keys(): raise HTTPException( diff --git a/model_server/session.py b/model_server/session.py index 6db22003c7e2bfeb4e45bdde566eb7492525bebb..1a9b38b9d3cf01c6f24ec5a5af3bf9818e2f47f6 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -74,15 +74,20 @@ class Session(object): mi = ModelClass(params=params) assert mi.loaded, f'Error loading instance of {ModelClass.__name__}' ii = 0 - def mid(i): return f'{ModelClass.__name__}_{ii:02d}' + + def mid(i): + return f'{ModelClass.__name__}_{i:02d}' + while mid(ii) in self.models.keys(): ii += 1 - self.models[mid(ii)] = { + + key = mid(ii) + self.models[key] = { 'object': mi, 'params': params } - self.log_event(f'Loaded model {mid}') - return self.describe_loaded_models() + self.log_event(f'Loaded model {key}') + return key def describe_loaded_models(self) -> dict: return { diff --git a/tests/test_api.py b/tests/test_api.py index d4f7d266e68d441f8a7bb566d40e8758055baa2d..c61e4c2bfd86fe6f4780e5fe817e0a07f0fb72e5 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -47,29 +47,24 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(resp.content, b'{}') def test_load_dummy_model(self): - model_key = DummyImageToImageModel.__name__ + '_00' + # model_key = DummyImageToImageModel.__name__ + '_00' resp_load = requests.put( self.uri + f'models/dummy/load', ) + model_id = resp_load.json()['model_id'] self.assertEqual(resp_load.status_code, 200, resp_load.json()) resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() - self.assertEqual(rj[model_key]['class'], 'DummyImageToImageModel') + self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') + return model_id - def test_respond_with_error_when_invalid_model_loaded(self): - model_id = 'not_a_real_model' - resp = requests.put( - self.uri + f'models/load', - params={'model_id': model_id} - ) - self.assertEqual(resp.status_code, 404) - print(resp.content) def test_respond_with_error_when_invalid_filepath_requested(self): - # model_id = self.test_load_dummy_model() + model_id = self.test_load_dummy_model() + resp = requests.put( - self.uri + f'i2i/infer/', + self.uri + f'infer/from_image_file', params={ 'model_id': model_id, 'input_filename': 'not_a_real_file.name' @@ -81,7 +76,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): def test_i2i_inference_errors_when_model_not_found(self): model_id = 'not_a_real_model' resp = requests.put( - self.uri + f'i2i/infer/', + self.uri + f'infer/from_image_file', params={ 'model_id': model_id, 'input_filename': 'not_a_real_file.name' @@ -93,7 +88,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): model_id = self.test_load_dummy_model() self.copy_input_file_to_server() resp_infer = requests.put( - self.uri + f'i2i/infer/', + self.uri + f'infer/from_image_file', params={ 'model_id': model_id, 'input_filename': czifile['filename'], diff --git a/tests/test_model.py b/tests/test_model.py index 911b9f7479bc2da3ff1aae045c824594118aa3f0..ca0324a9e574447cf52d591ac2edc0bafa03829c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -46,9 +46,4 @@ class TestCziImageFileAccess(unittest.TestCase): img.data[0, 0], 0, 'First pixel is not black as expected' - ) - - def test_find_subclasses_recursively(self): - sc = DummyImageToImageModel - scs = Model.get_all_subclasses() - self.assertIn(DummyImageToImageModel, scs) \ No newline at end of file + ) \ No newline at end of file