From fa7ebaab8eb6ef1ea0330fa04f0dbe47d5084563 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 1 Jul 2024 08:15:11 +0200
Subject: [PATCH] Playing with importing session object only

---
 model_server/base/api.py                  |  3 +--
 model_server/base/session.py              |  6 +++++-
 model_server/base/validators.py           |  4 +---
 model_server/extensions/ilastik/router.py |  4 +---
 tests/test_session.py                     | 18 +++++++++---------
 5 files changed, 17 insertions(+), 18 deletions(-)

diff --git a/model_server/base/api.py b/model_server/base/api.py
index 22554a06..d1411749 100644
--- a/model_server/base/api.py
+++ b/model_server/base/api.py
@@ -2,13 +2,12 @@ from fastapi import FastAPI, HTTPException
 from pydantic import BaseModel
 
 from base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel
-from base.session import Session, InvalidPathError
+from base.session import session, InvalidPathError
 from base.validators import validate_workflow_inputs
 from base.workflows import classify_pixels
 from extensions.ilastik.workflows import infer_px_then_ob_model
 
 app = FastAPI(debug=True)
-session = Session()
 
 import model_server.extensions.ilastik.router
 app.include_router(model_server.extensions.ilastik.router.router)
diff --git a/model_server/base/session.py b/model_server/base/session.py
index 20987ba5..f44138bd 100644
--- a/model_server/base/session.py
+++ b/model_server/base/session.py
@@ -13,7 +13,6 @@ from base.models import Model
 
 logger = logging.getLogger(__name__)
 
-
 class Singleton(type):
     _instances = {}
 
@@ -193,6 +192,11 @@ class Session(object, metaclass=Singleton):
     def restart(self, **kwargs):
         self.__init__(**kwargs)
 
+
+# create singleton instance
+session = Session()
+
+
 class Error(Exception):
     pass
 
diff --git a/model_server/base/validators.py b/model_server/base/validators.py
index c55adb19..3755497a 100644
--- a/model_server/base/validators.py
+++ b/model_server/base/validators.py
@@ -1,8 +1,6 @@
 from fastapi import HTTPException
 
-from base.session import Session
-
-session = Session()
+from base.session import session
 
 def validate_workflow_inputs(model_ids, inpaths):
     for mid in model_ids:
diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py
index 491e7e40..9a636c3e 100644
--- a/model_server/extensions/ilastik/router.py
+++ b/model_server/extensions/ilastik/router.py
@@ -1,6 +1,6 @@
 from fastapi import APIRouter, HTTPException
 
-from model_server.base.session import Session
+from model_server.base.session import session
 from model_server.base.validators import  validate_workflow_inputs
 
 from model_server.extensions.ilastik import models as ilm
@@ -12,8 +12,6 @@ router = APIRouter(
     tags=['ilastik'],
 )
 
-session = Session()
-
 
 def load_ilastik_model(model_class: ilm.IlastikModel, params: ilm.IlastikParams) -> dict:
     """
diff --git a/tests/test_session.py b/tests/test_session.py
index ce49dbce..50be8565 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -4,23 +4,23 @@ from pydantic import BaseModel
 import unittest
 
 from base.models import DummySemanticSegmentationModel
-from base.session import Session
+from base.session import session
 
 class TestGetSessionObject(unittest.TestCase):
     def setUp(self) -> None:
-        self.sesh = Session()
+        self.sesh = session
 
     def tearDown(self) -> None:
         print('Tearing down...')
-        Session._instances = {}
+        session.__class__._instances = {}
 
     def test_session_is_singleton(self):
-        Session._instances = {}
-        self.assertEqual(len(Session._instances), 0)
-        s = Session()
-        self.assertEqual(len(Session._instances), 1)
-        self.assertIs(s, Session())
-        self.assertEqual(len(Session._instances), 1)
+        session.__class__._instances = {}
+        self.assertEqual(len(session.__class__._instances), 0)
+        s = session.__class__()
+        self.assertEqual(len(session.__class__._instances), 1)
+        self.assertIs(s, session.__class__())
+        self.assertEqual(len(session.__class__._instances), 1)
 
     def test_session_logfile_is_valid(self):
         self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')
-- 
GitLab