diff --git a/model_server/base/api.py b/model_server/base/api.py
index f375894d92dea4156de46786ce499916e551d399..d1411749b2896f521fc131e9dfc535358e52d36d 100644
--- a/model_server/base/api.py
+++ b/model_server/base/api.py
@@ -2,16 +2,15 @@ from fastapi import FastAPI, HTTPException
 from pydantic import BaseModel
 
 from base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel
-from base.session import InvalidPathError, Session
+from base.session import session, InvalidPathError
 from base.validators import validate_workflow_inputs
 from base.workflows import classify_pixels
 from extensions.ilastik.workflows import infer_px_then_ob_model
 
 app = FastAPI(debug=True)
-session = Session()
 
-import extensions.ilastik.router
-app.include_router(extensions.ilastik.router.router)
+import model_server.extensions.ilastik.router
+app.include_router(model_server.extensions.ilastik.router.router)
 
 @app.on_event("startup")
 def startup():
diff --git a/model_server/base/session.py b/model_server/base/session.py
index 20987ba556138ff65850b1f0f73b9ce2f8bc68cd..1308e2aec3d203ad7a33b55d9211c68d38d92c77 100644
--- a/model_server/base/session.py
+++ b/model_server/base/session.py
@@ -14,14 +14,6 @@ from base.models import Model
 logger = logging.getLogger(__name__)
 
 
-class Singleton(type):
-    _instances = {}
-
-    def __call__(cls, *args, **kwargs):
-        if cls not in cls._instances:
-            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
-        return cls._instances[cls]
-
 class CsvTable(object):
     def __init__(self, fpath: Path):
         self.path = fpath
@@ -38,7 +30,7 @@ class CsvTable(object):
         self.empty = False
         return True
 
-class Session(object, metaclass=Singleton):
+class _Session(object):
     """
     Singleton class for a server session that persists data between API calls
     """
@@ -100,7 +92,7 @@ class Session(object, metaclass=Singleton):
             root_path = Path(conf.defaults.root)
         else:
             root_path = Path(root)
-        sid = Session.create_session_id(root_path)
+        sid = _Session.create_session_id(root_path)
         paths = {'root': root_path}
         for pk in ['inbound_images', 'outbound_images', 'logs', 'tables']:
             pa = root_path / sid / conf.defaults.subdirectories[pk]
@@ -193,6 +185,11 @@ class Session(object, metaclass=Singleton):
     def restart(self, **kwargs):
         self.__init__(**kwargs)
 
+
+# create singleton instance
+session = _Session()
+
+
 class Error(Exception):
     pass
 
diff --git a/model_server/base/validators.py b/model_server/base/validators.py
index c55adb192826bba255de2d86b2e6195c57b25083..3755497a4b4e0070d4bd4dc56214c91b09c52287 100644
--- a/model_server/base/validators.py
+++ b/model_server/base/validators.py
@@ -1,8 +1,6 @@
 from fastapi import HTTPException
 
-from base.session import Session
-
-session = Session()
+from base.session import session
 
 def validate_workflow_inputs(model_ids, inpaths):
     for mid in model_ids:
diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py
index 76e1364ee1fe9380b43b6b5f0ef272dd9a17e7fb..9a636c3e0399c9622b576d1496c3950346c567f5 100644
--- a/model_server/extensions/ilastik/router.py
+++ b/model_server/extensions/ilastik/router.py
@@ -1,18 +1,17 @@
 from fastapi import APIRouter, HTTPException
 
-from base.session import Session
-from base.validators import  validate_workflow_inputs
+from model_server.base.session import session
+from model_server.base.validators import  validate_workflow_inputs
 
-from extensions.ilastik import models as ilm
-from extensions.ilastik.workflows import infer_px_then_ob_model
+from model_server.extensions.ilastik import models as ilm
+from model_server.base.models import ParameterExpectedError
+from model_server.extensions.ilastik.workflows import infer_px_then_ob_model
 
 router = APIRouter(
     prefix='/ilastik',
     tags=['ilastik'],
 )
 
-session = Session()
-
 
 def load_ilastik_model(model_class: ilm.IlastikModel, params: ilm.IlastikParams) -> dict:
     """
diff --git a/tests/test_session.py b/tests/test_session.py
index ce49dbceef0d43620cf94f41a9fbed05fb2cf6e0..929bdc44afd21709c8226778b936029bb41620bf 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -3,31 +3,20 @@ import pathlib
 from pydantic import BaseModel
 import unittest
 
+import base.session
 from base.models import DummySemanticSegmentationModel
-from base.session import Session
+from base.session import session
 
 class TestGetSessionObject(unittest.TestCase):
     def setUp(self) -> None:
-        self.sesh = Session()
-
-    def tearDown(self) -> None:
-        print('Tearing down...')
-        Session._instances = {}
-
-    def test_session_is_singleton(self):
-        Session._instances = {}
-        self.assertEqual(len(Session._instances), 0)
-        s = Session()
-        self.assertEqual(len(Session._instances), 1)
-        self.assertIs(s, Session())
-        self.assertEqual(len(Session._instances), 1)
+        session.restart()
+        self.sesh = session
 
     def test_session_logfile_is_valid(self):
         self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')
 
     def test_changing_session_root_creates_new_directory(self):
         from model_server.conf.defaults import root
-        from shutil import rmtree
 
         old_paths = self.sesh.get_paths()
         newroot = root / 'subdir'
@@ -36,12 +25,6 @@ class TestGetSessionObject(unittest.TestCase):
         for k in old_paths.keys():
             self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__()))
 
-        # this is necessary because logger itself is a singleton class
-        self.tearDown()
-        self.setUp()
-        rmtree(newroot)
-        self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory')
-
     def test_change_session_subdirectory(self):
         old_paths = self.sesh.get_paths()
         print(old_paths)
@@ -56,6 +39,23 @@ class TestGetSessionObject(unittest.TestCase):
         self.assertTrue(logfile2.exists())
         self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile')
 
+    def test_reimporting_session_uses_same_logfile(self):
+        session1 = self.sesh
+        logfile1 = session1.logfile
+        self.assertTrue(logfile1.exists())
+
+        from base.session import session as session2
+        self.assertEqual(session1, session2)
+        logfile2 = session2.logfile
+        self.assertTrue(logfile2.exists())
+        self.assertEqual(logfile1, logfile2, 'Reimporting session incorrectly creates new logfile')
+
+        from model_server.base.session import session as session3
+        self.assertEqual(session1, session3)
+        logfile3 = session3.logfile
+        self.assertTrue(logfile3.exists())
+        self.assertEqual(logfile1, logfile3, 'Reimporting session incorrectly creates new logfile')
+
     def test_log_warning(self):
         msg = 'A test warning'
         self.sesh.log_info(msg)