Skip to content
Snippets Groups Projects
Commit bcc66f86 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Merge branch 'issue_singleton_namespace' into dev_conda_build

# Conflicts:
#	model_server/base/api.py
#	model_server/extensions/ilastik/router.py
parents daa63b51 6ad29bb8
No related branches found
No related tags found
No related merge requests found
...@@ -2,16 +2,15 @@ from fastapi import FastAPI, HTTPException ...@@ -2,16 +2,15 @@ from fastapi import FastAPI, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel from base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel
from base.session import InvalidPathError, Session from base.session import session, InvalidPathError
from base.validators import validate_workflow_inputs from base.validators import validate_workflow_inputs
from base.workflows import classify_pixels from base.workflows import classify_pixels
from extensions.ilastik.workflows import infer_px_then_ob_model from extensions.ilastik.workflows import infer_px_then_ob_model
app = FastAPI(debug=True) app = FastAPI(debug=True)
session = Session()
import extensions.ilastik.router import model_server.extensions.ilastik.router
app.include_router(extensions.ilastik.router.router) app.include_router(model_server.extensions.ilastik.router.router)
@app.on_event("startup") @app.on_event("startup")
def startup(): def startup():
......
...@@ -14,14 +14,6 @@ from base.models import Model ...@@ -14,14 +14,6 @@ from base.models import Model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class CsvTable(object): class CsvTable(object):
def __init__(self, fpath: Path): def __init__(self, fpath: Path):
self.path = fpath self.path = fpath
...@@ -38,7 +30,7 @@ class CsvTable(object): ...@@ -38,7 +30,7 @@ class CsvTable(object):
self.empty = False self.empty = False
return True return True
class Session(object, metaclass=Singleton): class _Session(object):
""" """
Singleton class for a server session that persists data between API calls Singleton class for a server session that persists data between API calls
""" """
...@@ -100,7 +92,7 @@ class Session(object, metaclass=Singleton): ...@@ -100,7 +92,7 @@ class Session(object, metaclass=Singleton):
root_path = Path(conf.defaults.root) root_path = Path(conf.defaults.root)
else: else:
root_path = Path(root) root_path = Path(root)
sid = Session.create_session_id(root_path) sid = _Session.create_session_id(root_path)
paths = {'root': root_path} paths = {'root': root_path}
for pk in ['inbound_images', 'outbound_images', 'logs', 'tables']: for pk in ['inbound_images', 'outbound_images', 'logs', 'tables']:
pa = root_path / sid / conf.defaults.subdirectories[pk] pa = root_path / sid / conf.defaults.subdirectories[pk]
...@@ -193,6 +185,11 @@ class Session(object, metaclass=Singleton): ...@@ -193,6 +185,11 @@ class Session(object, metaclass=Singleton):
def restart(self, **kwargs): def restart(self, **kwargs):
self.__init__(**kwargs) self.__init__(**kwargs)
# create singleton instance
session = _Session()
class Error(Exception): class Error(Exception):
pass pass
......
from fastapi import HTTPException from fastapi import HTTPException
from base.session import Session from base.session import session
session = Session()
def validate_workflow_inputs(model_ids, inpaths): def validate_workflow_inputs(model_ids, inpaths):
for mid in model_ids: for mid in model_ids:
......
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from base.session import Session from model_server.base.session import session
from base.validators import validate_workflow_inputs from model_server.base.validators import validate_workflow_inputs
from extensions.ilastik import models as ilm from model_server.extensions.ilastik import models as ilm
from extensions.ilastik.workflows import infer_px_then_ob_model from model_server.base.models import ParameterExpectedError
from model_server.extensions.ilastik.workflows import infer_px_then_ob_model
router = APIRouter( router = APIRouter(
prefix='/ilastik', prefix='/ilastik',
tags=['ilastik'], tags=['ilastik'],
) )
session = Session()
def load_ilastik_model(model_class: ilm.IlastikModel, params: ilm.IlastikParams) -> dict: def load_ilastik_model(model_class: ilm.IlastikModel, params: ilm.IlastikParams) -> dict:
""" """
......
...@@ -3,31 +3,20 @@ import pathlib ...@@ -3,31 +3,20 @@ import pathlib
from pydantic import BaseModel from pydantic import BaseModel
import unittest import unittest
import base.session
from base.models import DummySemanticSegmentationModel from base.models import DummySemanticSegmentationModel
from base.session import Session from base.session import session
class TestGetSessionObject(unittest.TestCase): class TestGetSessionObject(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.sesh = Session() session.restart()
self.sesh = session
def tearDown(self) -> None:
print('Tearing down...')
Session._instances = {}
def test_session_is_singleton(self):
Session._instances = {}
self.assertEqual(len(Session._instances), 0)
s = Session()
self.assertEqual(len(Session._instances), 1)
self.assertIs(s, Session())
self.assertEqual(len(Session._instances), 1)
def test_session_logfile_is_valid(self): def test_session_logfile_is_valid(self):
self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place') self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')
def test_changing_session_root_creates_new_directory(self): def test_changing_session_root_creates_new_directory(self):
from model_server.conf.defaults import root from model_server.conf.defaults import root
from shutil import rmtree
old_paths = self.sesh.get_paths() old_paths = self.sesh.get_paths()
newroot = root / 'subdir' newroot = root / 'subdir'
...@@ -36,12 +25,6 @@ class TestGetSessionObject(unittest.TestCase): ...@@ -36,12 +25,6 @@ class TestGetSessionObject(unittest.TestCase):
for k in old_paths.keys(): for k in old_paths.keys():
self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__())) self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__()))
# this is necessary because logger itself is a singleton class
self.tearDown()
self.setUp()
rmtree(newroot)
self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory')
def test_change_session_subdirectory(self): def test_change_session_subdirectory(self):
old_paths = self.sesh.get_paths() old_paths = self.sesh.get_paths()
print(old_paths) print(old_paths)
...@@ -56,6 +39,23 @@ class TestGetSessionObject(unittest.TestCase): ...@@ -56,6 +39,23 @@ class TestGetSessionObject(unittest.TestCase):
self.assertTrue(logfile2.exists()) self.assertTrue(logfile2.exists())
self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile') self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile')
def test_reimporting_session_uses_same_logfile(self):
session1 = self.sesh
logfile1 = session1.logfile
self.assertTrue(logfile1.exists())
from base.session import session as session2
self.assertEqual(session1, session2)
logfile2 = session2.logfile
self.assertTrue(logfile2.exists())
self.assertEqual(logfile1, logfile2, 'Reimporting session incorrectly creates new logfile')
from model_server.base.session import session as session3
self.assertEqual(session1, session3)
logfile3 = session3.logfile
self.assertTrue(logfile3.exists())
self.assertEqual(logfile1, logfile3, 'Reimporting session incorrectly creates new logfile')
def test_log_warning(self): def test_log_warning(self):
msg = 'A test warning' msg = 'A test warning'
self.sesh.log_info(msg) self.sesh.log_info(msg)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment