Skip to content
Snippets Groups Projects
Commit 5865f0aa authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Mapped dummy model for instance segmentation to API as well, for testing purposes

parent 1cea684d
No related branches found
No related tags found
No related merge requests found
from fastapi import FastAPI, HTTPException
from model_server.base.models import DummySemanticSegmentationModel
from model_server.base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel
from model_server.base.session import Session, InvalidPathError
from model_server.base.validators import validate_workflow_inputs
from model_server.base.workflows import classify_pixels
......@@ -65,10 +65,14 @@ def restart_session(root: str = None) -> dict:
def list_active_models():
return session.describe_loaded_models()
@app.put('/models/dummy/load/')
@app.put('/models/dummy_semantic/load/')
def load_dummy_model() -> dict:
return {'model_id': session.load_model(DummySemanticSegmentationModel)}
@app.put('/models/dummy_instance/load/')
def load_dummy_model() -> dict:
return {'model_id': session.load_model(DummyInstanceSegmentationModel)}
@app.put('/workflows/segment')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
inpath = session.paths['inbound_images'] / input_filename
......
......@@ -78,8 +78,8 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.content, b'{}')
def test_load_dummy_model(self):
resp_load = self._put(f'models/dummy/load')
def test_load_dummy_semantic_model(self):
resp_load = self._put(f'models/dummy_semantic/load')
model_id = resp_load.json()['model_id']
self.assertEqual(resp_load.status_code, 200, resp_load.json())
resp_list = self._get('models')
......@@ -88,8 +88,18 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
self.assertEqual(rj[model_id]['class'], 'DummySemanticSegmentationModel')
return model_id
def test_load_dummy_instance_model(self):
resp_load = self._put(f'models/dummy_instance/load')
model_id = resp_load.json()['model_id']
self.assertEqual(resp_load.status_code, 200, resp_load.json())
resp_list = self._get('models')
self.assertEqual(resp_list.status_code, 200)
rj = resp_list.json()
self.assertEqual(rj[model_id]['class'], 'DummyInstanceSegmentationModel')
return model_id
def test_respond_with_error_when_invalid_filepath_requested(self):
model_id = self.test_load_dummy_model()
model_id = self.test_load_dummy_semantic_model()
resp = self._put(
f'infer/from_image_file',
......@@ -107,7 +117,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
self.assertEqual(resp.status_code, 409, resp.content.decode())
def test_i2i_dummy_inference_by_api(self):
model_id = self.test_load_dummy_model()
model_id = self.test_load_dummy_semantic_model()
self.copy_input_file_to_server()
resp_infer = self._put(
f'workflows/segment',
......@@ -120,7 +130,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
def test_restarting_session_clears_loaded_models(self):
resp_load = self._put(f'models/dummy/load',)
resp_load = self._put(f'models/dummy_semantic/load',)
self.assertEqual(resp_load.status_code, 200, resp_load.json())
resp_list_0 = self._get('models')
self.assertEqual(resp_list_0.status_code, 200)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment