Skip to content
Snippets Groups Projects
Commit 1851ba66 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Merge branch 'issue0027-serialize-roiset' into 'master'

Completed (de)serialization of RoiSet

See merge request !16
parents 88f0bff9 c0b4e877
No related branches found
No related tags found
4 merge requests!39Sync up staging,!38Sync up staging now that master is on a new release,!19Draft: Resolve "pixel_probabilities export option does not do anything",!16Completed (de)serialization of RoiSet
......@@ -162,6 +162,14 @@ class CziImageFileAccessor(GenericImageFileAccessor):
except Exception:
raise FileAccessorError(f'Unable to access CZI data in {fpath}')
try:
md = cf.metadata(raw=False)
compmet = md['ImageDocument']['Metadata']['Information']['Image']['OriginalCompressionMethod']
except KeyError:
raise InvalidCziCompression('Could not find metadata key OriginalCompressionMethod')
if compmet.upper() != 'UNCOMPRESSED':
raise InvalidCziCompression(f'Unsupported compression method {compmet}')
sd = {ch: cf.shape[cf.axes.index(ch)] for ch in cf.axes}
if (sd.get('S') and (sd['S'] > 1)) or (sd.get('T') and (sd['T'] > 1)):
raise DataShapeError(f'Cannot handle image with multiple positions or time points: {sd}')
......@@ -284,6 +292,29 @@ class PatchStack(InMemoryDataAccessor):
def count(self):
return self.shape_dict['P']
def export_pyxcz(self, fpath: Path):
tzcyx = np.moveaxis(
self.pyxcz, # yxcz
[0, 4, 3, 1, 2],
[0, 1, 2, 3, 4]
)
if self.is_mask():
if self.dtype == 'bool':
data = (tzcyx * 255).astype('uint8')
else:
data = tzcyx.astype('uint8')
tifffile.imwrite(fpath, data, imagej=True)
else:
tifffile.imwrite(fpath, tzcyx, imagej=True)
def get_one_channel_data(self, channel: int, mip: bool = False):
c = int(channel)
if mip:
return PatchStack(self.pyxcz[:, :, :, c:(c + 1), :].max(axis=-1, keepdims=True))
else:
return PatchStack(self.pyxcz[:, :, :, c:(c + 1), :])
@property
def shape_dict(self):
return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))
......@@ -313,16 +344,32 @@ class PatchStack(InMemoryDataAccessor):
return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))
def make_patch_stack_from_file(fpath): # interpret z-dimension as patch position
def make_patch_stack_from_file(fpath): # interpret t-dimension as patch position
if not Path(fpath).exists():
raise FileNotFoundError(f'Could not find {fpath}')
pyxc = np.moveaxis(
generate_file_accessor(fpath).data, # yxcz
[0, 1, 2, 3],
[1, 2, 3, 0]
try:
tf = tifffile.TiffFile(fpath)
except Exception:
raise FileAccessorError(f'Unable to access data in {fpath}')
if len(tf.series) != 1:
raise DataShapeError(f'Expect only one series in {fpath}')
se = tf.series[0]
axs = [a for a in se.axes if a in [*'TZCYX']]
sd = dict(zip(axs, se.shape))
for a in [*'TZC']:
if a not in axs:
sd[a] = 1
tzcyx = se.asarray().reshape([sd[k] for k in [*'TZCYX']])
pyxcz = np.moveaxis(
tzcyx,
[0, 3, 4, 2, 1],
[0, 1, 2, 3, 4],
)
pyxcz = np.expand_dims(pyxc, axis=3)
return PatchStack(pyxcz)
......@@ -345,6 +392,9 @@ class FileWriteError(Error):
class InvalidAxisKey(Error):
pass
class InvalidCziCompression(Error):
pass
class InvalidDataShape(Error):
pass
......
......@@ -46,6 +46,7 @@ def change_path(key, path):
status_code=404,
detail=e.__str__(),
)
session.log_info(f'Change {key} path to {path}')
return session.get_paths()
@app.put('/paths/watch_input')
......@@ -56,22 +57,30 @@ def watch_input_path(path: str):
def watch_input_path(path: str):
return change_path('outbound_images', path)
@app.get('/restart')
@app.get('/session/restart')
def restart_session(root: str = None) -> dict:
session.restart(root=root)
return session.describe_loaded_models()
@app.get('/session/logs')
def list_session_log() -> list:
return session.get_log_data()
@app.get('/models')
def list_active_models():
return session.describe_loaded_models()
@app.put('/models/dummy_semantic/load/')
def load_dummy_model() -> dict:
return {'model_id': session.load_model(DummySemanticSegmentationModel)}
mid = session.load_model(DummySemanticSegmentationModel)
session.log_info(f'Loaded model {mid}')
return {'model_id': mid}
@app.put('/models/dummy_instance/load/')
def load_dummy_model() -> dict:
return {'model_id': session.load_model(DummyInstanceSegmentationModel)}
mid = session.load_model(DummyInstanceSegmentationModel)
session.log_info(f'Loaded model {mid}')
return {'model_id': mid}
@app.put('/workflows/segment')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
......@@ -83,5 +92,5 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
session.paths['outbound_images'],
channel=channel,
)
session.record_workflow_run(record)
session.log_info(f'Completed segmentation of {input_filename}')
return record
\ No newline at end of file
......@@ -3,7 +3,7 @@ from math import floor
import numpy as np
from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, PatchStack
class Model(ABC):
......@@ -88,6 +88,16 @@ class InstanceSegmentationModel(ImageToImageModel):
if not img.shape == mask.shape:
raise InvalidInputImageError('Expect input image and mask to be the same shape')
def label_patch_stack(self, img: PatchStack, mask: PatchStack, **kwargs):
"""
Iterative over a patch stack, call inference on each patch
:return: PatchStack of same shape in input
"""
res_data = np.zeros(img.shape, dtype='uint16')
for i in range(0, img.count): # interpret as PYXCZ
res_data[i, :, :, :, :] = self.label_instance_class(img.iat(i), mask.iat(i), **kwargs).data
return PatchStack(res_data)
class DummySemanticSegmentationModel(SemanticSegmentationModel):
......
This diff is collapsed.
import json
import logging
import os
from pathlib import Path
from time import strftime, localtime
from typing import Dict
import pandas as pd
import model_server.conf.defaults
from model_server.base.models import Model
from model_server.base.workflows import WorkflowRunRecord
def create_manifest_json():
pass
logger = logging.getLogger(__name__)
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class CsvTable(object):
def __init__(self, fpath: Path):
self.path = fpath
self.empty = True
class Session(object):
def append(self, coords: dict, data: pd.DataFrame) -> bool:
assert isinstance(data, pd.DataFrame)
for c in reversed(coords.keys()):
data.insert(0, c, coords[c])
if self.empty:
data.to_csv(self.path, index=False, mode='w', header=True)
else:
data.to_csv(self.path, index=False, mode='a', header=False)
self.empty = False
return True
class Session(object, metaclass=Singleton):
"""
Singleton class for a server session that persists data between API calls
"""
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(Session, cls).__new__(cls)
return cls.instance
log_format = '%(asctime)s - %(levelname)s - %(message)s'
def __init__(self, root: str = None):
print('Initializing session')
self.models = {} # model_id : model object
self.manifest = [] # paths to data as well as other metadata from each inference run
self.paths = self.make_paths(root)
self.session_log = self.paths['logs'] / f'session.log'
self.log_event('Initialized session')
self.manifest_json = self.paths['logs'] / f'manifest.json'
open(self.manifest_json, 'w').close() # instantiate empty json file
self.logfile = self.paths['logs'] / f'session.log'
logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True, format=self.log_format)
self.log_info('Initialized session')
self.tables = {}
def write_to_table(self, name: str, coords: dict, data: pd.DataFrame):
"""
Write data to a named data table, initializing if it does not yet exist.
:param name: name of the table to persist through session
:param coords: dictionary of coordinates to associate with all rows in this method call
:param data: DataFrame containing data
:return: True if successful
"""
try:
if name in self.tables.keys():
table = self.tables.get(name)
else:
table = CsvTable(self.paths['tables'] / (name + '.csv'))
self.tables[name] = table
except Exception:
raise CouldNotCreateTable(f'Unable to create table named {name}')
try:
table.append(coords, data)
return True
except Exception:
raise CouldNotAppendToTable(f'Unable to append data to table named {name}')
def get_paths(self):
return self.paths
......@@ -55,7 +101,7 @@ class Session(object):
root_path = Path(root)
sid = Session.create_session_id(root_path)
paths = {'root': root_path}
for pk in ['inbound_images', 'outbound_images', 'logs']:
for pk in ['inbound_images', 'outbound_images', 'logs', 'tables']:
pa = root_path / sid / model_server.conf.defaults.subdirectories[pk]
paths[pk] = pa
try:
......@@ -75,22 +121,23 @@ class Session(object):
idx += 1
return f'{yyyymmdd}-{idx:04d}'
def log_event(self, event: str):
"""
Write an event string to this session's log file.
"""
timestamp = strftime('%m/%d/%Y, %H:%M:%S', localtime())
with open(self.session_log, 'a') as fh:
fh.write(f'{timestamp} -- {event}\n')
def get_log_data(self) -> list:
log = []
with open(self.logfile, 'r') as fh:
for line in fh:
k = ['datatime', 'level', 'message']
v = line.strip().split(' - ')[0:3]
log.insert(0, dict(zip(k, v)))
return log
def record_workflow_run(self, record: WorkflowRunRecord or None):
"""
Append a JSON describing inference data to this session's manifest
"""
self.log_event(f'Ran model {record.model_id} on {record.input_filepath} to infer {record.output_filepath}')
with open(self.manifest_json, 'w+') as fh:
json.dump(record.dict(), fh)
def log_info(self, msg):
logger.info(msg)
def log_warning(self, msg):
logger.warning(msg)
def log_error(self, msg):
logger.error(msg)
def load_model(self, ModelClass: Model, params: Dict[str, str] = None) -> dict:
"""
......@@ -114,7 +161,7 @@ class Session(object):
'object': mi,
'params': params
}
self.log_event(f'Loaded model {key}')
self.log_info(f'Loaded model {key}')
return key
def describe_loaded_models(self) -> dict:
......@@ -157,5 +204,11 @@ class CouldNotInstantiateModelError(Error):
class CouldNotCreateDirectory(Error):
pass
class CouldNotCreateTable(Error):
pass
class CouldNotAppendToTable(Error):
pass
class InvalidPathError(Error):
pass
\ No newline at end of file
......@@ -6,6 +6,7 @@ subdirectories = {
'logs': 'logs',
'inbound_images': 'images/inbound',
'outbound_images': 'images/outbound',
'tables': 'tables',
}
server_conf = {
......
......@@ -67,6 +67,7 @@ roiset_test_data = {
'c': 5,
'z': 7,
'mask_path': root / 'zmask-test-stack-mask.tif',
'mask_path_3d': root / 'zmask-test-stack-mask-3d.tif',
},
'pipeline_params': {
'segmentation_channel': 0,
......
import json
import os
from pathlib import Path
......@@ -12,8 +13,17 @@ from model_server.base.models import Model, ImageToImageModel, InstanceSegmentat
class IlastikModel(Model):
def __init__(self, params, autoload=True):
def __init__(self, params, autoload=True, enforce_embedded=True):
"""
Base class for models that run via ilastik shell API
:param params:
project_file: path to ilastik project file
:param autoload: automatically load model into memory if true
:param enforce_embedded:
raise an error if all input data are not embedded in the project file, i.e. on the filesystem
"""
self.project_file = Path(params['project_file'])
self.enforce_embedded = enforce_embedded
params['project_file'] = self.project_file.__str__()
if self.project_file.is_absolute():
pap = self.project_file
......@@ -42,6 +52,15 @@ class IlastikModel(Model):
args.project = self.project_file_abspath.__str__()
shell = app.main(args, init_logging=False)
# validate if inputs are embedded in project file
input_groups = shell.projectManager.currentProjectFile['Input Data']['infos']
lanes = input_groups.keys()
for ll in lanes:
input_types = input_groups[ll]
for tt in input_types:
ds_loc = input_groups[ll][tt].get('location', False)
if self.enforce_embedded and ds_loc and ds_loc[()] == b'FileSystem':
raise IlastikInputEmbedding('Cannot load ilastik project file where inputs are on filesystem')
if not isinstance(shell.workflow, self.get_workflow()):
raise ParameterExpectedError(
f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}'
......@@ -55,12 +74,35 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
model_id = 'ilastik_pixel_classification'
operations = ['segment', ]
@property
def model_shape_dict(self):
raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data']
ax = raw_info['axistags'][()]
ax_keys = [ax['key'].upper() for ax in json.loads(ax)['axes']]
shape = raw_info['shape'][()]
dd = dict(zip(ax_keys, shape))
for ci in 'TCZ':
if ci not in dd.keys():
dd[ci] = 1
return dd
@property
def model_chroma(self):
return self.model_shape_dict['C']
@property
def model_3d(self):
return self.model_shape_dict['Z'] > 1
@staticmethod
def get_workflow():
from ilastik.workflows import PixelClassificationWorkflow
return PixelClassificationWorkflow
def infer(self, input_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
dsi = [
{
......@@ -87,22 +129,33 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
model_id = 'ilastik_object_classification_from_segmentation'
@staticmethod
def _make_8bit_mask(nda):
if nda.dtype == 'bool':
return 255 * nda.astype('uint8')
else:
return nda
@staticmethod
def get_workflow():
from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary
return ObjectClassificationWorkflowBinary
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
assert segmentation_img.is_mask()
if segmentation_img.dtype == 'bool':
seg = 255 * segmentation_img.data.astype('uint8')
if isinstance(input_img, PatchStack):
assert isinstance(segmentation_img, PatchStack)
tagged_input_data = vigra.taggedView(input_img.pczyx, 'tczyx')
tagged_seg_data = vigra.taggedView(
255 * segmentation_img.data.astype('uint8'),
'yxcz'
self._make_8bit_mask(segmentation_img.pczyx),
'tczyx'
)
else:
tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_seg_data = vigra.taggedView(
self._make_8bit_mask(segmentation_img.data),
'yxcz'
)
dsi = [
{
......@@ -115,12 +168,21 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
assert len(obmaps) == 1, 'ilastik generated more than one object map'
yxcz = np.moveaxis(
obmaps[0],
[1, 2, 3, 0],
[0, 1, 2, 3]
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
if isinstance(input_img, PatchStack):
pyxcz = np.moveaxis(
obmaps[0],
[0, 1, 2, 3, 4],
[0, 4, 1, 2, 3]
)
return PatchStack(data=pyxcz), {'success': True}
else:
yxcz = np.moveaxis(
obmaps[0],
[1, 2, 3, 0],
[0, 1, 2, 3]
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
......@@ -172,48 +234,40 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
"""
if not img.shape == pxmap.shape:
raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape')
# TODO: check that pxmap is in-range
pxch = kwargs.get('pixel_classification_channel', 0)
pxtr = kwargs('pixel_classification_threshold', 0.5)
pxtr = kwargs.get('pixel_classification_threshold', 0.5)
mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
# super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
obmap, _ = self.infer(img, mask)
return obmap
def make_instance_segmentation_model(self, px_ch: int):
"""
Generate an instance segmentation model, i.e. one that takes binary masks instead of pixel probabilities as a
second input.
:param px_ch: channel of pixel probability map to use
:return:
InstanceSegmentationModel object
"""
class _Mod(self.__class__, InstanceSegmentationModel):
def label_instance_class(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
) -> GenericImageDataAccessor:
if mask.dtype == 'bool':
norm_mask = 1.0 * mask.data
else:
norm_mask = mask.data / np.iinfo(mask.dtype).max
norm_mask_acc = InMemoryDataAccessor(norm_mask.astype('float32'))
return super().label_instance_class(img, norm_mask_acc, pixel_classification_channel=px_ch)
return _Mod(params={'project_file': self.project_file})
class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
"""
Wrap ilastik object classification for inputs comprising single-object series of raw images and binary
segmentation masks.
"""
def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict):
assert segmentation_acc.is_mask()
if not input_acc.chroma == 1:
raise InvalidInputImageError('Object classifier expects only monochrome patches')
if not input_acc.nz == 1:
raise InvalidInputImageError('Object classifier expects only 2d patches')
tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx')
tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx')
dsi = [
{
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
}
]
obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
assert len(obmaps) == 1, 'ilastik generated more than one object map'
class Error(Exception):
pass
# for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C
assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1)
pyxcz = np.moveaxis(
obmaps[0],
[0, 1, 2, 3, 4],
[0, 4, 1, 2, 3]
)
class IlastikInputEmbedding(Error):
pass
return PatchStack(data=pyxcz), {'success': True}
\ No newline at end of file
class IlastikInputShapeError(Error):
"""Raised when an ilastik classifier is asked to infer on data that is incompatible with its input shape"""
pass
\ No newline at end of file
......@@ -25,9 +25,11 @@ def load_ilastik_model(model_class: ilm.IlastikModel, project_file: str, duplica
if not duplicate:
existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True)
if existing_model_id is not None:
session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate')
return {'model_id': existing_model_id}
try:
result = session.load_model(model_class, {'project_file': project_file})
session.log_info(f'Loaded ilastik model {result} from {project_file}')
except (FileNotFoundError, ParameterExpectedError):
raise HTTPException(
status_code=404,
......@@ -60,6 +62,7 @@ def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: st
channel=channel,
mip=mip,
)
session.log_info(f'Completed pixel and object classification of {input_filename}')
except AssertionError:
raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}')
return record
\ No newline at end of file
......@@ -11,6 +11,9 @@ from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams
from model_server.base.workflows import classify_pixels
from tests.test_api import TestServerBaseClass
def _random_int(*args):
return np.random.randint(0, 2 ** 8, size=args, dtype='uint8')
class TestIlastikPixelClassification(unittest.TestCase):
def setUp(self) -> None:
self.cf = CziImageFileAccessor(czifile['path'])
......@@ -83,6 +86,40 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.mono_image = mono_image
self.mask = mask
def test_pixel_classifier_enforces_input_shape(self):
model = ilm.IlastikPixelClassifierModel(
{'project_file': ilastik_classifiers['px']}
)
self.assertEqual(model.model_chroma, 1)
self.assertEqual(model.model_3d, False)
# correct data
self.assertIsInstance(
model.label_pixel_class(
InMemoryDataAccessor(
_random_int(512, 256, 1, 1)
)
),
InMemoryDataAccessor
)
# raise except with input of multiple channels
with self.assertRaises(ilm.IlastikInputShapeError):
mask = model.label_pixel_class(
InMemoryDataAccessor(
_random_int(512, 256, 3, 1)
)
)
# raise except with input of multiple channels
with self.assertRaises(ilm.IlastikInputShapeError):
mask = model.label_pixel_class(
InMemoryDataAccessor(
_random_int(512, 256, 1, 15)
)
)
def test_run_object_classifier_from_pixel_predictions(self):
self.test_run_pixel_classifier()
fp = czifile['path']
......@@ -97,7 +134,24 @@ class TestIlastikPixelClassification(unittest.TestCase):
objmap,
)
)
self.assertEqual(objmap.data.max(), 3)
self.assertEqual(objmap.data.max(), 2)
def test_make_seg_obj_model_from_pxmap_obj(self):
self.test_run_pixel_classifier()
fp = czifile['path']
pxmap_model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
{'project_file': ilastik_classifiers['pxmap_to_obj']}
)
seg_model = pxmap_model.make_instance_segmentation_model(px_ch=0)
objmap = seg_model.label_instance_class(self.mono_image, self.mask)
self.assertTrue(
write_accessor_data_to_file(
output_path / f'obmap_seg_from_pxmap_{fp.stem}.tif',
objmap,
)
)
self.assertEqual(objmap.data.max(), 2)
def test_run_object_classifier_from_segmentation(self):
self.test_run_pixel_classifier()
......@@ -113,7 +167,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
objmap,
)
)
self.assertEqual(objmap.data.max(), 3)
self.assertEqual(objmap.data.max(), 2)
def test_ilastik_pixel_classification_as_workflow(self):
result = classify_pixels(
......@@ -168,7 +222,7 @@ class TestIlastikOverApi(TestServerBaseClass):
self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd)
def test_no_duplicate_model_with_different_path_formats(self):
self._get('restart')
self._get('session/restart')
resp_list_1 = self._get('models').json()
self.assertEqual(len(resp_list_1), 0)
ilp = ilastik_classifiers['px']
......@@ -269,16 +323,18 @@ class TestIlastikObjectClassification(unittest.TestCase):
)
)
self.object_classifier = ilm.PatchStackObjectClassifier(
self.object_classifier = ilm.IlastikObjectClassifierFromSegmentationModel(
params={'project_file': ilastik_classifiers['seg_to_obj']}
)
def test_classify_patches(self):
raw_patches = self.roiset.get_raw_patches()
patch_masks = self.roiset.get_patch_masks()
res_patches, _ = self.object_classifier.infer(raw_patches, patch_masks)
raw_patches = self.roiset.get_patches_acc()
patch_masks = self.roiset.get_patch_masks_acc()
res_patches = self.object_classifier.label_instance_class(raw_patches, patch_masks)
self.assertEqual(res_patches.count, self.roiset.count)
res_patches.export_pyxcz(output_path / 'res_patches.tif')
for pi in range(0, res_patches.count): # assert that there is only one nonzero label per patch
unique = np.unique(res_patches.iat(pi).data)
self.assertEqual(len(unique), 2)
self.assertEqual(unique[0], 0)
la, ct = np.unique(res_patches.iat(pi).data, return_counts=True)
self.assertEqual(np.sum(ct > 1), 2) # exclude single-pixel anomaly
self.assertEqual(la[0], 0)
......@@ -7,6 +7,9 @@ from model_server.base.accessors import PatchStack, make_patch_stack_from_file,
from model_server.conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile, monozstackmask
from model_server.base.accessors import CziImageFileAccessor, DataShapeError, generate_file_accessor, InMemoryDataAccessor, PngFileAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor
def _random_int(*args):
return np.random.randint(0, 2 ** 8, size=args, dtype='uint8')
class TestCziImageFileAccess(unittest.TestCase):
def setUp(self) -> None:
......@@ -40,7 +43,7 @@ class TestCziImageFileAccess(unittest.TestCase):
nc = 4
nz = 11
c = 3
cf = InMemoryDataAccessor(np.random.rand(h, w, nc, nz))
cf = InMemoryDataAccessor(_random_int(h, w, nc, nz))
sc = cf.get_one_channel_data(c)
self.assertEqual(sc.shape, (h, w, 1, nz))
......@@ -70,7 +73,7 @@ class TestCziImageFileAccess(unittest.TestCase):
def test_conform_data_shorter_than_xycz(self):
h = 256
w = 512
data = np.random.rand(h, w, 1)
data = _random_int(h, w, 1)
acc = InMemoryDataAccessor(data)
self.assertEqual(
InMemoryDataAccessor.conform_data(data).shape,
......@@ -82,7 +85,7 @@ class TestCziImageFileAccess(unittest.TestCase):
)
def test_conform_data_longer_than_xycz(self):
data = np.random.rand(256, 512, 12, 8, 3)
data = _random_int(256, 512, 12, 8, 3)
with self.assertRaises(DataShapeError):
acc = InMemoryDataAccessor(data)
......@@ -93,7 +96,7 @@ class TestCziImageFileAccess(unittest.TestCase):
c = 3
nz = 10
yxcz = (2**8 * np.random.rand(h, w, c, nz)).astype('uint8')
yxcz = _random_int(h, w, c, nz)
acc = InMemoryDataAccessor(yxcz)
fp = output_path / f'rand3d.tif'
self.assertTrue(
......@@ -138,7 +141,7 @@ class TestPatchStackAccessor(unittest.TestCase):
w = 256
h = 512
n = 4
acc = PatchStack(np.random.rand(n, h, w, 1, 1))
acc = PatchStack(_random_int(n, h, w, 1, 1))
self.assertEqual(acc.count, n)
self.assertEqual(acc.hw, (h, w))
self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1))
......@@ -147,7 +150,7 @@ class TestPatchStackAccessor(unittest.TestCase):
w = 256
h = 512
n = 4
acc = PatchStack([np.random.rand(h, w, 1, 1) for _ in range(0, n)])
acc = PatchStack([_random_int(h, w, 1, 1) for _ in range(0, n)])
self.assertEqual(acc.count, n)
self.assertEqual(acc.hw, (h, w))
self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1))
......@@ -176,8 +179,8 @@ class TestPatchStackAccessor(unittest.TestCase):
nz = 5
n = 4
patches = [np.random.rand(h, w, c, nz) for _ in range(0, n)]
patches.append(np.random.rand(h, 2 * w, c, nz))
patches = [_random_int(h, w, c, nz) for _ in range(0, n)]
patches.append(_random_int(h, 2 * w, c, nz))
acc = PatchStack(patches)
self.assertEqual(acc.count, n + 1)
self.assertEqual(acc.hw, (h, 2 * w))
......@@ -191,7 +194,30 @@ class TestPatchStackAccessor(unittest.TestCase):
n = 4
nz = 15
nc = 2
acc = PatchStack(np.random.rand(n, h, w, nc, nz))
acc = PatchStack(_random_int(n, h, w, nc, nz))
self.assertEqual(acc.count, n)
self.assertEqual(acc.pczyx.shape, (n, nc, nz, h, w))
self.assertEqual(acc.hw, (h, w))
return acc
def test_get_one_channel(self):
acc = self.test_pczyx()
mono = acc.get_one_channel_data(channel=1)
for a in 'PXYZ':
self.assertEqual(mono.shape_dict[a], acc.shape_dict[a])
self.assertEqual(mono.shape_dict['C'], 1)
def test_get_one_channel_mip(self):
acc = self.test_pczyx()
mono_mip = acc.get_one_channel_data(channel=1, mip=True)
for a in 'PXY':
self.assertEqual(mono_mip.shape_dict[a], acc.shape_dict[a])
for a in 'CZ':
self.assertEqual(mono_mip.shape_dict[a], 1)
def test_export_pczyx_patch_hyperstack(self):
acc = self.test_pczyx()
fp = output_path / 'patch_hyperstack.tif'
acc.export_pyxcz(fp)
acc2 = make_patch_stack_from_file(fp)
self.assertEqual(acc.shape, acc2.shape)
\ No newline at end of file
......@@ -136,7 +136,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
self.assertEqual(resp_list_0.status_code, 200)
rj0 = resp_list_0.json()
self.assertEqual(len(rj0), 1, f'Unexpected models in response: {rj0}')
resp_restart = self._get('restart')
resp_restart = self._get('session/restart')
resp_list_1 = self._get('models')
rj1 = resp_list_1.json()
self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}')
......@@ -171,4 +171,9 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
)
self.assertEqual(resp_change.status_code, 200)
resp_check = self._get('paths')
self.assertEqual(resp_inpath.json()['outbound_images'], resp_check.json()['outbound_images'])
\ No newline at end of file
self.assertEqual(resp_inpath.json()['outbound_images'], resp_check.json()['outbound_images'])
def test_get_logs(self):
resp = self._get('session/logs')
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.json()[0]['message'], 'Initialized session')
\ No newline at end of file
import os
import re
import unittest
import numpy as np
from pathlib import Path
import pandas as pd
from model_server.conf.testing import output_path, roiset_test_data
from model_server.base.roiset import RoiSetExportParams, RoiSetMetaParams
......@@ -29,31 +33,36 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
params=RoiSetMetaParams(
mask_type=mask_type,
filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}),
expand_box_by=(64, 2)
expand_box_by=(128, 2)
)
)
return roiset
def test_roi_mask_shape(self, **kwargs):
roiset = self._make_roi_set(**kwargs)
# all masks' bounding boxes are at least as big as ROI area
for roi in roiset.get_df().itertuples():
self.assertEqual(roi.binary_mask.dtype, 'bool')
sh = roi.binary_mask.shape
self.assertEqual(sh, (roi.h, roi.w))
self.assertGreaterEqual(sh[0] * sh[1], roi.area)
def test_roi_zmask(self, **kwargs):
roiset = self._make_roi_set(**kwargs)
zmask = roiset.get_zmask()
zmask_acc = InMemoryDataAccessor(zmask)
self.assertTrue(zmask_acc.is_mask())
# assert dimensionality of zmask
self.assertGreater(zmask_acc.shape_dict['Z'], 1)
self.assertEqual(zmask_acc.shape_dict['C'], 1)
self.assertEqual(zmask_acc.nz, roiset.acc_raw.nz)
self.assertEqual(zmask_acc.chroma, 1)
write_accessor_data_to_file(output_path / 'zmask.tif', zmask_acc)
# mask values are not just all True or all False
self.assertTrue(np.any(zmask))
self.assertFalse(np.all(zmask))
# assert non-trivial meta info in boxes
self.assertGreater(roiset.count, 1)
sh = roiset.get_df().iloc[1]['mask'].shape
ar = roiset.get_df().iloc[1]['area']
self.assertGreaterEqual(sh[0] * sh[1], ar)
def test_roiset_from_non_zstacks(self, **kwargs):
acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0])
......@@ -73,37 +82,75 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
self.assertEqual(len(ebb.shape), 4)
self.assertTrue(np.all([si >= 1 for si in ebb.shape]))
def test_dataframe_and_mask_array_in_iterator(self):
roiset = self._make_roi_set()
for roi in roiset:
ma = roi.binary_mask
self.assertEqual(ma.dtype, 'bool')
self.assertEqual(ma.shape, (roi.h, roi.w))
def test_rel_slices_are_valid(self):
roiset = self._make_roi_set()
for roi in roiset:
ebb = roiset.acc_raw.data[roi.slice]
ebb = roiset.acc_raw.data[roi.expanded_slice]
self.assertEqual(len(ebb.shape), 4)
self.assertTrue(np.all([si >= 1 for si in ebb.shape]))
rbb = ebb[roi.relative_slice]
self.assertEqual(len(rbb.shape), 4)
self.assertTrue(np.all([si >= 1 for si in rbb.shape]))
def test_make_2d_patches(self):
def test_make_expanded_2d_patches(self):
roiset = self._make_roi_set()
files = roiset.export_patches(
output_path / '2d_patches',
output_path / 'expanded_2d_patches',
draw_bounding_box=True,
expanded=True,
pad_to=256,
)
self.assertGreaterEqual(len(files), 1)
def test_make_3d_patches(self):
df = roiset.get_df()
for f in files:
acc = generate_file_accessor(f)
la = int(re.search(r'la([\d]+)', str(f)).group(1))
roi_q = df.loc[df.label == la, :]
self.assertEqual(len(roi_q), 1)
self.assertEqual((256, 256), acc.hw)
def test_make_tight_2d_patches(self):
roiset = self._make_roi_set()
files = roiset.export_patches(
output_path / 'tight_2d_patches',
draw_bounding_box=True,
expanded=False
)
df = roiset.get_df()
for f in files: # all exported files are same shape as bounding boxes in RoiSet's datatable
acc = generate_file_accessor(f)
la = int(re.search(r'la([\d]+)', str(f)).group(1))
roi_q = df.loc[df.label == la, :]
self.assertEqual(len(roi_q), 1)
roi = roi_q.iloc[0]
self.assertEqual((roi.h, roi.w), acc.hw)
def test_make_expanded_3d_patches(self):
roiset = self._make_roi_set()
files = roiset.export_patches(
output_path / '3d_patches',
make_3d=True)
make_3d=True,
expanded=True
)
self.assertGreaterEqual(len(files), 1)
for f in files:
acc = generate_file_accessor(f)
self.assertGreater(acc.nz, 1)
def test_export_annotated_zstack(self):
roiset = self._make_roi_set()
file = roiset.export_annotated_zstack(
output_path / 'annotated_zstack',
)
result = generate_file_accessor(Path(file['location']) / file['filename'])
result = generate_file_accessor(file)
self.assertEqual(result.shape, roiset.acc_raw.shape)
def test_flatten_image(self):
......@@ -132,7 +179,15 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
def test_make_binary_masks(self):
roiset = self._make_roi_set()
files = roiset.export_patch_masks(output_path / '2d_mask_patches', )
self.assertGreaterEqual(len(files), 1)
df = roiset.get_df()
for f in files: # all exported files are same shape as bounding boxes in RoiSet's datatable
acc = generate_file_accessor(output_path / '2d_mask_patches' / f)
la = int(re.search(r'la([\d]+)', str(f)).group(1))
roi_q = df.loc[df.label == la, :]
self.assertEqual(len(roi_q), 1)
roi = roi_q.iloc[0]
self.assertEqual((roi.h, roi.w), acc.hw)
def test_classify_by(self):
roiset = self._make_roi_set()
......@@ -156,17 +211,21 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
def test_raw_patches_are_correct_shape(self):
roiset = self._make_roi_set()
patches = roiset.get_raw_patches()
patches = roiset.get_patches_acc()
np, h, w, nc, nz = patches.shape
self.assertEqual(np, roiset.count)
self.assertEqual(nc, roiset.acc_raw.chroma)
self.assertEqual(nz, 1)
def test_patch_masks_are_correct_shape(self):
roiset = self._make_roi_set()
patch_masks = roiset.get_patch_masks()
np, h, w, nc, nz = patch_masks.shape
self.assertEqual(np, roiset.count)
self.assertEqual(nc, 1)
df_patch_masks = roiset.get_patch_masks()
for roi in df_patch_masks.itertuples():
h, w, nc, nz = roi.patch_mask.shape
self.assertEqual(nc, 1)
self.assertEqual(nz, 1)
self.assertEqual(h, roi.h)
self.assertEqual(w, roi.w)
class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
......@@ -189,8 +248,10 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
output_path / 'multichannel' / 'mono_2d_patches',
white_channel=3,
draw_bounding_box=True,
expanded=True,
pad_to=256,
)
result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename'])
result = generate_file_accessor(files[0])
self.assertEqual(result.chroma, 1)
def test_multichannnel_to_mono_2d_patches_rgb_bbox(self):
......@@ -199,8 +260,10 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
white_channel=3,
draw_bounding_box=True,
bounding_box_channel=1,
expanded=True,
pad_to=256,
)
result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename'])
result = generate_file_accessor(files[0])
self.assertEqual(result.chroma, 3)
def test_multichannnel_to_rgb_2d_patches_bbox(self):
......@@ -208,11 +271,28 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
output_path / 'multichannel' / 'rgb_2d_patches_bbox',
white_channel=4,
rgb_overlay_channels=(3, None, None),
draw_mask=False,
draw_bounding_box=True,
bounding_box_channel=1,
rgb_overlay_weights=(0.1, 1.0, 1.0),
expanded=True,
pad_to=256,
)
result = generate_file_accessor(files[0])
self.assertEqual(result.chroma, 3)
def test_multichannnel_to_rgb_2d_patches_mask(self):
files = self.roiset.export_patches(
output_path / 'multichannel' / 'rgb_2d_patches_mask',
white_channel=4,
rgb_overlay_channels=(3, None, None),
draw_mask=True,
mask_channel=0,
rgb_overlay_weights=(0.1, 1.0, 1.0)
rgb_overlay_weights=(0.1, 1.0, 1.0),
expanded=True,
pad_to=256,
)
result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename'])
result = generate_file_accessor(files[0])
self.assertEqual(result.chroma, 3)
def test_multichannnel_to_rgb_2d_patches_contour(self):
......@@ -221,25 +301,32 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
rgb_overlay_channels=(3, None, None),
draw_contour=True,
contour_channel=1,
rgb_overlay_weights=(0.1, 1.0, 1.0)
rgb_overlay_weights=(0.1, 1.0, 1.0),
expanded=True,
pad_to=256,
)
result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename'])
result = generate_file_accessor(files[0])
self.assertEqual(result.chroma, 3)
self.assertEqual(result.get_one_channel_data(2).data.max(), 0) # blue channel is black
def test_multichannel_to_multichannel_tif_patches(self):
files = self.roiset.export_patches(
output_path / 'multichannel' / 'multichannel_tif_patches',
expanded=True,
pad_to=256,
)
result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename'])
result = generate_file_accessor(files[0])
self.assertEqual(result.chroma, 5)
self.assertEqual(result.nz, 1)
def test_multichannel_annotated_zstack(self):
file = self.roiset.export_annotated_zstack(
output_path / 'multichannel' / 'annotated_zstack',
'test_multichannel_annotated_zstack',
expanded=True,
pad_to=256,
)
result = generate_file_accessor(Path(file['location']) / file['filename'])
result = generate_file_accessor(file)
self.assertEqual(result.chroma, self.stack.chroma)
self.assertEqual(result.nz, self.stack.nz)
......@@ -247,9 +334,104 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
file = self.roiset.export_annotated_zstack(
output_path / 'annotated_zstack',
channel=3,
expanded=True,
pad_to=256,
)
result = generate_file_accessor(Path(file['location']) / file['filename'])
result = generate_file_accessor(file)
self.assertEqual(result.hw, self.roiset.acc_raw.hw)
self.assertEqual(result.nz, self.roiset.acc_raw.nz)
self.assertEqual(result.chroma, 1)
class TestRoiSetFromZmask(unittest.TestCase):
def setUp(self) -> None:
# set up test raw data and segmentation from file
self.stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path'])
self.stack_ch_pa = self.stack.get_one_channel_data(roiset_test_data['pipeline_params']['segmentation_channel'])
self.seg_mask_3d = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path_3d'])
@staticmethod
def _label_is_2d(id_map, la): # single label's zmask has same counts as its MIP
mask_3d = (id_map == la)
mask_mip = mask_3d.max(axis=-1)
return mask_3d.sum() == mask_mip.sum()
def test_id_map_connects_z(self):
id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=True)
labels = np.unique(id_map.data)[1:]
is_2d = all([self._label_is_2d(id_map.data, la) for la in labels])
self.assertFalse(is_2d)
def test_id_map_disconnects_z(self):
id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False)
labels = np.unique(id_map.data)[1:]
is_2d = all([self._label_is_2d(id_map.data, la) for la in labels])
self.assertTrue(is_2d)
def test_create_roiset_from_3d_obj_ids(self):
id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False)
self.assertEqual(self.stack_ch_pa.shape, id_map.shape)
roiset = RoiSet(
self.stack_ch_pa,
id_map,
params=RoiSetMetaParams(mask_type='contours')
)
self.assertEqual(roiset.count, id_map.data.max())
self.assertGreater(len(roiset.get_df()['zi'].unique()), 1)
def test_create_roiset_from_2d_obj_ids(self):
id_map = _get_label_ids(self.seg_mask_3d, allow_3d=False)
self.assertEqual(self.stack_ch_pa.shape[0:3], id_map.shape[0:3])
self.assertEqual(id_map.nz, 1)
roiset = RoiSet(
self.stack_ch_pa,
id_map,
params=RoiSetMetaParams(mask_type='contours')
)
self.assertEqual(roiset.count, id_map.data.max())
self.assertGreater(len(roiset.get_df()['zi'].unique()), 1)
return roiset
def test_create_roiset_from_df_and_patch_masks(self):
ref_roiset = self.test_create_roiset_from_2d_obj_ids()
where_ser = output_path / 'serialize'
ref_roiset.serialize(where_ser, prefix='ref')
where_df = where_ser / 'dataframe' / 'ref.csv'
self.assertTrue(where_df.exists())
df_test = pd.read_csv(where_df)
# check that patches are correct size
where_patch_masks = where_ser / 'tight_patch_masks'
patch_filenames = []
for pmf in where_patch_masks.iterdir():
self.assertTrue(pmf.suffix.upper() == '.PNG')
la = int(re.search(r'la([\d]+)', str(pmf)).group(1))
roi_q = df_test.loc[df_test.label == la, :]
self.assertEqual(len(roi_q), 1)
roi = roi_q.iloc[0]
m_acc = generate_file_accessor(pmf)
self.assertEqual((roi.h, roi.w), m_acc.hw)
patch_filenames.append(pmf.name)
# make another RoiSet from just the data table, raw images, and (tight) patch masks
test_roiset = RoiSet.deserialize(self.stack_ch_pa, where_ser, prefix='ref')
self.assertEqual(ref_roiset.get_zmask().shape, test_roiset.get_zmask().shape,)
self.assertTrue((ref_roiset.get_zmask() == test_roiset.get_zmask()).all())
self.assertTrue(np.all(test_roiset.get_df().label == ref_roiset.get_df().label))
cols = ['label', 'y1', 'y0', 'x1', 'x0', 'zi']
self.assertTrue((test_roiset.get_df()[cols] == ref_roiset.get_df()[cols]).all().all())
# re-serialize and check that patch masks are the same
where_dser = output_path / 'deserialize'
test_roiset.serialize(where_dser, prefix='test')
for fr in patch_filenames:
pr = (where_ser / 'tight_patch_masks' / fr)
self.assertTrue(pr.exists())
pt = (where_dser / 'tight_patch_masks' / fr.replace('ref', 'test'))
self.assertTrue(pt.exists())
r_acc = generate_file_accessor(pr)
t_acc = generate_file_accessor(pt)
self.assertTrue(np.all(r_acc.data == t_acc.data))
import json
from os.path import exists
import pathlib
import unittest
from model_server.base.models import DummySemanticSegmentationModel
from model_server.base.session import Session
from model_server.base.workflows import WorkflowRunRecord
class TestGetSessionObject(unittest.TestCase):
def setUp(self) -> None:
pass
self.sesh = Session()
def tearDown(self) -> None:
print('Tearing down...')
Session._instances = {}
def test_single_session_instance(self):
sesh = Session()
self.assertIs(sesh, Session(), 'Re-initializing Session class returned a new object')
def test_session_is_singleton(self):
Session._instances = {}
self.assertEqual(len(Session._instances), 0)
s = Session()
self.assertEqual(len(Session._instances), 1)
self.assertIs(s, Session())
self.assertEqual(len(Session._instances), 1)
from os.path import exists
self.assertTrue(exists(sesh.session_log), 'Session did not create a log file in the correct place')
self.assertTrue(exists(sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place')
def test_session_logfile_is_valid(self):
self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')
def test_changing_session_root_creates_new_directory(self):
from model_server.conf.defaults import root
from shutil import rmtree
sesh = Session()
old_paths = sesh.get_paths()
old_paths = self.sesh.get_paths()
newroot = root / 'subdir'
sesh.restart(root=newroot)
new_paths = sesh.get_paths()
self.sesh.restart(root=newroot)
new_paths = self.sesh.get_paths()
for k in old_paths.keys():
self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__()))
# this is necessary because logger itself is a singleton class
self.tearDown()
self.setUp()
rmtree(newroot)
self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory')
def test_change_session_subdirectory(self):
sesh = Session()
old_paths = sesh.get_paths()
old_paths = self.sesh.get_paths()
print(old_paths)
sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
self.assertEqual(sesh.paths['outbound_images'], sesh.paths['inbound_images'])
def test_restart_session(self):
sesh = Session()
logfile1 = sesh.session_log
sesh.restart()
logfile2 = sesh.session_log
self.sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images'])
def test_restarting_session_creates_new_logfile(self):
logfile1 = self.sesh.logfile
self.assertTrue(logfile1.exists())
self.sesh.restart()
logfile2 = self.sesh.logfile
self.assertTrue(logfile2.exists())
self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile')
def test_session_records_workflow(self):
import json
from model_server.base.workflows import WorkflowRunRecord
sesh = Session()
di = WorkflowRunRecord(
model_id='test_model',
input_filepath='/test/input/directory',
output_filepath='/test/output/fi.le',
success=True,
timer_results={'start': 0.123},
)
sesh.record_workflow_run(di)
with open(sesh.manifest_json, 'r') as fh:
do = json.load(fh)
self.assertEqual(di.dict(), do, 'Manifest record is not correct')
def test_log_warning(self):
msg = 'A test warning'
self.sesh.log_info(msg)
with open(self.sesh.logfile, 'r') as fh:
log = fh.read()
self.assertTrue(msg in log)
def test_get_logs(self):
self.sesh.log_info('Info example 1')
self.sesh.log_warning('Example warning')
self.sesh.log_info('Info example 2')
logs = self.sesh.get_log_data()
self.assertEqual(len(logs), 4)
self.assertEqual(logs[1]['level'], 'WARNING')
self.assertEqual(logs[-1]['message'], 'Initialized session')
def test_session_loads_model(self):
sesh = Session()
MC = DummySemanticSegmentationModel
success = sesh.load_model(MC)
success = self.sesh.load_model(MC)
self.assertTrue(success)
loaded_models = sesh.describe_loaded_models()
loaded_models = self.sesh.describe_loaded_models()
self.assertTrue(
(MC.__name__ + '_00') in loaded_models.keys()
)
......@@ -76,46 +87,56 @@ class TestGetSessionObject(unittest.TestCase):
)
def test_session_loads_second_instance_of_same_model(self):
sesh = Session()
MC = DummySemanticSegmentationModel
sesh.load_model(MC)
sesh.load_model(MC)
self.assertIn(MC.__name__ + '_00', sesh.models.keys())
self.assertIn(MC.__name__ + '_01', sesh.models.keys())
self.sesh.load_model(MC)
self.sesh.load_model(MC)
self.assertIn(MC.__name__ + '_00', self.sesh.models.keys())
self.assertIn(MC.__name__ + '_01', self.sesh.models.keys())
def test_session_loads_model_with_params(self):
sesh = Session()
MC = DummySemanticSegmentationModel
p1 = {'p1': 'abc'}
success = sesh.load_model(MC, params=p1)
success = self.sesh.load_model(MC, params=p1)
self.assertTrue(success)
loaded_models = sesh.describe_loaded_models()
loaded_models = self.sesh.describe_loaded_models()
mid = MC.__name__ + '_00'
self.assertEqual(loaded_models[mid]['params'], p1)
# load a second model and confirm that the first is locatable by its param entry
p2 = {'p2': 'def'}
sesh.load_model(MC, params=p2)
find_mid = sesh.find_param_in_loaded_models('p1', 'abc')
self.sesh.load_model(MC, params=p2)
find_mid = self.sesh.find_param_in_loaded_models('p1', 'abc')
self.assertEqual(mid, find_mid)
self.assertEqual(sesh.describe_loaded_models()[mid]['params'], p1)
self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1)
def test_session_finds_existing_model_with_different_path_formats(self):
sesh = Session()
MC = DummySemanticSegmentationModel
p1 = {'path': 'c:\\windows\\dummy.pa'}
p2 = {'path': 'c:/windows/dummy.pa'}
mid = sesh.load_model(MC, params=p1)
mid = self.sesh.load_model(MC, params=p1)
assert pathlib.Path(p1['path']) == pathlib.Path(p2['path'])
find_mid = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
find_mid = self.sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
self.assertEqual(mid, find_mid)
def test_change_output_path(self):
import pathlib
sesh = Session()
pa = sesh.get_paths()['inbound_images']
pa = self.sesh.get_paths()['inbound_images']
self.assertIsInstance(pa, pathlib.Path)
sesh.set_data_directory('outbound_images', pa.__str__())
self.assertEqual(sesh.paths['inbound_images'], sesh.paths['outbound_images'])
self.assertIsInstance(sesh.paths['outbound_images'], pathlib.Path)
\ No newline at end of file
self.sesh.set_data_directory('outbound_images', pa.__str__())
self.assertEqual(self.sesh.paths['inbound_images'], self.sesh.paths['outbound_images'])
self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path)
def test_make_table(self):
import pandas as pd
data = [{'modulo': i % 2, 'times one hundred': i * 100} for i in range(0, 8)]
self.sesh.write_to_table(
'test_numbers', {'X': 0, 'Y': 0}, pd.DataFrame(data[0:4])
)
self.assertTrue(self.sesh.tables['test_numbers'].path.exists())
self.sesh.write_to_table(
'test_numbers', {'X': 1, 'Y': 1}, pd.DataFrame(data[4:8])
)
dfv = pd.read_csv(self.sesh.tables['test_numbers'].path)
self.assertEqual(len(dfv), len(data))
self.assertEqual(dfv.columns[0], 'X')
self.assertEqual(dfv.columns[1], 'Y')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment