Skip to content
Snippets Groups Projects
Commit d02ee357 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Added session method to get accessor object with and without popping

parent 9c735683
No related branches found
No related tags found
No related merge requests found
......@@ -82,7 +82,7 @@ class _Session(object):
raise InvalidPathError(f'Could not find {path}')
self.paths[key] = Path(path)
def add_accessor(self, acc: GenericImageDataAccessor, accessor_id: str = None):
def add_accessor(self, acc: GenericImageDataAccessor, accessor_id: str = None) -> str:
"""
Add an accessor to session context
:param acc: accessor to add
......@@ -97,7 +97,7 @@ class _Session(object):
self.accessors[accessor_id] = {'loaded': True, 'object': acc, **acc.info}
return accessor_id
def del_accessor(self, accessor_id: str):
def del_accessor(self, accessor_id: str) -> str:
"""
Remove accessor object but retain its info dictionary
:param accessor_id: accessor's ID
......@@ -117,11 +117,10 @@ class _Session(object):
def list_accessors(self) -> dict:
"""
List information about all accessors in JSON-readable format
:return:
"""
return pd.DataFrame(self.accessors).drop('object').to_dict()
def get_accessor_info(self, acc_id: str):
def get_accessor_info(self, acc_id: str) -> dict:
"""
Get information about a single accessor
"""
......@@ -129,6 +128,20 @@ class _Session(object):
raise AccessorIdError(f'No accessor with ID {acc_id} is registered')
return self.list_accessors()[acc_id]
def get_accessor(self, acc_id: str, pop: bool = True) -> GenericImageDataAccessor:
"""
Return an accessor object
:param acc_id: accessor's ID
:param pop: remove object from session accessor registry if True
:return: accessor object
"""
if acc_id not in self.accessors.keys():
raise AccessorIdError(f'No accessor with ID {acc_id} is registered')
acc = self.accessors[acc_id]['object']
if pop:
self.del_accessor(acc_id)
return acc
@staticmethod
def make_paths(root: str = None) -> dict:
"""
......
......@@ -89,14 +89,50 @@ class TestGetSessionObject(unittest.TestCase):
def test_add_and_remove_accessor(self):
w = 256
h = 512
nc = 4
nz = 11
sh = (h, w, nc, nz)
acc = InMemoryDataAccessor(np.random.randint(0, 2 ** 8, size=sh, dtype='uint8'))
acc_id = session.add_accessor(acc)
acc = InMemoryDataAccessor(
np.random.randint(
0,
2 ** 8,
size=(512, 256, 3, 7),
dtype='uint8'
)
)
shd = acc.shape_dict
self.assertEqual(session.accessors[acc_id].shape, sh)
# add accessor to session registry
acc_id = session.add_accessor(acc)
self.assertEqual(session.get_accessor_info(acc_id)['shape_dict'], shd)
self.assertTrue(session.get_accessor_info(acc_id)['loaded'])
# remove accessor from session registry
session.del_accessor(acc_id)
self.assertEqual(session.accessors[acc_id]['shape_dict'], shd)
self.assertEqual(session.get_accessor_info(acc_id)['shape_dict'], shd)
self.assertFalse(session.get_accessor_info(acc_id)['loaded'])
def test_add_and_use_accessor(self):
acc = InMemoryDataAccessor(
np.random.randint(
0,
2 ** 8,
size=(512, 256, 3, 7),
dtype='uint8'
)
)
shd = acc.shape_dict
# add accessor to session registry
acc_id = session.add_accessor(acc)
self.assertEqual(session.get_accessor_info(acc_id)['shape_dict'], shd)
self.assertTrue(session.get_accessor_info(acc_id)['loaded'])
# get accessor from session registry without popping
acc_get = session.get_accessor(acc_id, pop=False)
self.assertIsInstance(acc_get, InMemoryDataAccessor)
self.assertEqual(acc_get.shape_dict, shd)
self.assertTrue(session.get_accessor_info(acc_id)['loaded'])
# get accessor from session registry with popping
acc_get = session.get_accessor(acc_id)
self.assertIsInstance(acc_get, InMemoryDataAccessor)
self.assertEqual(acc_get.shape_dict, shd)
self.assertFalse(session.get_accessor_info(acc_id)['loaded'])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment