Skip to content
Snippets Groups Projects
Commit 10beaebd authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Test covers accessor output

parent 5824cb5b
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@ from typing import Union
from fastapi import FastAPI, HTTPException
from .accessors import generate_file_accessor
from .session import session, AccessorIdError, InvalidPathError
from .session import session, AccessorIdError, InvalidPathError, WriteAccessorError
app = FastAPI(debug=True)
......@@ -86,5 +86,11 @@ def read_accessor_from_file(filename: str, accessor_id: Union[str, None] = None)
acc = generate_file_accessor(fp)
return session.add_accessor(acc, accessor_id=accessor_id)
@app.put('/accessors/write_to_file/{accessor_id}')
def write_accessor_to_file(accessor_id: str, filename: Union[str, None] = None) -> str:
try:
return session.write_accessor(accessor_id, filename)
except AccessorIdError as e:
raise HTTPException(404, f'Did not find accessor with ID {accessor_id}')
except WriteAccessorError as e:
raise HTTPException(409, str(e))
\ No newline at end of file
......@@ -118,7 +118,10 @@ class _Session(object):
"""
List information about all accessors in JSON-readable format
"""
return pd.DataFrame(self.accessors).drop('object').to_dict()
if len(self.accessors):
return pd.DataFrame(self.accessors).drop('object').to_dict()
else:
return {}
def get_accessor_info(self, acc_id: str) -> dict:
"""
......@@ -142,6 +145,26 @@ class _Session(object):
self.del_accessor(acc_id)
return acc
def write_accessor(self, acc_id: str, filename: Union[str, None] = None) -> str:
"""
Write an accessor to file and unload it from the session
"""
if filename is None:
fp = self.paths['outbound_images'] / f'{acc_id}.tif'
else:
fp = self.paths['outbound_images'] / filename
acc = self.get_accessor(acc_id, pop=True)
old_fp = self.accessors[acc_id]['filepath']
if old_fp != '':
raise WriteAccessorError(
f'Cannot overwrite accessor that is already written to {old_fp}'
)
acc.write(fp)
self.accessors[acc_id]['filepath'] = fp.__str__()
return fp.__str__()
@staticmethod
def make_paths(root: str = None) -> dict:
"""
......@@ -273,6 +296,9 @@ class CouldNotInstantiateModelError(Error):
class AccessorIdError(Error):
pass
class WriteAccessorError(Error):
pass
class CouldNotCreateDirectory(Error):
pass
......
from pathlib import Path
from fastapi import APIRouter
import numpy as np
from pydantic import BaseModel
import model_server.conf.testing as conf
from model_server.base.accessors import InMemoryDataAccessor, generate_file_accessor
from model_server.base.api import app
from model_server.base.session import session
from tests.base.test_model import DummyInstanceSegmentationModel, DummySemanticSegmentationModel
......@@ -24,6 +26,18 @@ class BounceBackParams(BaseModel):
def list_bounce_back(params: BounceBackParams):
return {'success': True, 'params': {'par1': params.par1, 'par2': params.par2}}
@router.put('/accessors/dummy_accessor/load')
def load_dummy_accessor() -> str:
acc = InMemoryDataAccessor(
np.random.randint(
0,
2 ** 8,
size=(512, 256, 3, 7),
dtype='uint8'
)
)
return session.add_accessor(acc)
@router.put('/models/dummy_semantic/load/')
def load_dummy_model() -> dict:
mid = session.load_model(DummySemanticSegmentationModel)
......@@ -220,3 +234,25 @@ class TestApiFromAutomatedClient(TestServerTestCase):
# and try a non-existent accessor ID
resp_wrong_acc = self._get('accessors/auto_123456')
self.assertEqual(resp_wrong_acc.status_code, 404)
def test_empty_accessor_list(self):
resp_list_acc = self._get(
f'accessors',
)
self.assertEqual(len(resp_list_acc.json()), 0)
def test_write_accessor(self):
acc_id = self._put('/testing/accessors/dummy_accessor/load').json()
self.assertTrue(self._get(f'accessors/{acc_id}').json()['loaded'])
sd = self._get(f'accessors/{acc_id}').json()['shape_dict']
self.assertEqual(self._get(f'accessors/{acc_id}').json()['filepath'], '')
filename = 'test_output.tif'
self._put(f'/accessors/write_to_file/{acc_id}', query={'filename': filename})
where_out = self._get('paths').json()['outbound_images']
fp_out = (Path(where_out) / filename)
self.assertTrue(fp_out.exists())
acc_out = generate_file_accessor(fp_out)
self.assertEqual(sd, acc_out.shape_dict)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment