Skip to content
Snippets Groups Projects
Commit b9f27095 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Can now load ilastik models by API

parent 2291d080
No related branches found
No related tags found
No related merge requests found
...@@ -27,30 +27,22 @@ def load_dummy_model() -> dict: ...@@ -27,30 +27,22 @@ def load_dummy_model() -> dict:
return {'model_id': session.load_model(DummyImageToImageModel)} return {'model_id': session.load_model(DummyImageToImageModel)}
@app.put('/models/ilastik/pixel_classification/load/') @app.put('/models/ilastik/pixel_classification/load/')
def load_ilastik_pixel_classification_model(params: str) -> dict: def load_ilastik_pixel_classification_model(project_file: str) -> dict:
return {'model_id': session.load_model(IlastikPixelClassifierModel, params)} return {
'model_id': session.load_model(
IlastikPixelClassifierModel,
{'project_file': project_file}
)
}
@app.put('/models/ilastik/object_classification/load/') @app.put('/models/ilastik/object_classification/load/')
def load_ilastik_object_classification_model(params: str) -> dict: def load_ilastik_object_classification_model(project_file: str) -> dict:
return {'model_id': session.load_model(IlastikObjectClassifierModel, params)} return {
'model_id': session.load_model(
# @app.put('/models/ilastik/pixel_classification/load/') IlastikObjectClassifierModel,
# def infer_ilastik_pixel_classification_from_file(input_filename: str, channel: int = None) -> dict: {'project_file': project_file}
# inpath = session.inbound.path / input_filename )
# if not inpath.exists(): }
# raise HTTPException(
# status_code=404,
# detail=f'Could not find file:\n{inpath}'
# )
#
# record = infer_image_to_image(
# inpath,
# session.models[model_id]['object'],
# session.outbound.path,
# channel=channel,
# )
# session.record_workflow_run(record)
# return record
@app.put('/infer/from_image_file') @app.put('/infer/from_image_file')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
......
...@@ -47,7 +47,6 @@ class TestApiFromAutomatedClient(TestServerBaseClass): ...@@ -47,7 +47,6 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
self.assertEqual(resp.content, b'{}') self.assertEqual(resp.content, b'{}')
def test_load_dummy_model(self): def test_load_dummy_model(self):
# model_key = DummyImageToImageModel.__name__ + '_00'
resp_load = requests.put( resp_load = requests.put(
self.uri + f'models/dummy/load', self.uri + f'models/dummy/load',
) )
......
...@@ -34,9 +34,6 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -34,9 +34,6 @@ class TestIlastikPixelClassification(unittest.TestCase):
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
pxmap , _= model.infer(input_img) pxmap , _= model.infer(input_img)
def test_ilastik_subclasses_are_found(self):
self.assertIn(IlastikPixelClassifierModel, Model.get_all_subclasses())
self.assertIn(IlastikObjectClassifierModel, Model.get_all_subclasses())
def test_run_pixel_classifier_on_random_data(self): def test_run_pixel_classifier_on_random_data(self):
model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}) model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']})
...@@ -48,6 +45,7 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -48,6 +45,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
pxmap, _ = model.infer(input_img) pxmap, _ = model.infer(input_img)
self.assertEqual(pxmap.shape, (w, h, 2, 1)) self.assertEqual(pxmap.shape, (w, h, 2, 1))
def test_run_pixel_classifier(self): def test_run_pixel_classifier(self):
channel = 0 channel = 0
model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}) model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']})
...@@ -99,17 +97,28 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -99,17 +97,28 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.assertGreater(result.timer_results['inference'], 1.0) self.assertGreater(result.timer_results['inference'], 1.0)
class TestIlastikOverApi(TestServerBaseClass): class TestIlastikOverApi(TestServerBaseClass):
def test_load_ilastik_model(self): def test_load_ilastik_pixel_model(self):
model_id = IlastikPixelClassifierModel.model_id resp_load = requests.put(
self.uri + 'models/ilastik/pixel_classification/load/',
params={'project_file': str(ilastik['pixel_classifier'])},
)
model_id = resp_load.json()['model_id']
self.assertEqual(resp_load.status_code, 200, resp_load.json())
resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200)
rj = resp_list.json()
self.assertEqual(rj[model_id]['class'], 'IlastikPixelClassifierModel')
def test_load_ilastik_object_model(self):
resp_load = requests.put( resp_load = requests.put(
self.uri + f'models/load', self.uri + 'models/ilastik/object_classification/load/',
params={'model_id': model_id}, params={'project_file': str(ilastik['object_classifier'])},
# data={'project_file': str(ilastik['pixel_classifier'])},
data={'project_file': 'hii',},
) )
self.assertEqual(resp_load.status_code, 200, resp_load.content) model_id = resp_load.json()['model_id']
# resp_list = requests.get(self.uri + 'models')
# self.assertEqual(resp_list.status_code, 200) self.assertEqual(resp_load.status_code, 200, resp_load.json())
# rj = resp_list.json() resp_list = requests.get(self.uri + 'models')
# self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') self.assertEqual(resp_list.status_code, 200)
# return model_id rj = resp_list.json()
\ No newline at end of file self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierModel')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment