Skip to content
Snippets Groups Projects
Commit 2291d080 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Inference API now works again with dummy model

parent e147a54f
No related branches found
No related tags found
No related merge requests found
......@@ -24,17 +24,35 @@ def list_active_models():
@app.put('/models/dummy/load/')
def load_dummy_model() -> dict:
return session.load_model(DummyImageToImageModel)
return {'model_id': session.load_model(DummyImageToImageModel)}
@app.put('/models/ilastik/pixel_classification/load/')
def load_ilastik_pixel_classification_model(params: str) -> dict:
return session.load_model(IlastikPixelClassifierModel, params)
return {'model_id': session.load_model(IlastikPixelClassifierModel, params)}
@app.put('/models/ilastik/object_classification/load/')
def load_ilastik_object_classification_model(params: str) -> dict:
return session.load_model(IlastikObjectClassifierModel, params)
return {'model_id': session.load_model(IlastikObjectClassifierModel, params)}
@app.put('/i2i/infer/')
# @app.put('/models/ilastik/pixel_classification/load/')
# def infer_ilastik_pixel_classification_from_file(input_filename: str, channel: int = None) -> dict:
# inpath = session.inbound.path / input_filename
# if not inpath.exists():
# raise HTTPException(
# status_code=404,
# detail=f'Could not find file:\n{inpath}'
# )
#
# record = infer_image_to_image(
# inpath,
# session.models[model_id]['object'],
# session.outbound.path,
# channel=channel,
# )
# session.record_workflow_run(record)
# return record
@app.put('/infer/from_image_file')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
if model_id not in session.describe_loaded_models().keys():
raise HTTPException(
......
......@@ -74,15 +74,20 @@ class Session(object):
mi = ModelClass(params=params)
assert mi.loaded, f'Error loading instance of {ModelClass.__name__}'
ii = 0
def mid(i): return f'{ModelClass.__name__}_{ii:02d}'
def mid(i):
return f'{ModelClass.__name__}_{i:02d}'
while mid(ii) in self.models.keys():
ii += 1
self.models[mid(ii)] = {
key = mid(ii)
self.models[key] = {
'object': mi,
'params': params
}
self.log_event(f'Loaded model {mid}')
return self.describe_loaded_models()
self.log_event(f'Loaded model {key}')
return key
def describe_loaded_models(self) -> dict:
return {
......
......@@ -47,29 +47,24 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
self.assertEqual(resp.content, b'{}')
def test_load_dummy_model(self):
model_key = DummyImageToImageModel.__name__ + '_00'
# model_key = DummyImageToImageModel.__name__ + '_00'
resp_load = requests.put(
self.uri + f'models/dummy/load',
)
model_id = resp_load.json()['model_id']
self.assertEqual(resp_load.status_code, 200, resp_load.json())
resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200)
rj = resp_list.json()
self.assertEqual(rj[model_key]['class'], 'DummyImageToImageModel')
self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel')
return model_id
def test_respond_with_error_when_invalid_model_loaded(self):
model_id = 'not_a_real_model'
resp = requests.put(
self.uri + f'models/load',
params={'model_id': model_id}
)
self.assertEqual(resp.status_code, 404)
print(resp.content)
def test_respond_with_error_when_invalid_filepath_requested(self):
# model_id = self.test_load_dummy_model()
model_id = self.test_load_dummy_model()
resp = requests.put(
self.uri + f'i2i/infer/',
self.uri + f'infer/from_image_file',
params={
'model_id': model_id,
'input_filename': 'not_a_real_file.name'
......@@ -81,7 +76,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
def test_i2i_inference_errors_when_model_not_found(self):
model_id = 'not_a_real_model'
resp = requests.put(
self.uri + f'i2i/infer/',
self.uri + f'infer/from_image_file',
params={
'model_id': model_id,
'input_filename': 'not_a_real_file.name'
......@@ -93,7 +88,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
model_id = self.test_load_dummy_model()
self.copy_input_file_to_server()
resp_infer = requests.put(
self.uri + f'i2i/infer/',
self.uri + f'infer/from_image_file',
params={
'model_id': model_id,
'input_filename': czifile['filename'],
......
......@@ -46,9 +46,4 @@ class TestCziImageFileAccess(unittest.TestCase):
img.data[0, 0],
0,
'First pixel is not black as expected'
)
def test_find_subclasses_recursively(self):
sc = DummyImageToImageModel
scs = Model.get_all_subclasses()
self.assertIn(DummyImageToImageModel, scs)
\ No newline at end of file
)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment