From fa7ebaab8eb6ef1ea0330fa04f0dbe47d5084563 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Mon, 1 Jul 2024 08:15:11 +0200 Subject: [PATCH] Playing with importing session object only --- model_server/base/api.py | 3 +-- model_server/base/session.py | 6 +++++- model_server/base/validators.py | 4 +--- model_server/extensions/ilastik/router.py | 4 +--- tests/test_session.py | 18 +++++++++--------- 5 files changed, 17 insertions(+), 18 deletions(-) diff --git a/model_server/base/api.py b/model_server/base/api.py index 22554a06..d1411749 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -2,13 +2,12 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel from base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel -from base.session import Session, InvalidPathError +from base.session import session, InvalidPathError from base.validators import validate_workflow_inputs from base.workflows import classify_pixels from extensions.ilastik.workflows import infer_px_then_ob_model app = FastAPI(debug=True) -session = Session() import model_server.extensions.ilastik.router app.include_router(model_server.extensions.ilastik.router.router) diff --git a/model_server/base/session.py b/model_server/base/session.py index 20987ba5..f44138bd 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -13,7 +13,6 @@ from base.models import Model logger = logging.getLogger(__name__) - class Singleton(type): _instances = {} @@ -193,6 +192,11 @@ class Session(object, metaclass=Singleton): def restart(self, **kwargs): self.__init__(**kwargs) + +# create singleton instance +session = Session() + + class Error(Exception): pass diff --git a/model_server/base/validators.py b/model_server/base/validators.py index c55adb19..3755497a 100644 --- a/model_server/base/validators.py +++ b/model_server/base/validators.py @@ -1,8 +1,6 @@ from fastapi import HTTPException -from base.session import Session - -session = Session() +from base.session import session def validate_workflow_inputs(model_ids, inpaths): for mid in model_ids: diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py index 491e7e40..9a636c3e 100644 --- a/model_server/extensions/ilastik/router.py +++ b/model_server/extensions/ilastik/router.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, HTTPException -from model_server.base.session import Session +from model_server.base.session import session from model_server.base.validators import validate_workflow_inputs from model_server.extensions.ilastik import models as ilm @@ -12,8 +12,6 @@ router = APIRouter( tags=['ilastik'], ) -session = Session() - def load_ilastik_model(model_class: ilm.IlastikModel, params: ilm.IlastikParams) -> dict: """ diff --git a/tests/test_session.py b/tests/test_session.py index ce49dbce..50be8565 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -4,23 +4,23 @@ from pydantic import BaseModel import unittest from base.models import DummySemanticSegmentationModel -from base.session import Session +from base.session import session class TestGetSessionObject(unittest.TestCase): def setUp(self) -> None: - self.sesh = Session() + self.sesh = session def tearDown(self) -> None: print('Tearing down...') - Session._instances = {} + session.__class__._instances = {} def test_session_is_singleton(self): - Session._instances = {} - self.assertEqual(len(Session._instances), 0) - s = Session() - self.assertEqual(len(Session._instances), 1) - self.assertIs(s, Session()) - self.assertEqual(len(Session._instances), 1) + session.__class__._instances = {} + self.assertEqual(len(session.__class__._instances), 0) + s = session.__class__() + self.assertEqual(len(session.__class__._instances), 1) + self.assertIs(s, session.__class__()) + self.assertEqual(len(session.__class__._instances), 1) def test_session_logfile_is_valid(self): self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place') -- GitLab