Skip to content
Snippets Groups Projects
Commit bcc66f86 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Merge branch 'issue_singleton_namespace' into dev_conda_build

# Conflicts:
#	model_server/base/api.py
#	model_server/extensions/ilastik/router.py
parents daa63b51 6ad29bb8
No related branches found
No related tags found
No related merge requests found
......@@ -2,16 +2,15 @@ from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel
from base.session import InvalidPathError, Session
from base.session import session, InvalidPathError
from base.validators import validate_workflow_inputs
from base.workflows import classify_pixels
from extensions.ilastik.workflows import infer_px_then_ob_model
app = FastAPI(debug=True)
session = Session()
import extensions.ilastik.router
app.include_router(extensions.ilastik.router.router)
import model_server.extensions.ilastik.router
app.include_router(model_server.extensions.ilastik.router.router)
@app.on_event("startup")
def startup():
......
......@@ -14,14 +14,6 @@ from base.models import Model
logger = logging.getLogger(__name__)
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class CsvTable(object):
def __init__(self, fpath: Path):
self.path = fpath
......@@ -38,7 +30,7 @@ class CsvTable(object):
self.empty = False
return True
class Session(object, metaclass=Singleton):
class _Session(object):
"""
Singleton class for a server session that persists data between API calls
"""
......@@ -100,7 +92,7 @@ class Session(object, metaclass=Singleton):
root_path = Path(conf.defaults.root)
else:
root_path = Path(root)
sid = Session.create_session_id(root_path)
sid = _Session.create_session_id(root_path)
paths = {'root': root_path}
for pk in ['inbound_images', 'outbound_images', 'logs', 'tables']:
pa = root_path / sid / conf.defaults.subdirectories[pk]
......@@ -193,6 +185,11 @@ class Session(object, metaclass=Singleton):
def restart(self, **kwargs):
self.__init__(**kwargs)
# create singleton instance
session = _Session()
class Error(Exception):
pass
......
from fastapi import HTTPException
from base.session import Session
session = Session()
from base.session import session
def validate_workflow_inputs(model_ids, inpaths):
for mid in model_ids:
......
from fastapi import APIRouter, HTTPException
from base.session import Session
from base.validators import validate_workflow_inputs
from model_server.base.session import session
from model_server.base.validators import validate_workflow_inputs
from extensions.ilastik import models as ilm
from extensions.ilastik.workflows import infer_px_then_ob_model
from model_server.extensions.ilastik import models as ilm
from model_server.base.models import ParameterExpectedError
from model_server.extensions.ilastik.workflows import infer_px_then_ob_model
router = APIRouter(
prefix='/ilastik',
tags=['ilastik'],
)
session = Session()
def load_ilastik_model(model_class: ilm.IlastikModel, params: ilm.IlastikParams) -> dict:
"""
......
......@@ -3,31 +3,20 @@ import pathlib
from pydantic import BaseModel
import unittest
import base.session
from base.models import DummySemanticSegmentationModel
from base.session import Session
from base.session import session
class TestGetSessionObject(unittest.TestCase):
def setUp(self) -> None:
self.sesh = Session()
def tearDown(self) -> None:
print('Tearing down...')
Session._instances = {}
def test_session_is_singleton(self):
Session._instances = {}
self.assertEqual(len(Session._instances), 0)
s = Session()
self.assertEqual(len(Session._instances), 1)
self.assertIs(s, Session())
self.assertEqual(len(Session._instances), 1)
session.restart()
self.sesh = session
def test_session_logfile_is_valid(self):
self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')
def test_changing_session_root_creates_new_directory(self):
from model_server.conf.defaults import root
from shutil import rmtree
old_paths = self.sesh.get_paths()
newroot = root / 'subdir'
......@@ -36,12 +25,6 @@ class TestGetSessionObject(unittest.TestCase):
for k in old_paths.keys():
self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__()))
# this is necessary because logger itself is a singleton class
self.tearDown()
self.setUp()
rmtree(newroot)
self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory')
def test_change_session_subdirectory(self):
old_paths = self.sesh.get_paths()
print(old_paths)
......@@ -56,6 +39,23 @@ class TestGetSessionObject(unittest.TestCase):
self.assertTrue(logfile2.exists())
self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile')
def test_reimporting_session_uses_same_logfile(self):
session1 = self.sesh
logfile1 = session1.logfile
self.assertTrue(logfile1.exists())
from base.session import session as session2
self.assertEqual(session1, session2)
logfile2 = session2.logfile
self.assertTrue(logfile2.exists())
self.assertEqual(logfile1, logfile2, 'Reimporting session incorrectly creates new logfile')
from model_server.base.session import session as session3
self.assertEqual(session1, session3)
logfile3 = session3.logfile
self.assertTrue(logfile3.exists())
self.assertEqual(logfile1, logfile3, 'Reimporting session incorrectly creates new logfile')
def test_log_warning(self):
msg = 'A test warning'
self.sesh.log_info(msg)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment