diff --git a/model_server/base/api.py b/model_server/base/api.py
index 22554a062c65bb468b570522b6b96a17ca716b2d..d1411749b2896f521fc131e9dfc535358e52d36d 100644
--- a/model_server/base/api.py
+++ b/model_server/base/api.py
@@ -2,13 +2,12 @@ from fastapi import FastAPI, HTTPException
 from pydantic import BaseModel
 
 from base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel
-from base.session import Session, InvalidPathError
+from base.session import session, InvalidPathError
 from base.validators import validate_workflow_inputs
 from base.workflows import classify_pixels
 from extensions.ilastik.workflows import infer_px_then_ob_model
 
 app = FastAPI(debug=True)
-session = Session()
 
 import model_server.extensions.ilastik.router
 app.include_router(model_server.extensions.ilastik.router.router)
diff --git a/model_server/base/session.py b/model_server/base/session.py
index 20987ba556138ff65850b1f0f73b9ce2f8bc68cd..f44138bded00ac221ad2afdbbc83e2faa085bdab 100644
--- a/model_server/base/session.py
+++ b/model_server/base/session.py
@@ -13,7 +13,6 @@ from base.models import Model
 
 logger = logging.getLogger(__name__)
 
-
 class Singleton(type):
     _instances = {}
 
@@ -193,6 +192,11 @@ class Session(object, metaclass=Singleton):
     def restart(self, **kwargs):
         self.__init__(**kwargs)
 
+
+# create singleton instance
+session = Session()
+
+
 class Error(Exception):
     pass
 
diff --git a/model_server/base/validators.py b/model_server/base/validators.py
index c55adb192826bba255de2d86b2e6195c57b25083..3755497a4b4e0070d4bd4dc56214c91b09c52287 100644
--- a/model_server/base/validators.py
+++ b/model_server/base/validators.py
@@ -1,8 +1,6 @@
 from fastapi import HTTPException
 
-from base.session import Session
-
-session = Session()
+from base.session import session
 
 def validate_workflow_inputs(model_ids, inpaths):
     for mid in model_ids:
diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py
index 491e7e40ab678c34426ef0ae8a846c2fe065ea2f..9a636c3e0399c9622b576d1496c3950346c567f5 100644
--- a/model_server/extensions/ilastik/router.py
+++ b/model_server/extensions/ilastik/router.py
@@ -1,6 +1,6 @@
 from fastapi import APIRouter, HTTPException
 
-from model_server.base.session import Session
+from model_server.base.session import session
 from model_server.base.validators import  validate_workflow_inputs
 
 from model_server.extensions.ilastik import models as ilm
@@ -12,8 +12,6 @@ router = APIRouter(
     tags=['ilastik'],
 )
 
-session = Session()
-
 
 def load_ilastik_model(model_class: ilm.IlastikModel, params: ilm.IlastikParams) -> dict:
     """
diff --git a/tests/test_session.py b/tests/test_session.py
index ce49dbceef0d43620cf94f41a9fbed05fb2cf6e0..50be856547a4ee68f1a44cfa3e043dcd0b1b8447 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -4,23 +4,23 @@ from pydantic import BaseModel
 import unittest
 
 from base.models import DummySemanticSegmentationModel
-from base.session import Session
+from base.session import session
 
 class TestGetSessionObject(unittest.TestCase):
     def setUp(self) -> None:
-        self.sesh = Session()
+        self.sesh = session
 
     def tearDown(self) -> None:
         print('Tearing down...')
-        Session._instances = {}
+        session.__class__._instances = {}
 
     def test_session_is_singleton(self):
-        Session._instances = {}
-        self.assertEqual(len(Session._instances), 0)
-        s = Session()
-        self.assertEqual(len(Session._instances), 1)
-        self.assertIs(s, Session())
-        self.assertEqual(len(Session._instances), 1)
+        session.__class__._instances = {}
+        self.assertEqual(len(session.__class__._instances), 0)
+        s = session.__class__()
+        self.assertEqual(len(session.__class__._instances), 1)
+        self.assertIs(s, session.__class__())
+        self.assertEqual(len(session.__class__._instances), 1)
 
     def test_session_logfile_is_valid(self):
         self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')