diff --git a/api.py b/api.py index 03861c2dc48d7f072dd059138b00e2c89ad7094f..317c3d2181cbed89faa383748f1bb6baae464283 100644 --- a/api.py +++ b/api.py @@ -1,9 +1,10 @@ from fastapi import FastAPI, HTTPException from model_server.ilastik import IlastikPixelClassifierModel, IlastikObjectClassifierModel -from model_server.model import DummyImageToImageModel +from model_server.model import DummyImageToImageModel, ParameterExpectedError from model_server.session import Session from model_server.workflow import infer_image_to_image +from model_server.workflow_ilastik import infer_px_then_ob_model app = FastAPI(debug=True) session = Session() @@ -17,12 +18,16 @@ def read_root(): return {'success': True} @app.put('/bounce_back') -def read_root(par1=None, par2=None): +def list_bounce_back(par1=None, par2=None): return {'success': True, 'params': {'par1': par1, 'par2': par2}} +@app.get('/paths') +def list_session_paths(): + return session.get_paths() + @app.get('/restart') -def restart_session() -> dict: - session.restart() +def restart_session(root: str = None) -> dict: + session.restart(root=root) return session.describe_loaded_models() @app.get('/models') @@ -33,42 +38,68 @@ def list_active_models(): def load_dummy_model() -> dict: return {'model_id': session.load_model(DummyImageToImageModel)} +def load_ilastik_model(model_class, project_file): + try: + result = { + 'model_id': session.load_model( + model_class, + {'project_file': project_file} + ) + } + except (FileNotFoundError, ParameterExpectedError): + raise HTTPException( + status_code=404, + detail=f'Could not load project file {project_file}', + ) + return result + @app.put('/models/ilastik/pixel_classification/load/') def load_ilastik_pixel_classification_model(project_file: str) -> dict: - return { - 'model_id': session.load_model( - IlastikPixelClassifierModel, - {'project_file': project_file} - ) - } + return load_ilastik_model(IlastikPixelClassifierModel, project_file) @app.put('/models/ilastik/object_classification/load/') def load_ilastik_object_classification_model(project_file: str) -> dict: - return { - 'model_id': session.load_model( - IlastikObjectClassifierModel, - {'project_file': project_file} - ) - } + return load_ilastik_model(IlastikObjectClassifierModel, project_file) + +def validate_workflow_inputs(model_ids, inpaths): + for mid in model_ids: + if mid not in session.describe_loaded_models().keys(): + raise HTTPException( + status_code=409, + detail=f'Model {mid} has not been loaded' + ) + for inpa in inpaths: + if not inpa.exists(): + raise HTTPException( + status_code=404, + detail=f'Could not find file:\n{inpa}' + ) @app.put('/infer/from_image_file') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: - if model_id not in session.describe_loaded_models().keys(): - raise HTTPException( - status_code=409, - detail=f'Model {model_id} has not been loaded' - ) - inpath = session.inbound.path / input_filename - if not inpath.exists(): - raise HTTPException( - status_code=404, - detail=f'Could not find file:\n{inpath}' - ) + inpath = session.paths['inbound_images'] / input_filename + validate_workflow_inputs([model_id], [inpath]) record = infer_image_to_image( inpath, session.models[model_id]['object'], - session.outbound.path, + session.paths['outbound_images'], channel=channel, ) session.record_workflow_run(record) + return record + +@app.put('/models/ilastik/pixel_then_object_classification/infer') +def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict: + inpath = session.paths['inbound_images'] / input_filename + validate_workflow_inputs([px_model_id, ob_model_id], [inpath]) + try: + record = infer_px_then_ob_model( + inpath, + session.models[px_model_id]['object'], + session.models[ob_model_id]['object'], + session.paths['outbound_images'], + channel=channel + ) + except AssertionError: + raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}') return record \ No newline at end of file diff --git a/conf/defaults.py b/conf/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..9323b984322cbd82538408bd01cde34e606bfe7a --- /dev/null +++ b/conf/defaults.py @@ -0,0 +1,9 @@ +from pathlib import Path + +root = Path.home() / 'model-server' / 'sessions' + +subdirectories = { + 'logs': 'logs', + 'inbound_images': 'images/inbound', + 'outbound_images': 'images/outbound', +} \ No newline at end of file diff --git a/conf/ilastik.py b/conf/ilastik.py new file mode 100644 index 0000000000000000000000000000000000000000..de14a50c4d55ceebafd3078c457276750c0ffcf9 --- /dev/null +++ b/conf/ilastik.py @@ -0,0 +1,5 @@ +from pathlib import Path + +paths = { + 'project_files': Path.home() / 'model-server' / 'ilastik' +} \ No newline at end of file diff --git a/conf/server.py b/conf/server.py deleted file mode 100644 index 707b60ac28f42e4a0e02a22bb93802ec5f4b7b23..0000000000000000000000000000000000000000 --- a/conf/server.py +++ /dev/null @@ -1,21 +0,0 @@ -from pathlib import Path - -root = Path('c:/Users/rhodes/projects/proj0015-model-server/resources') -paths = { - 'logs': { - 'session': root / 'logs' / 'session', - }, - 'images': { - 'inbound': root / 'images' / 'inbound', - 'outbound': root / 'images' / 'outbound', - }, - 'ilastik': { - 'projects': root / 'ilastik' - } -} - -for gk in paths.keys(): - for pk in paths[gk].keys(): - paths[gk][pk].mkdir(parents=True, exist_ok=True) - -# TODO: consider configuring paths via API, and this just becomes a HTTP client script \ No newline at end of file diff --git a/conf/testing.py b/conf/testing.py index cd6088e53f4365ec05ec445a933fee9b200a2bab..5540838150cac9bd9bcdb79434d8f5e4a3714b54 100644 --- a/conf/testing.py +++ b/conf/testing.py @@ -1,21 +1,17 @@ from pathlib import Path -root = Path('c:/Users/rhodes/projects/proj0015-model-server/resources') +root = Path.home() / 'model-server' / 'testing' + filename = 'D3-selection-01.czi' czifile = { 'filename': filename, - 'path': root / 'testdata' / filename, + 'path': root / filename, 'w': 1274, 'h': 1274, 'c': 5, 'z': 1, } -# ilastik = { -# 'pixel_classifier': root / 'testdata' / 'ilastik' / 'demo_px.ilp', -# 'object_classifier': root / 'testdata' / 'ilastik' / 'demo_obj.ilp', -# } - ilastik = { 'pixel_classifier': 'demo_px.ilp', 'object_classifier': 'demo_obj.ilp', diff --git a/imagej/infer_ilastik_by_api.py b/imagej/infer_ilastik_by_api.py deleted file mode 100644 index 4e0d58577328a8fb0bcb84f2e9f9d12c3da33e9c..0000000000000000000000000000000000000000 --- a/imagej/infer_ilastik_by_api.py +++ /dev/null @@ -1,56 +0,0 @@ -import httplib -import json -import urllib - -import os -import sys -print(sys.version) - -from ij import IJ - -host = '127.0.0.1' -port = 8001 -uri = 'http://{}:{}/'.format(host, port) - -abspath = IJ.getImage().getProp('Location') -input_filename = os.path.split(abspath)[-1] - -outpath = 'C:\\Users\\rhodes\\projects\\proj0015-model-server\\resources\\testdata' - -def hit_endpoint(method, endpoint, params=None, verbose=True): - connection = httplib.HTTPConnection(host, port) - if not method in ['GET', 'PUT']: - raise Exception('Can only handle GET and PUT requests') - if params: - url = endpoint + '?' + urllib.urlencode(params) - else: - url = endpoint - connection.request(method, url) - resp = connection.getresponse() - resp_str = resp.read() - if verbose: - print(method + ' ' + url + ', status ' + str(resp.status) + ':\n' + resp_str) - return json.loads(resp_str) - -hit_endpoint('GET', '/restart') -resp = hit_endpoint('PUT', '/models/ilastik/pixel_classification/load/', {'project_file': 'demo_px.ilp'}) -pxmid = resp['model_id'] -resp = hit_endpoint('GET', '/models', verbose=True) - -infer_params = { - 'model_id': pxmid, - 'input_filename': input_filename, - 'channel': 0 - } - -hit_endpoint('PUT', '/infer/from_image_file', infer_params) - -import time - -dt_arr = [] -for i in range(0, 10): - t0 = time() - time.sleep(0.1) - dt_arr.append(time() - t0) -print(dt_arr) -print('mean [s]': sum(dt_arr)/len(dt_arr)) diff --git a/model_server/ilastik.py b/model_server/ilastik.py index 54419f78195024c4b4d31e6229f77dec5118ac6a..d2490ba9007ad7e6e8390fb31187a9bac2fe2d63 100644 --- a/model_server/ilastik.py +++ b/model_server/ilastik.py @@ -1,10 +1,10 @@ import os -import pathlib +from pathlib import Path import numpy as np import vigra -import conf.server +import conf.ilastik from model_server.image import GenericImageDataAccessor, InMemoryDataAccessor from model_server.model import ImageToImageModel, ParameterExpectedError @@ -12,10 +12,12 @@ from model_server.model import ImageToImageModel, ParameterExpectedError class IlastikImageToImageModel(ImageToImageModel): def __init__(self, params, autoload=True): - self.project_file = str(params['project_file']) - self.project_file_abspath = pathlib.Path( - conf.server.paths['ilastik']['projects'] / self.project_file, - ) + self.project_file = Path(params['project_file']) + params['project_file'] = self.project_file.__str__() + pap = conf.ilastik.paths['project_files'] / self.project_file + self.project_file_abspath = pap + if not pap.exists(): + raise FileNotFoundError(f'Project file does not exist: {pap}') if 'project_file' not in params or not self.project_file_abspath.exists(): raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file') diff --git a/model_server/image.py b/model_server/image.py index c700ea66bb00e810fb4acc5887b20a7d8e5679e2..5d16d980f26cf037bcab7e72379a722eba5ec0f0 100644 --- a/model_server/image.py +++ b/model_server/image.py @@ -24,7 +24,7 @@ class GenericImageDataAccessor(ABC): @staticmethod def conform_data(data): - if len(data.shape) > 4: + if len(data.shape) > 4 or (0 in data.shape): raise DataShapeError(f'Cannot handle image with dimensions other than X, Y, C, and Z: {data.shape}') ones = [1 for i in range(0, 4 - len(data.shape))] return data.reshape(*data.shape, *ones) @@ -33,7 +33,8 @@ class GenericImageDataAccessor(ABC): return True if self.shape_dict['Z'] > 1 else False def get_one_channel_data (self, channel: int): - return InMemoryDataAccessor(self.data[:, :, int(channel), :]) + c = int(channel) + return InMemoryDataAccessor(self.data[:, :, c:(c+1), :]) @property def data(self): diff --git a/model_server/session.py b/model_server/session.py index 1a9b38b9d3cf01c6f24ec5a5af3bf9818e2f47f6..d53654f84b9e3c74f9fd621f5cab6836653bdea6 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -5,7 +5,7 @@ from pathlib import Path from time import strftime, localtime from typing import Dict -from conf.server import paths +import conf.defaults from model_server.model import Model from model_server.share import SharedImageDirectory from model_server.workflow import WorkflowRunRecord @@ -17,25 +17,52 @@ class Session(object): """ Singleton class for a server session that persists data between API calls """ - inbound = SharedImageDirectory(paths['images']['inbound']) - outbound = SharedImageDirectory(paths['images']['outbound']) - where_records = Path(paths['logs']['session']) + # inbound = SharedImageDirectory(conf.defaults.paths['inbound_images']) + # outbound = SharedImageDirectory(conf.defaults.paths['outbound_images']) + # where_records = conf.defaults.paths['logs'] def __new__(cls): if not hasattr(cls, 'instance'): cls.instance = super(Session, cls).__new__(cls) return cls.instance - def __init__(self): + def __init__(self, root: str = None): print('Initializing session') self.models = {} # model_id : model object self.manifest = [] # paths to data as well as other metadata from each inference run - self.session_id = self.create_session_id(self.where_records) - self.session_log = self.where_records / f'{self.session_id}.log' + self.paths = self.make_paths(root) + # self.session_id = self.create_session_id(self.paths['logs']) + # self.session_id = self.create_session_id(self.paths) + self.session_log = self.paths['logs'] / f'session.log' self.log_event('Initialized session') - self.manifest_json = self.where_records / f'{self.session_id}-manifest.json' + self.manifest_json = self.paths['logs'] / f'manifest.json' open(self.manifest_json, 'w').close() # instantiate empty json file + def get_paths(self): + return self.paths + + @staticmethod + def make_paths(root: str = None) -> dict: + """ + Set paths where images, logs, etc. are located in this session + :param root: absolute path to top-level directory + :return: dictionary of session paths + """ + if root is None: + root_path = Path(conf.defaults.root) + else: + root_path = Path(root) + sid = Session.create_session_id(root_path) + paths = {'root': root_path} + for pk in ['inbound_images', 'outbound_images', 'logs']: + pa = root_path / sid / conf.defaults.subdirectories[pk] + paths[pk] = pa + try: + pa.mkdir(parents=True, exist_ok=True) + except Exception: + raise CouldNotCreateDirectory(f'Could not create directory: {pa}') + return paths + @staticmethod def create_session_id(look_where: Path) -> str: """ @@ -43,7 +70,7 @@ class Session(object): """ yyyymmdd = strftime('%Y%m%d', localtime()) idx = 0 - while os.path.exists(look_where / f'{yyyymmdd}-{idx:04d}.log'): + while os.path.exists(look_where / f'{yyyymmdd}-{idx:04d}'): idx += 1 return f'{yyyymmdd}-{idx:04d}' @@ -52,8 +79,8 @@ class Session(object): Write an event string to this session's log file. """ timestamp = strftime('%m/%d/%Y, %H:%M:%S', localtime()) - with open(self.session_log, 'w+') as fh: - fh.write(f'{timestamp} -- {event}') + with open(self.session_log, 'a') as fh: + fh.write(f'{timestamp} -- {event}\n') def record_workflow_run(self, record: WorkflowRunRecord or None): """ @@ -98,8 +125,8 @@ class Session(object): for k in self.models.keys() } - def restart(self): - self.__init__() + def restart(self, **kwargs): + self.__init__(**kwargs) class Error(Exception): pass @@ -108,4 +135,7 @@ class InferenceRecordError(Error): pass class CouldNotInstantiateModelError(Error): + pass + +class CouldNotCreateDirectory(Error): pass \ No newline at end of file diff --git a/model_server/workflow.py b/model_server/workflow.py index a89f900a95536da22bb3acea88369648fbb7bec8..7c09bc46835eacb04a31871534a91786991267ec 100644 --- a/model_server/workflow.py +++ b/model_server/workflow.py @@ -1,11 +1,12 @@ """ Implementation of image analysis work behind API endpoints, without knowledge of persistent data in server session. """ - +from pathlib import Path from time import perf_counter from typing import Dict from model_server.image import generate_file_accessor, write_accessor_data_to_file +from model_server.model import Model from pydantic import BaseModel @@ -28,7 +29,15 @@ class WorkflowRunRecord(BaseModel): timer_results: Dict[str, float] -def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict: +def infer_image_to_image(fpi: Path, model: Model, where_output: Path, **kwargs) -> WorkflowRunRecord: + """ + Generic workflow where a model processes an input image into an output image + :param fpi: Path object that references input image file + :param model: model object + :param where_output: Path object that references output image directory + :param kwargs: variable-length keyword arguments + :return: record object + """ ti = Timer() ch = kwargs.get('channel') img = generate_file_accessor(fpi).get_one_channel_data(ch) diff --git a/model_server/workflow_ilastik.py b/model_server/workflow_ilastik.py new file mode 100644 index 0000000000000000000000000000000000000000..db412bb4b9b2ba1e33b0ba42d9d5108bbfcdce5d --- /dev/null +++ b/model_server/workflow_ilastik.py @@ -0,0 +1,73 @@ +""" +Implementation of image analysis work behind API endpoints, without knowledge of persistent data in server session. +""" +from pathlib import Path +from time import perf_counter +from typing import Dict + +from model_server.ilastik import IlastikPixelClassifierModel, IlastikObjectClassifierModel +from model_server.image import generate_file_accessor, write_accessor_data_to_file +from model_server.model import Model +from model_server.workflow import Timer + +from pydantic import BaseModel + +class WorkflowRunRecord(BaseModel): + pixel_model_id: str + object_model_id: str + input_filepath: str + pixel_map_filepath: str + object_map_filepath: str + success: bool + timer_results: Dict[str, float] + + +def infer_px_then_ob_model( + fpi: Path, + px_model: IlastikPixelClassifierModel, + ob_model: IlastikObjectClassifierModel, + where_output: Path, + **kwargs +) -> WorkflowRunRecord: + """ + Workflow that specifically runs an ilastik pixel classifier, then passes results to an object classifier, + saving intermediate images + :param fpi: Path object that references input image file + :param px_model: model instance for pixel classification + :param ob_model: model instance for object classification + :param where_output: Path object that references output image directory + :param kwargs: variable-length keyword arguments + :return: + """ + assert isinstance(px_model, IlastikPixelClassifierModel) + assert isinstance(ob_model, IlastikObjectClassifierModel) + + ti = Timer() + ch = kwargs.get('channel') + img = generate_file_accessor(fpi).get_one_channel_data(ch) + ti.click('file_input') + + px_map, _ = px_model.infer(img) + ti.click('pixel_probability_inference') + + px_map_path = where_output / (px_model.model_id + '_pxmap_' + fpi.stem + '.tif') + write_accessor_data_to_file(px_map_path, px_map) + ti.click('pixel_map_output') + + ob_map, _ = ob_model.infer(img, px_map) + ti.click('object_classification') + + ob_map_path = where_output / (ob_model.model_id + '_obmap_' + fpi.stem + '.tif') + write_accessor_data_to_file(ob_map_path, ob_map) + ti.click('object_map_output') + + return WorkflowRunRecord( + pixel_model_id=px_model.model_id, + object_model_id=ob_model.model_id, + input_filepath=str(fpi), + pixel_map_filepath=str(px_map_path), + object_map_filepath=str(ob_map_path), + success=True, + timer_results=ti.events, + ) + diff --git a/tests/test_api.py b/tests/test_api.py index b7835ffd2e5b7f4e24834d690c6223d9f21b1791..40ecfd2985e6a9450fe1bf273eddd18a8890fb55 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,4 +1,5 @@ from multiprocessing import Process +from pathlib import Path import requests import unittest @@ -20,14 +21,12 @@ class TestServerBaseClass(unittest.TestCase): self.uri = f'http://{host}:{port}/' self.server_process.start() - @staticmethod - def copy_input_file_to_server(): - import pathlib + def copy_input_file_to_server(self): from shutil import copyfile - from conf.server import paths - - outpath = pathlib.Path(paths['images']['inbound'] / czifile['filename']) + resp = requests.get(self.uri + 'paths') + pa = resp.json()['inbound_images'] + outpath = Path(pa) / czifile['filename'] copyfile( czifile['path'], outpath @@ -47,6 +46,15 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(resp.json()['params']['par1'], 'hello', resp.json()) self.assertEqual(resp.json()['params']['par2'], None, resp.json()) + def test_default_session_paths(self): + import conf.defaults + resp = requests.get(self.uri + 'paths') + conf_root = conf.defaults.root + for p in ['inbound_images', 'outbound_images', 'logs']: + self.assertTrue(resp.json()[p].startswith(conf_root.__str__())) + suffix = Path(conf.defaults.subdirectories[p]).__str__() + self.assertTrue(resp.json()[p].endswith(suffix)) + def test_list_empty_loaded_models(self): resp = requests.get(self.uri + 'models') self.assertEqual(resp.status_code, 200) diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py index 2cb9b5cdecdc08659e7510a107be32b118d5e3e6..25c876b945fdac880f78af76e7e96edb0e4e1cf3 100644 --- a/tests/test_ilastik.py +++ b/tests/test_ilastik.py @@ -6,7 +6,6 @@ import numpy as np import conf.testing from model_server.image import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.ilastik import IlastikObjectClassifierModel, IlastikPixelClassifierModel -from model_server.model import Model from model_server.workflow import infer_image_to_image from tests.test_api import TestServerBaseClass @@ -110,6 +109,15 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertGreater(result.timer_results['inference'], 1.0) class TestIlastikOverApi(TestServerBaseClass): + + def test_httpexception_if_incorrect_project_file_loaded(self): + resp_load = requests.put( + self.uri + 'models/ilastik/pixel_classification/load/', + params={'project_file': 'improper.ilp'}, + ) + self.assertEqual(resp_load.status_code, 404) + + def test_load_ilastik_pixel_model(self): resp_load = requests.put( self.uri + 'models/ilastik/pixel_classification/load/', @@ -122,7 +130,6 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() self.assertEqual(rj[model_id]['class'], 'IlastikPixelClassifierModel') - return model_id @@ -138,9 +145,10 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierModel') + return model_id def test_ilastik_infer_pixel_probability(self): - TestServerBaseClass.copy_input_file_to_server() + self.copy_input_file_to_server() model_id = self.test_load_ilastik_pixel_model() resp_infer = requests.put( @@ -153,20 +161,18 @@ class TestIlastikOverApi(TestServerBaseClass): ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) - def test_load_ilastik_pixel_model_in_subdirectory(self): - px_ilp = 'proj0011-exp0004/px3d.ilp' - pf = px_ilp + def test_ilastik_infer_px_then_ob(self): + self.copy_input_file_to_server() + px_model_id = self.test_load_ilastik_pixel_model() + ob_model_id = self.test_load_ilastik_object_model() - resp_load = requests.put( - self.uri + 'models/ilastik/pixel_classification/load/', - params={'project_file': pf}, + resp_infer = requests.put( + self.uri + f'models/ilastik/pixel_then_object_classification/infer/', + params={ + 'px_model_id': px_model_id, + 'ob_model_id': ob_model_id, + 'input_filename': conf.testing.czifile['filename'], + 'channel': 0, + } ) - model_id = resp_load.json()['model_id'] - - self.assertEqual(resp_load.status_code, 200, resp_load.json()) - resp_list = requests.get(self.uri + 'models') - self.assertEqual(resp_list.status_code, 200) - rj = resp_list.json() - self.assertEqual(rj[model_id]['class'], 'IlastikPixelClassifierModel') - - return model_id + self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) diff --git a/tests/test_image.py b/tests/test_image.py index 367041e60046bc91b19dcf4ddd33a5c871f744a2..83d25ba1488080aa657c7bfe212c84a87387a1f0 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -6,6 +6,7 @@ from conf.testing import czifile, output_path from model_server.image import CziImageFileAccessor, DataShapeError, InMemoryDataAccessor, write_accessor_data_to_file class TestCziImageFileAccess(unittest.TestCase): + def setUp(self) -> None: pass @@ -19,6 +20,16 @@ class TestCziImageFileAccess(unittest.TestCase): self.assertEqual(cf.shape[0], czifile['h']) self.assertEqual(cf.shape[1], czifile['w']) + def test_get_single_channel_from_zstack(self): + w = 256 + h = 512 + nc = 4 + nz = 11 + c = 3 + cf = InMemoryDataAccessor(np.random.rand(w, h, nc, nz)) + sc = cf.get_one_channel_data(c) + self.assertEqual(sc.shape, (w, h, 1, nz)) + def test_write_single_channel_tif(self): ch = 4 cf = CziImageFileAccessor(czifile['path']) diff --git a/tests/test_session.py b/tests/test_session.py index 6a15c68be361d5e9eab2a668c5b186d359dc1047..614ad5d107635ed19a6c651493389789d7eb2939 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -14,12 +14,27 @@ class TestGetSessionObject(unittest.TestCase): self.assertTrue(exists(sesh.session_log), 'Session did not create a log file in the correct place') self.assertTrue(exists(sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place') + def test_changing_session_root_creates_new_directory(self): + from conf.defaults import root + from shutil import rmtree + + sesh = Session() + old_paths = sesh.get_paths() + newroot = root / 'subdir' + sesh.restart(root=newroot) + new_paths = sesh.get_paths() + for k in old_paths.keys(): + self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__())) + rmtree(newroot) + self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory') + + def test_restart_session(self): sesh = Session() logfile1 = sesh.session_log sesh.restart() logfile2 = sesh.session_log - self.assertIsNot(logfile1, logfile2, 'Restarting session does not generate new logfile') + self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile') def test_session_records_workflow(self): import json @@ -44,7 +59,6 @@ class TestGetSessionObject(unittest.TestCase): success = sesh.load_model(MC) self.assertTrue(success) loaded_models = sesh.describe_loaded_models() - print(loaded_models) self.assertTrue( (MC.__name__ + '_00') in loaded_models.keys() ) @@ -58,7 +72,6 @@ class TestGetSessionObject(unittest.TestCase): MC = DummyImageToImageModel sesh.load_model(MC) sesh.load_model(MC) - print(sesh.models.keys()) self.assertIn(MC.__name__ + '_00', sesh.models.keys()) self.assertIn(MC.__name__ + '_01', sesh.models.keys()) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 9bca0b4af3d440206041ddb16c6c56891a18ed56..6bc8d56d88f960d63cdada4efe68e65c6ecd572f 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -34,6 +34,4 @@ class TestGetSessionObject(unittest.TestCase): img[0, 0], 0, 'First pixel is not black as expected' - ) - - print(result.timer_results) \ No newline at end of file + ) \ No newline at end of file