Skip to content
Snippets Groups Projects
Commit b9f27095 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Can now load ilastik models by API

parent 2291d080
No related branches found
No related tags found
No related merge requests found
......@@ -27,30 +27,22 @@ def load_dummy_model() -> dict:
return {'model_id': session.load_model(DummyImageToImageModel)}
@app.put('/models/ilastik/pixel_classification/load/')
def load_ilastik_pixel_classification_model(params: str) -> dict:
return {'model_id': session.load_model(IlastikPixelClassifierModel, params)}
def load_ilastik_pixel_classification_model(project_file: str) -> dict:
return {
'model_id': session.load_model(
IlastikPixelClassifierModel,
{'project_file': project_file}
)
}
@app.put('/models/ilastik/object_classification/load/')
def load_ilastik_object_classification_model(params: str) -> dict:
return {'model_id': session.load_model(IlastikObjectClassifierModel, params)}
# @app.put('/models/ilastik/pixel_classification/load/')
# def infer_ilastik_pixel_classification_from_file(input_filename: str, channel: int = None) -> dict:
# inpath = session.inbound.path / input_filename
# if not inpath.exists():
# raise HTTPException(
# status_code=404,
# detail=f'Could not find file:\n{inpath}'
# )
#
# record = infer_image_to_image(
# inpath,
# session.models[model_id]['object'],
# session.outbound.path,
# channel=channel,
# )
# session.record_workflow_run(record)
# return record
def load_ilastik_object_classification_model(project_file: str) -> dict:
return {
'model_id': session.load_model(
IlastikObjectClassifierModel,
{'project_file': project_file}
)
}
@app.put('/infer/from_image_file')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
......
......@@ -47,7 +47,6 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
self.assertEqual(resp.content, b'{}')
def test_load_dummy_model(self):
# model_key = DummyImageToImageModel.__name__ + '_00'
resp_load = requests.put(
self.uri + f'models/dummy/load',
)
......
......@@ -34,9 +34,6 @@ class TestIlastikPixelClassification(unittest.TestCase):
with self.assertRaises(AttributeError):
pxmap , _= model.infer(input_img)
def test_ilastik_subclasses_are_found(self):
self.assertIn(IlastikPixelClassifierModel, Model.get_all_subclasses())
self.assertIn(IlastikObjectClassifierModel, Model.get_all_subclasses())
def test_run_pixel_classifier_on_random_data(self):
model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']})
......@@ -48,6 +45,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
pxmap, _ = model.infer(input_img)
self.assertEqual(pxmap.shape, (w, h, 2, 1))
def test_run_pixel_classifier(self):
channel = 0
model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']})
......@@ -99,17 +97,28 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.assertGreater(result.timer_results['inference'], 1.0)
class TestIlastikOverApi(TestServerBaseClass):
def test_load_ilastik_model(self):
model_id = IlastikPixelClassifierModel.model_id
def test_load_ilastik_pixel_model(self):
resp_load = requests.put(
self.uri + 'models/ilastik/pixel_classification/load/',
params={'project_file': str(ilastik['pixel_classifier'])},
)
model_id = resp_load.json()['model_id']
self.assertEqual(resp_load.status_code, 200, resp_load.json())
resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200)
rj = resp_list.json()
self.assertEqual(rj[model_id]['class'], 'IlastikPixelClassifierModel')
def test_load_ilastik_object_model(self):
resp_load = requests.put(
self.uri + f'models/load',
params={'model_id': model_id},
# data={'project_file': str(ilastik['pixel_classifier'])},
data={'project_file': 'hii',},
self.uri + 'models/ilastik/object_classification/load/',
params={'project_file': str(ilastik['object_classifier'])},
)
self.assertEqual(resp_load.status_code, 200, resp_load.content)
# resp_list = requests.get(self.uri + 'models')
# self.assertEqual(resp_list.status_code, 200)
# rj = resp_list.json()
# self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel')
# return model_id
\ No newline at end of file
model_id = resp_load.json()['model_id']
self.assertEqual(resp_load.status_code, 200, resp_load.json())
resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200)
rj = resp_list.json()
self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierModel')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment