diff --git a/model_server/base/api.py b/model_server/base/api.py index 1715d9ab73c2d3e73ec87d342b01a44f3a1a78f9..96620aefb05e30fedcf06dadbca9b76ffddbdf2a 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -87,7 +87,10 @@ def get_accessor(accessor_id: str): @app.get('/accessors/delete/{accessor_id}') def delete_accessor(accessor_id: str): - return _session_accessor(session.del_accessor, accessor_id) + if accessor_id == '*': + return session.del_all_accessors() + else: + return _session_accessor(session.del_accessor, accessor_id) @app.put('/accessors/read_from_file/{filename}') @@ -106,6 +109,4 @@ def write_accessor_to_file(accessor_id: str, filename: Union[str, None] = None) except AccessorIdError as e: raise HTTPException(404, f'Did not find accessor with ID {accessor_id}') except WriteAccessorError as e: - raise HTTPException(409, str(e)) - -# TODO: endpoint to unload all accessors \ No newline at end of file + raise HTTPException(409, str(e)) \ No newline at end of file diff --git a/model_server/base/session.py b/model_server/base/session.py index f22655c6202014e3bc5fb9d0e18133869f22770c..7e06003f8be9407f6a8ee4f1816b70fac12f1900 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -114,6 +114,20 @@ class _Session(object): v['object'] = None return accessor_id + def del_all_accessors(self) -> list[str]: + """ + Remove (unload) all accessors but keep their info in dictionary + :return: list of removed accessor IDs + """ + res = [] + for k, v in self.accessors.items(): + if v['loaded']: + v['object'] = None + v['loaded'] = False + res.append(k) + return res + + def list_accessors(self) -> dict: """ List information about all accessors in JSON-readable format diff --git a/tests/base/test_api.py b/tests/base/test_api.py index a6aad378e4f8a875573afbcc88308f55a0f8573c..f368dc05212c59f1b2ece208a719b4e249cedf70 100644 --- a/tests/base/test_api.py +++ b/tests/base/test_api.py @@ -240,6 +240,13 @@ class TestApiFromAutomatedClient(TestServerTestCase): resp_wrong_acc = self._get('accessors/auto_123456') self.assertEqual(resp_wrong_acc.status_code, 404) + # load another... then remove all + self._put(f'accessors/read_from_file/{fname}') + self.assertEqual(sum([v['loaded'] for v in self._get('accessors').json().values()]), 1) + self.assertEqual(len(self._get(f'accessors/delete/*').json()), 1) + self.assertEqual(sum([v['loaded'] for v in self._get('accessors').json().values()]), 0) + + def test_empty_accessor_list(self): resp_list_acc = self._get( f'accessors',