From 18449743f9fa64f1987cbd42006711838966de6b Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 15 Aug 2024 17:07:32 +0200 Subject: [PATCH] Added option to remove all accessors --- model_server/base/api.py | 9 +++++---- model_server/base/session.py | 14 ++++++++++++++ tests/base/test_api.py | 7 +++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/model_server/base/api.py b/model_server/base/api.py index 1715d9ab..96620aef 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -87,7 +87,10 @@ def get_accessor(accessor_id: str): @app.get('/accessors/delete/{accessor_id}') def delete_accessor(accessor_id: str): - return _session_accessor(session.del_accessor, accessor_id) + if accessor_id == '*': + return session.del_all_accessors() + else: + return _session_accessor(session.del_accessor, accessor_id) @app.put('/accessors/read_from_file/{filename}') @@ -106,6 +109,4 @@ def write_accessor_to_file(accessor_id: str, filename: Union[str, None] = None) except AccessorIdError as e: raise HTTPException(404, f'Did not find accessor with ID {accessor_id}') except WriteAccessorError as e: - raise HTTPException(409, str(e)) - -# TODO: endpoint to unload all accessors \ No newline at end of file + raise HTTPException(409, str(e)) \ No newline at end of file diff --git a/model_server/base/session.py b/model_server/base/session.py index f22655c6..7e06003f 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -114,6 +114,20 @@ class _Session(object): v['object'] = None return accessor_id + def del_all_accessors(self) -> list[str]: + """ + Remove (unload) all accessors but keep their info in dictionary + :return: list of removed accessor IDs + """ + res = [] + for k, v in self.accessors.items(): + if v['loaded']: + v['object'] = None + v['loaded'] = False + res.append(k) + return res + + def list_accessors(self) -> dict: """ List information about all accessors in JSON-readable format diff --git a/tests/base/test_api.py b/tests/base/test_api.py index a6aad378..f368dc05 100644 --- a/tests/base/test_api.py +++ b/tests/base/test_api.py @@ -240,6 +240,13 @@ class TestApiFromAutomatedClient(TestServerTestCase): resp_wrong_acc = self._get('accessors/auto_123456') self.assertEqual(resp_wrong_acc.status_code, 404) + # load another... then remove all + self._put(f'accessors/read_from_file/{fname}') + self.assertEqual(sum([v['loaded'] for v in self._get('accessors').json().values()]), 1) + self.assertEqual(len(self._get(f'accessors/delete/*').json()), 1) + self.assertEqual(sum([v['loaded'] for v in self._get('accessors').json().values()]), 0) + + def test_empty_accessor_list(self): resp_list_acc = self._get( f'accessors', -- GitLab