diff --git a/model_server/base/api.py b/model_server/base/api.py index f375894d92dea4156de46786ce499916e551d399..d1411749b2896f521fc131e9dfc535358e52d36d 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -2,16 +2,15 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel from base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel -from base.session import InvalidPathError, Session +from base.session import session, InvalidPathError from base.validators import validate_workflow_inputs from base.workflows import classify_pixels from extensions.ilastik.workflows import infer_px_then_ob_model app = FastAPI(debug=True) -session = Session() -import extensions.ilastik.router -app.include_router(extensions.ilastik.router.router) +import model_server.extensions.ilastik.router +app.include_router(model_server.extensions.ilastik.router.router) @app.on_event("startup") def startup(): diff --git a/model_server/base/session.py b/model_server/base/session.py index 20987ba556138ff65850b1f0f73b9ce2f8bc68cd..1308e2aec3d203ad7a33b55d9211c68d38d92c77 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -14,14 +14,6 @@ from base.models import Model logger = logging.getLogger(__name__) -class Singleton(type): - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) - return cls._instances[cls] - class CsvTable(object): def __init__(self, fpath: Path): self.path = fpath @@ -38,7 +30,7 @@ class CsvTable(object): self.empty = False return True -class Session(object, metaclass=Singleton): +class _Session(object): """ Singleton class for a server session that persists data between API calls """ @@ -100,7 +92,7 @@ class Session(object, metaclass=Singleton): root_path = Path(conf.defaults.root) else: root_path = Path(root) - sid = Session.create_session_id(root_path) + sid = _Session.create_session_id(root_path) paths = {'root': root_path} for pk in ['inbound_images', 'outbound_images', 'logs', 'tables']: pa = root_path / sid / conf.defaults.subdirectories[pk] @@ -193,6 +185,11 @@ class Session(object, metaclass=Singleton): def restart(self, **kwargs): self.__init__(**kwargs) + +# create singleton instance +session = _Session() + + class Error(Exception): pass diff --git a/model_server/base/validators.py b/model_server/base/validators.py index c55adb192826bba255de2d86b2e6195c57b25083..3755497a4b4e0070d4bd4dc56214c91b09c52287 100644 --- a/model_server/base/validators.py +++ b/model_server/base/validators.py @@ -1,8 +1,6 @@ from fastapi import HTTPException -from base.session import Session - -session = Session() +from base.session import session def validate_workflow_inputs(model_ids, inpaths): for mid in model_ids: diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py index 76e1364ee1fe9380b43b6b5f0ef272dd9a17e7fb..9a636c3e0399c9622b576d1496c3950346c567f5 100644 --- a/model_server/extensions/ilastik/router.py +++ b/model_server/extensions/ilastik/router.py @@ -1,18 +1,17 @@ from fastapi import APIRouter, HTTPException -from base.session import Session -from base.validators import validate_workflow_inputs +from model_server.base.session import session +from model_server.base.validators import validate_workflow_inputs -from extensions.ilastik import models as ilm -from extensions.ilastik.workflows import infer_px_then_ob_model +from model_server.extensions.ilastik import models as ilm +from model_server.base.models import ParameterExpectedError +from model_server.extensions.ilastik.workflows import infer_px_then_ob_model router = APIRouter( prefix='/ilastik', tags=['ilastik'], ) -session = Session() - def load_ilastik_model(model_class: ilm.IlastikModel, params: ilm.IlastikParams) -> dict: """ diff --git a/tests/test_session.py b/tests/test_session.py index ce49dbceef0d43620cf94f41a9fbed05fb2cf6e0..929bdc44afd21709c8226778b936029bb41620bf 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -3,31 +3,20 @@ import pathlib from pydantic import BaseModel import unittest +import base.session from base.models import DummySemanticSegmentationModel -from base.session import Session +from base.session import session class TestGetSessionObject(unittest.TestCase): def setUp(self) -> None: - self.sesh = Session() - - def tearDown(self) -> None: - print('Tearing down...') - Session._instances = {} - - def test_session_is_singleton(self): - Session._instances = {} - self.assertEqual(len(Session._instances), 0) - s = Session() - self.assertEqual(len(Session._instances), 1) - self.assertIs(s, Session()) - self.assertEqual(len(Session._instances), 1) + session.restart() + self.sesh = session def test_session_logfile_is_valid(self): self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place') def test_changing_session_root_creates_new_directory(self): from model_server.conf.defaults import root - from shutil import rmtree old_paths = self.sesh.get_paths() newroot = root / 'subdir' @@ -36,12 +25,6 @@ class TestGetSessionObject(unittest.TestCase): for k in old_paths.keys(): self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__())) - # this is necessary because logger itself is a singleton class - self.tearDown() - self.setUp() - rmtree(newroot) - self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory') - def test_change_session_subdirectory(self): old_paths = self.sesh.get_paths() print(old_paths) @@ -56,6 +39,23 @@ class TestGetSessionObject(unittest.TestCase): self.assertTrue(logfile2.exists()) self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile') + def test_reimporting_session_uses_same_logfile(self): + session1 = self.sesh + logfile1 = session1.logfile + self.assertTrue(logfile1.exists()) + + from base.session import session as session2 + self.assertEqual(session1, session2) + logfile2 = session2.logfile + self.assertTrue(logfile2.exists()) + self.assertEqual(logfile1, logfile2, 'Reimporting session incorrectly creates new logfile') + + from model_server.base.session import session as session3 + self.assertEqual(session1, session3) + logfile3 = session3.logfile + self.assertTrue(logfile3.exists()) + self.assertEqual(logfile1, logfile3, 'Reimporting session incorrectly creates new logfile') + def test_log_warning(self): msg = 'A test warning' self.sesh.log_info(msg)