diff --git a/model_server/base/api.py b/model_server/base/api.py index 22554a062c65bb468b570522b6b96a17ca716b2d..d1411749b2896f521fc131e9dfc535358e52d36d 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -2,13 +2,12 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel from base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel -from base.session import Session, InvalidPathError +from base.session import session, InvalidPathError from base.validators import validate_workflow_inputs from base.workflows import classify_pixels from extensions.ilastik.workflows import infer_px_then_ob_model app = FastAPI(debug=True) -session = Session() import model_server.extensions.ilastik.router app.include_router(model_server.extensions.ilastik.router.router) diff --git a/model_server/base/session.py b/model_server/base/session.py index 20987ba556138ff65850b1f0f73b9ce2f8bc68cd..f44138bded00ac221ad2afdbbc83e2faa085bdab 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -13,7 +13,6 @@ from base.models import Model logger = logging.getLogger(__name__) - class Singleton(type): _instances = {} @@ -193,6 +192,11 @@ class Session(object, metaclass=Singleton): def restart(self, **kwargs): self.__init__(**kwargs) + +# create singleton instance +session = Session() + + class Error(Exception): pass diff --git a/model_server/base/validators.py b/model_server/base/validators.py index c55adb192826bba255de2d86b2e6195c57b25083..3755497a4b4e0070d4bd4dc56214c91b09c52287 100644 --- a/model_server/base/validators.py +++ b/model_server/base/validators.py @@ -1,8 +1,6 @@ from fastapi import HTTPException -from base.session import Session - -session = Session() +from base.session import session def validate_workflow_inputs(model_ids, inpaths): for mid in model_ids: diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py index 491e7e40ab678c34426ef0ae8a846c2fe065ea2f..9a636c3e0399c9622b576d1496c3950346c567f5 100644 --- a/model_server/extensions/ilastik/router.py +++ b/model_server/extensions/ilastik/router.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, HTTPException -from model_server.base.session import Session +from model_server.base.session import session from model_server.base.validators import validate_workflow_inputs from model_server.extensions.ilastik import models as ilm @@ -12,8 +12,6 @@ router = APIRouter( tags=['ilastik'], ) -session = Session() - def load_ilastik_model(model_class: ilm.IlastikModel, params: ilm.IlastikParams) -> dict: """ diff --git a/tests/test_session.py b/tests/test_session.py index ce49dbceef0d43620cf94f41a9fbed05fb2cf6e0..50be856547a4ee68f1a44cfa3e043dcd0b1b8447 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -4,23 +4,23 @@ from pydantic import BaseModel import unittest from base.models import DummySemanticSegmentationModel -from base.session import Session +from base.session import session class TestGetSessionObject(unittest.TestCase): def setUp(self) -> None: - self.sesh = Session() + self.sesh = session def tearDown(self) -> None: print('Tearing down...') - Session._instances = {} + session.__class__._instances = {} def test_session_is_singleton(self): - Session._instances = {} - self.assertEqual(len(Session._instances), 0) - s = Session() - self.assertEqual(len(Session._instances), 1) - self.assertIs(s, Session()) - self.assertEqual(len(Session._instances), 1) + session.__class__._instances = {} + self.assertEqual(len(session.__class__._instances), 0) + s = session.__class__() + self.assertEqual(len(session.__class__._instances), 1) + self.assertIs(s, session.__class__()) + self.assertEqual(len(session.__class__._instances), 1) def test_session_logfile_is_valid(self): self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')