Skip to content
Snippets Groups Projects
Commit 86692770 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Merge branch 'release-2024.04.19' into 'master'

Release 2024.04.19

See merge request rhodes/model_server!37
Showing
with 1419 additions and 393 deletions
......@@ -13,6 +13,8 @@ from model_server.base.process import is_mask
class GenericImageDataAccessor(ABC):
axes = 'YXCZ'
@abstractmethod
def __init__(self):
"""
......@@ -25,6 +27,13 @@ class GenericImageDataAccessor(ABC):
def chroma(self):
return self.shape_dict['C']
@staticmethod
def _derived_accessor(data):
"""
Create a new accessor given np.ndarray data; used for example in slicing operations
"""
return InMemoryDataAccessor(data)
@staticmethod
def conform_data(data):
if len(data.shape) > 4 or (0 in data.shape):
......@@ -38,12 +47,23 @@ class GenericImageDataAccessor(ABC):
def is_mask(self):
return is_mask(self._data)
def get_one_channel_data (self, channel: int, mip: bool = False):
c = int(channel)
def get_channels(self, channels: list, mip: bool = False):
carr = [int(c) for c in channels]
if mip:
return InMemoryDataAccessor(self.data[:, :, c:(c+1), :].max(axis=-1))
nda = self.data.take(indices=carr, axis=self._ga('C')).max(axis=self._ga('Z'), keepdims=True)
return self._derived_accessor(nda)
else:
return InMemoryDataAccessor(self.data[:, :, c:(c+1), :])
nda = self.data.take(indices=carr, axis=self._ga('C'))
return self._derived_accessor(nda)
def get_one_channel_data(self, channel: int, mip: bool = False):
return self.get_channels([channel], mip=mip)
def _gc(self, channels):
return self.get_channels(list(channels))
def _unique(self):
return np.unique(self.data, return_counts=True)
@property
def pixel_scale_in_micrometers(self):
......@@ -53,6 +73,12 @@ class GenericImageDataAccessor(ABC):
def dtype(self):
return self.data.dtype
def get_axis(self, ch):
return self.axes.index(ch.upper())
def _ga(self, arg):
return self.get_axis(arg)
@property
def hw(self):
"""
......@@ -162,6 +188,14 @@ class CziImageFileAccessor(GenericImageFileAccessor):
except Exception:
raise FileAccessorError(f'Unable to access CZI data in {fpath}')
try:
md = cf.metadata(raw=False)
compmet = md['ImageDocument']['Metadata']['Information']['Image']['OriginalCompressionMethod']
except KeyError:
raise InvalidCziCompression('Could not find metadata key OriginalCompressionMethod')
if compmet.upper() != 'UNCOMPRESSED':
raise InvalidCziCompression(f'Unsupported compression method {compmet}')
sd = {ch: cf.shape[cf.axes.index(ch)] for ch in cf.axes}
if (sd.get('S') and (sd['S'] > 1)) or (sd.get('T') and (sd['T'] > 1)):
raise DataShapeError(f'Cannot handle image with multiple positions or time points: {sd}')
......@@ -250,40 +284,84 @@ def generate_file_accessor(fpath):
class PatchStack(InMemoryDataAccessor):
def __init__(self, data):
axes = 'PYXCZ'
def __init__(self, data, force_ydim_longest=False):
"""
A sequence of n (generally) color 3D images of the same size
:param data: either a list of np.ndarrays of size YXCZ, or np.ndarray of size PYXCZ
:param force_ydmin_longest: if creating a PatchStack from a list of different-sized patches, rotate each
as needed so that height is always greater than or equal to width
"""
self._slices = []
if isinstance(data, list): # list of YXCZ patches
n = len(data)
if force_ydim_longest:
psh = np.array([e.shape[0:2] for e in data]).max(axis=1).max()
psw = np.array([e.shape[0:2] for e in data]).min(axis=1).max()
psc, psz = np.array([e.shape[2:] for e in data]).max(axis=0)
yxcz_shape = np.array([psh, psw, psc, psz])
else:
yxcz_shape = np.array([e.shape for e in data]).max(axis=0)
nda = np.zeros(
(n, *yxcz_shape), dtype=data[0].dtype
)
for i in range(0, len(data)):
s = tuple([slice(0, c) for c in data[i].shape])
nda[i][s] = data[i]
h, w = data[i].shape[0:2]
if force_ydim_longest and w > h:
patch = np.rot90(data[i], axes=(0, 1))
else:
patch = data[i]
s = tuple([slice(0, c) for c in patch.shape])
nda[i][s] = patch
self._slices.append(s)
elif isinstance(data, np.ndarray) and len(data.shape) == 5: # interpret as PYXCZ
nda = data
for i in range(0, len(data)):
self._slices.append(tuple([slice(0, c) for c in data[i].shape]))
else:
raise InvalidDataForPatchStackError(f'Cannot create accessor from {type(data)}')
assert nda.ndim == 5
self._data = nda
def iat(self, i):
@staticmethod
def _derived_accessor(data):
return PatchStack(data)
def get_slice_at(self, i):
return self._slices[i]
def iat(self, i, crop=False):
if crop:
return InMemoryDataAccessor(self.data[i, :, :, :, :][self._slices[i]])
else:
return InMemoryDataAccessor(self.data[i, :, :, :, :])
def iat_yxcz(self, i):
return self.iat(i)
def iat_yxcz(self, i, crop=False):
return self.iat(i, crop=crop)
@property
def count(self):
return self.shape_dict['P']
def export_pyxcz(self, fpath: Path):
tzcyx = np.moveaxis(
self.pyxcz, # yxcz
[0, 4, 3, 1, 2],
[0, 1, 2, 3, 4]
)
if self.is_mask():
if self.dtype == 'bool':
data = (tzcyx * 255).astype('uint8')
else:
data = tzcyx.astype('uint8')
tifffile.imwrite(fpath, data, imagej=True)
else:
tifffile.imwrite(fpath, tzcyx, imagej=True)
@property
def shape_dict(self):
return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))
......@@ -313,16 +391,32 @@ class PatchStack(InMemoryDataAccessor):
return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))
def make_patch_stack_from_file(fpath): # interpret z-dimension as patch position
def make_patch_stack_from_file(fpath): # interpret t-dimension as patch position
if not Path(fpath).exists():
raise FileNotFoundError(f'Could not find {fpath}')
pyxc = np.moveaxis(
generate_file_accessor(fpath).data, # yxcz
[0, 1, 2, 3],
[1, 2, 3, 0]
try:
tf = tifffile.TiffFile(fpath)
except Exception:
raise FileAccessorError(f'Unable to access data in {fpath}')
if len(tf.series) != 1:
raise DataShapeError(f'Expect only one series in {fpath}')
se = tf.series[0]
axs = [a for a in se.axes if a in [*'TZCYX']]
sd = dict(zip(axs, se.shape))
for a in [*'TZC']:
if a not in axs:
sd[a] = 1
tzcyx = se.asarray().reshape([sd[k] for k in [*'TZCYX']])
pyxcz = np.moveaxis(
tzcyx,
[0, 3, 4, 2, 1],
[0, 1, 2, 3, 4],
)
pyxcz = np.expand_dims(pyxc, axis=3)
return PatchStack(pyxcz)
......@@ -345,6 +439,9 @@ class FileWriteError(Error):
class InvalidAxisKey(Error):
pass
class InvalidCziCompression(Error):
pass
class InvalidDataShape(Error):
pass
......
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from model_server.base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel
from model_server.base.session import Session, InvalidPathError
......@@ -20,9 +21,13 @@ def startup():
def read_root():
return {'success': True}
class BounceBackParams(BaseModel):
par1: str
par2: list
@app.put('/bounce_back')
def list_bounce_back(par1=None, par2=None):
return {'success': True, 'params': {'par1': par1, 'par2': par2}}
def list_bounce_back(params: BounceBackParams):
return {'success': True, 'params': {'par1': params.par1, 'par2': params.par2}}
@app.get('/paths')
def list_session_paths():
......@@ -46,6 +51,7 @@ def change_path(key, path):
status_code=404,
detail=e.__str__(),
)
session.log_info(f'Change {key} path to {path}')
return session.get_paths()
@app.put('/paths/watch_input')
......@@ -53,25 +59,33 @@ def watch_input_path(path: str):
return change_path('inbound_images', path)
@app.put('/paths/watch_output')
def watch_input_path(path: str):
def watch_output_path(path: str):
return change_path('outbound_images', path)
@app.get('/restart')
@app.get('/session/restart')
def restart_session(root: str = None) -> dict:
session.restart(root=root)
return session.describe_loaded_models()
@app.get('/session/logs')
def list_session_log() -> list:
return session.get_log_data()
@app.get('/models')
def list_active_models():
return session.describe_loaded_models()
@app.put('/models/dummy_semantic/load/')
def load_dummy_model() -> dict:
return {'model_id': session.load_model(DummySemanticSegmentationModel)}
mid = session.load_model(DummySemanticSegmentationModel)
session.log_info(f'Loaded model {mid}')
return {'model_id': mid}
@app.put('/models/dummy_instance/load/')
def load_dummy_model() -> dict:
return {'model_id': session.load_model(DummyInstanceSegmentationModel)}
mid = session.load_model(DummyInstanceSegmentationModel)
session.log_info(f'Loaded model {mid}')
return {'model_id': mid}
@app.put('/workflows/segment')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
......@@ -83,5 +97,5 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
session.paths['outbound_images'],
channel=channel,
)
session.record_workflow_run(record)
session.log_info(f'Completed segmentation of {input_filename}')
return record
\ No newline at end of file
......@@ -3,7 +3,7 @@ from math import floor
import numpy as np
from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, PatchStack
class Model(ABC):
......@@ -70,6 +70,16 @@ class SemanticSegmentationModel(ImageToImageModel):
"""
pass
def label_patch_stack(self, img: PatchStack, **kwargs) -> PatchStack:
"""
Iterative over a patch stack, call pixel labeling (to boolean array) separately on each cropped patch
"""
data = np.zeros((img.count, *img.hw, 1, img.nz), dtype=bool)
for i in range(0, img.count):
sl = img.get_slice_at(i)
data[i][sl] = self.label_pixel_class(img.iat(i, crop=True), **kwargs)
return PatchStack(data)
class InstanceSegmentationModel(ImageToImageModel):
"""
......@@ -85,9 +95,32 @@ class InstanceSegmentationModel(ImageToImageModel):
"""
if not mask.is_mask():
raise InvalidInputImageError('Expecting a binary mask')
if not img.shape == mask.shape:
if img.hw != mask.hw or img.nz != mask.nz:
raise InvalidInputImageError('Expect input image and mask to be the same shape')
def label_patch_stack(self, img: PatchStack, mask: PatchStack, allow_multiple=True, force_single=False, **kwargs):
"""
Call inference on all patches in a PatchStack at once
:param img: raw image data
:param mask: binary masks of same shape as img
:param allow_multiple: allow multiple nonzero pixel values in inferred patches if True
:param force_single: if True, and if allow_multiple is False, convert all nonzero pixels in a patch to the most
used label; otherwise raise an exception
:return: PatchStack of labeled objects
"""
res = self.label_instance_class(img, mask, **kwargs)
data = res.data
for i in range(0, res.count): # interpret as PYXCZ
la_patch = data[i, :, :, :, :]
la, ct = np.unique(la_patch, return_counts=True)
if len(la[la > 0]) > 1 and not allow_multiple:
if force_single:
la_patch[la_patch > 0] = la[1:][ct[1:].argsort()[-1]] # most common nonzero value
else:
raise InvalidObjectLabelsError(f'Found more than one nonzero label: {la}, counts: {ct}')
data[i, :, :, :, :] = la_patch
return PatchStack(data)
class DummySemanticSegmentationModel(SemanticSegmentationModel):
......@@ -143,3 +176,6 @@ class ParameterExpectedError(Error):
class InvalidInputImageError(Error):
pass
class InvalidObjectLabelsError(Error):
pass
\ No newline at end of file
......@@ -6,7 +6,7 @@ from math import ceil, floor
import numpy as np
import skimage
from skimage.exposure import rescale_intensity
from skimage.measure import find_contours
def is_mask(img):
"""
......@@ -117,6 +117,17 @@ def mask_largest_object(
else:
return img
def get_safe_contours(mask):
"""
Return a list of contour coordinates even if a mask is only one pixel across
"""
if mask.shape[0] == 1 or mask.shape[1] == 1:
c0 = mask.shape[0] - 1
c1 = mask.shape[1] - 1
return [np.array([(0, 0), (c0, c1)])]
else:
return find_contours(mask)
class Error(Exception):
pass
......
This diff is collapsed.
import json
import logging
import os
from pathlib import Path
from time import strftime, localtime
from typing import Dict
import pandas as pd
import model_server.conf.defaults
from model_server.base.models import Model
from model_server.base.workflows import WorkflowRunRecord
def create_manifest_json():
pass
logger = logging.getLogger(__name__)
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class CsvTable(object):
def __init__(self, fpath: Path):
self.path = fpath
self.empty = True
def append(self, coords: dict, data: pd.DataFrame) -> bool:
assert isinstance(data, pd.DataFrame)
for c in reversed(coords.keys()):
data.insert(0, c, coords[c])
if self.empty:
data.to_csv(self.path, index=False, mode='w', header=True)
else:
data.to_csv(self.path, index=False, mode='a', header=False)
self.empty = False
return True
class Session(object):
class Session(object, metaclass=Singleton):
"""
Singleton class for a server session that persists data between API calls
"""
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(Session, cls).__new__(cls)
return cls.instance
log_format = '%(asctime)s - %(levelname)s - %(message)s'
def __init__(self, root: str = None):
print('Initializing session')
self.models = {} # model_id : model object
self.manifest = [] # paths to data as well as other metadata from each inference run
self.paths = self.make_paths(root)
self.session_log = self.paths['logs'] / f'session.log'
self.log_event('Initialized session')
self.manifest_json = self.paths['logs'] / f'manifest.json'
open(self.manifest_json, 'w').close() # instantiate empty json file
self.logfile = self.paths['logs'] / f'session.log'
logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True, format=self.log_format)
self.log_info('Initialized session')
self.tables = {}
def write_to_table(self, name: str, coords: dict, data: pd.DataFrame):
"""
Write data to a named data table, initializing if it does not yet exist.
:param name: name of the table to persist through session
:param coords: dictionary of coordinates to associate with all rows in this method call
:param data: DataFrame containing data
:return: True if successful
"""
try:
if name in self.tables.keys():
table = self.tables.get(name)
else:
table = CsvTable(self.paths['tables'] / (name + '.csv'))
self.tables[name] = table
except Exception:
raise CouldNotCreateTable(f'Unable to create table named {name}')
try:
table.append(coords, data)
return True
except Exception:
raise CouldNotAppendToTable(f'Unable to append data to table named {name}')
def get_paths(self):
return self.paths
......@@ -55,7 +101,7 @@ class Session(object):
root_path = Path(root)
sid = Session.create_session_id(root_path)
paths = {'root': root_path}
for pk in ['inbound_images', 'outbound_images', 'logs']:
for pk in ['inbound_images', 'outbound_images', 'logs', 'tables']:
pa = root_path / sid / model_server.conf.defaults.subdirectories[pk]
paths[pk] = pa
try:
......@@ -75,22 +121,23 @@ class Session(object):
idx += 1
return f'{yyyymmdd}-{idx:04d}'
def log_event(self, event: str):
"""
Write an event string to this session's log file.
"""
timestamp = strftime('%m/%d/%Y, %H:%M:%S', localtime())
with open(self.session_log, 'a') as fh:
fh.write(f'{timestamp} -- {event}\n')
def get_log_data(self) -> list:
log = []
with open(self.logfile, 'r') as fh:
for line in fh:
k = ['datatime', 'level', 'message']
v = line.strip().split(' - ')[0:3]
log.insert(0, dict(zip(k, v)))
return log
def record_workflow_run(self, record: WorkflowRunRecord or None):
"""
Append a JSON describing inference data to this session's manifest
"""
self.log_event(f'Ran model {record.model_id} on {record.input_filepath} to infer {record.output_filepath}')
with open(self.manifest_json, 'w+') as fh:
json.dump(record.dict(), fh)
def log_info(self, msg):
logger.info(msg)
def log_warning(self, msg):
logger.warning(msg)
def log_error(self, msg):
logger.error(msg)
def load_model(self, ModelClass: Model, params: Dict[str, str] = None) -> dict:
"""
......@@ -114,7 +161,7 @@ class Session(object):
'object': mi,
'params': params
}
self.log_event(f'Loaded model {key}')
self.log_info(f'Loaded model {key}')
return key
def describe_loaded_models(self) -> dict:
......@@ -157,5 +204,11 @@ class CouldNotInstantiateModelError(Error):
class CouldNotCreateDirectory(Error):
pass
class CouldNotCreateTable(Error):
pass
class CouldNotAppendToTable(Error):
pass
class InvalidPathError(Error):
pass
\ No newline at end of file
"""
Implementation of image analysis work behind API endpoints, without knowledge of persistent data in server session.
"""
from collections import OrderedDict
from pathlib import Path
from time import perf_counter
from typing import Dict
......@@ -14,7 +15,7 @@ class Timer(object):
tfunc = perf_counter
def __init__(self):
self.events = {}
self.events = OrderedDict()
self.last = self.tfunc()
def click(self, key):
......
......@@ -6,18 +6,20 @@ import httplib
import json
import urllib
from ij import IJ
from ij import ImagePlus
HOST = '127.0.0.1'
PORT = 6221
uri = 'http://{}:{}/'.format(HOST, PORT)
def hit_endpoint(method, endpoint, params=None):
def hit_endpoint(method, endpoint, params=None, body=None):
"""
Python 2.7 implementation of HTTP client
:param method: (str) either 'GET' or 'PUT'
:param endpoint: (str) endpoint of HTTP request
:param params: (dict) of parameters required by client request
:param params: (dict) of parameters that are embedded in client request URL
:param body: (dict) of parameters that JSON-encoded and attached as payload in request
:return: (dict) of response status and content, formatted as dict if request is successful
"""
connection = httplib.HTTPConnection(HOST, PORT)
......@@ -27,7 +29,7 @@ def hit_endpoint(method, endpoint, params=None):
url = endpoint + '?' + urllib.urlencode(params)
else:
url = endpoint
connection.request(method, url)
connection.request(method, url, body=json.dumps(body))
resp = connection.getresponse()
resp_str = resp.read()
try:
......@@ -36,6 +38,28 @@ def hit_endpoint(method, endpoint, params=None):
content = {'str': str(resp_str)}
return {'status': resp.status, 'content': content}
def verify_server(popup=True):
try:
resp = hit_endpoint('GET', '/')
except Exception as e:
print(e)
msg = 'Could not find server at: ' + uri
IJ.log(msg)
if popup:
IJ.error(msg)
raise e
return False
if resp['status'] != 200:
msg = 'Unknown error verifying server at: ' + uri
if popup:
IJ.error(msg)
raise Exception(msg)
return False
else:
IJ.log('Verified server is online at: ' + uri)
return True
def run_request_sequence(imp, func, params):
"""
Execute a sequence of client requests in the ImageJ scripting environment
......
......@@ -6,6 +6,7 @@ subdirectories = {
'logs': 'logs',
'inbound_images': 'images/inbound',
'outbound_images': 'images/outbound',
'tables': 'tables',
}
server_conf = {
......
......@@ -67,6 +67,7 @@ roiset_test_data = {
'c': 5,
'z': 7,
'mask_path': root / 'zmask-test-stack-mask.tif',
'mask_path_3d': root / 'zmask-test-stack-mask-3d.tif',
},
'pipeline_params': {
'segmentation_channel': 0,
......
import json
import os
from pathlib import Path
......@@ -12,8 +13,17 @@ from model_server.base.models import Model, ImageToImageModel, InstanceSegmentat
class IlastikModel(Model):
def __init__(self, params, autoload=True):
def __init__(self, params, autoload=True, enforce_embedded=True):
"""
Base class for models that run via ilastik shell API
:param params:
project_file: path to ilastik project file
:param autoload: automatically load model into memory if true
:param enforce_embedded:
raise an error if all input data are not embedded in the project file, i.e. on the filesystem
"""
self.project_file = Path(params['project_file'])
self.enforce_embedded = enforce_embedded
params['project_file'] = self.project_file.__str__()
if self.project_file.is_absolute():
pap = self.project_file
......@@ -42,14 +52,42 @@ class IlastikModel(Model):
args.project = self.project_file_abspath.__str__()
shell = app.main(args, init_logging=False)
# validate if inputs are embedded in project file
h5 = shell.projectManager.currentProjectFile
for lane in h5['Input Data/infos'].keys():
for role in h5[f'Input Data/infos/{lane}'].keys():
grp = h5[f'Input Data/infos/{lane}/{role}']
if self.enforce_embedded and ('location' in grp.keys()) and grp['location'][()] != b'ProjectInternal':
raise IlastikInputEmbedding('Cannot load ilastik project file where inputs are on filesystem')
assert True
if not isinstance(shell.workflow, self.get_workflow()):
raise ParameterExpectedError(
f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}'
f'Ilastik project file {self.project_file} does not describe an instance of {self.__class__}'
)
self.shell = shell
return True
@property
def model_shape_dict(self):
raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data']
ax = raw_info['axistags'][()]
ax_keys = [ax['key'].upper() for ax in json.loads(ax)['axes']]
shape = raw_info['shape'][()]
dd = dict(zip(ax_keys, shape))
for ci in 'TCZ':
if ci not in dd.keys():
dd[ci] = 1
return dd
@property
def model_chroma(self):
return self.model_shape_dict['C']
@property
def model_3d(self):
return self.model_shape_dict['Z'] > 1
class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
model_id = 'ilastik_pixel_classification'
......@@ -60,7 +98,15 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
from ilastik.workflows import PixelClassificationWorkflow
return PixelClassificationWorkflow
def infer(self, input_img: GenericImageDataAccessor) -> (np.ndarray, dict):
@property
def labels(self):
h5 = self.shell.projectManager.currentProjectFile
return [l.decode() for l in h5['PixelClassification/LabelNames'][()]]
def infer(self, input_img: GenericImageDataAccessor) -> (InMemoryDataAccessor, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
dsi = [
{
......@@ -78,6 +124,22 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
def infer_patch_stack(self, img: PatchStack, **kwargs) -> (np.ndarray, dict):
"""
Iterative over a patch stack, call inference separately on each cropped patch
"""
from ilastik.applets.featureSelection.opFeatureSelection import FeatureSelectionConstraintError
nc = len(self.labels)
data = np.zeros((img.count, *img.hw, nc, img.nz), dtype=float) # interpret as PYXCZ
for i in range(0, img.count):
sl = img.get_slice_at(i)
try:
data[i][sl[0], sl[1], :, sl[3]] = self.infer(img.iat(i, crop=True))[0].data
except FeatureSelectionConstraintError: # occurs occasionally on small patches
continue
return PatchStack(data), {'success': True}
def label_pixel_class(self, img: GenericImageDataAccessor, px_class: int = 0, px_prob_threshold=0.5, **kwargs):
pxmap, _ = self.infer(img)
mask = pxmap.data[:, :, px_class, :] > px_prob_threshold
......@@ -87,22 +149,43 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
model_id = 'ilastik_object_classification_from_segmentation'
@staticmethod
def _make_8bit_mask(nda):
if nda.dtype == 'bool':
return 255 * nda.astype('uint8')
else:
return nda
@staticmethod
def get_workflow():
from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary
return ObjectClassificationWorkflowBinary
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
if self.model_chroma != input_img.chroma:
raise IlastikInputShapeError(
f'Model {self} expects {self.model_chroma} input channels but received only {input_img.chroma}'
)
if self.model_3d != input_img.is_3d():
if self.model_3d:
raise IlastikInputShapeError(f'Model is 3D but input image is 2D')
else:
raise IlastikInputShapeError(f'Model is 2D but input image is 3D')
assert segmentation_img.is_mask()
if segmentation_img.dtype == 'bool':
seg = 255 * segmentation_img.data.astype('uint8')
if isinstance(input_img, PatchStack):
assert isinstance(segmentation_img, PatchStack)
tagged_input_data = vigra.taggedView(input_img.pczyx, 'tczyx')
tagged_seg_data = vigra.taggedView(
255 * segmentation_img.data.astype('uint8'),
'yxcz'
self._make_8bit_mask(segmentation_img.pczyx),
'tczyx'
)
else:
tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_seg_data = vigra.taggedView(
self._make_8bit_mask(segmentation_img.data),
'yxcz'
)
dsi = [
{
......@@ -115,6 +198,15 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
assert len(obmaps) == 1, 'ilastik generated more than one object map'
if isinstance(input_img, PatchStack):
pyxcz = np.moveaxis(
obmaps[0],
[0, 1, 2, 3, 4],
[0, 4, 1, 2, 3]
)
return PatchStack(data=pyxcz), {'success': True}
else:
yxcz = np.moveaxis(
obmaps[0],
[1, 2, 3, 0],
......@@ -137,6 +229,14 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
return ObjectClassificationWorkflowPrediction
def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
if isinstance(input_img, PatchStack):
assert isinstance(pxmap_img, PatchStack)
tagged_input_data = vigra.taggedView(input_img.pczyx, 'tczyx')
tagged_pxmap_data = vigra.taggedView(pxmap_img.pczyx, 'tczyx')
else:
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
......@@ -151,6 +251,14 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
assert len(obmaps) == 1, 'ilastik generated more than one object map'
if isinstance(input_img, PatchStack):
pyxcz = np.moveaxis(
obmaps[0],
[0, 1, 2, 3, 4],
[0, 4, 1, 2, 3]
)
return PatchStack(data=pyxcz), {'success': True}
else:
yxcz = np.moveaxis(
obmaps[0],
[1, 2, 3, 0],
......@@ -172,48 +280,40 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
"""
if not img.shape == pxmap.shape:
raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape')
# TODO: check that pxmap is in-range
pxch = kwargs.get('pixel_classification_channel', 0)
pxtr = kwargs('pixel_classification_threshold', 0.5)
mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
# super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
pxtr = kwargs.get('pixel_classification_threshold', 0.5)
mask = img._derived_accessor(pxmap.get_one_channel_data(pxch).data > pxtr)
obmap, _ = self.infer(img, mask)
return obmap
class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
def make_instance_segmentation_model(self, px_ch: int):
"""
Wrap ilastik object classification for inputs comprising single-object series of raw images and binary
segmentation masks.
Generate an instance segmentation model, i.e. one that takes binary masks instead of pixel probabilities as a
second input.
:param px_ch: channel of pixel probability map to use
:return:
InstanceSegmentationModel object
"""
class _Mod(self.__class__, InstanceSegmentationModel):
def label_instance_class(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
) -> GenericImageDataAccessor:
if mask.dtype == 'bool':
norm_mask = 1.0 * mask.data
else:
norm_mask = mask.data / np.iinfo(mask.dtype).max
norm_mask_acc = mask._derived_accessor(norm_mask.astype('float32'))
return super().label_instance_class(img, norm_mask_acc, pixel_classification_channel=px_ch)
return _Mod(params={'project_file': self.project_file})
def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict):
assert segmentation_acc.is_mask()
if not input_acc.chroma == 1:
raise InvalidInputImageError('Object classifier expects only monochrome patches')
if not input_acc.nz == 1:
raise InvalidInputImageError('Object classifier expects only 2d patches')
tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx')
tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx')
dsi = [
{
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
}
]
obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
assert len(obmaps) == 1, 'ilastik generated more than one object map'
class Error(Exception):
pass
# for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C
assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1)
pyxcz = np.moveaxis(
obmaps[0],
[0, 1, 2, 3, 4],
[0, 4, 1, 2, 3]
)
class IlastikInputEmbedding(Error):
pass
return PatchStack(data=pyxcz), {'success': True}
\ No newline at end of file
class IlastikInputShapeError(Error):
"""Raised when an ilastik classifier is asked to infer on data that is incompatible with its input shape"""
pass
\ No newline at end of file
......@@ -25,14 +25,10 @@ def load_ilastik_model(model_class: ilm.IlastikModel, project_file: str, duplica
if not duplicate:
existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True)
if existing_model_id is not None:
session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate')
return {'model_id': existing_model_id}
try:
result = session.load_model(model_class, {'project_file': project_file})
except (FileNotFoundError, ParameterExpectedError):
raise HTTPException(
status_code=404,
detail=f'Could not load project file {project_file}',
)
session.log_info(f'Loaded ilastik model {result} from {project_file}')
return {'model_id': result}
@router.put('/seg/load/')
......@@ -60,6 +56,7 @@ def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: st
channel=channel,
mip=mip,
)
session.log_info(f'Completed pixel and object classification of {input_filename}')
except AssertionError:
raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}')
return record
\ No newline at end of file
......@@ -5,12 +5,16 @@ import unittest
import numpy as np
from model_server.conf.testing import czifile, ilastik_classifiers, output_path, roiset_test_data
from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file
from model_server.extensions.ilastik import models as ilm
from model_server.base.models import InvalidObjectLabelsError
from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams
from model_server.base.workflows import classify_pixels
from tests.test_api import TestServerBaseClass
def _random_int(*args):
return np.random.randint(0, 2 ** 8, size=args, dtype='uint8')
class TestIlastikPixelClassification(unittest.TestCase):
def setUp(self) -> None:
self.cf = CziImageFileAccessor(czifile['path'])
......@@ -83,6 +87,66 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.mono_image = mono_image
self.mask = mask
def test_pixel_classifier_enforces_input_shape(self):
model = ilm.IlastikPixelClassifierModel(
{'project_file': ilastik_classifiers['px']}
)
self.assertEqual(model.model_chroma, 1)
self.assertEqual(model.model_3d, False)
# correct data
self.assertIsInstance(
model.label_pixel_class(
InMemoryDataAccessor(
_random_int(512, 256, 1, 1)
)
),
InMemoryDataAccessor
)
# raise except with input of multiple channels
with self.assertRaises(ilm.IlastikInputShapeError):
mask = model.label_pixel_class(
InMemoryDataAccessor(
_random_int(512, 256, 3, 1)
)
)
# raise except with input of multiple channels
with self.assertRaises(ilm.IlastikInputShapeError):
mask = model.label_pixel_class(
InMemoryDataAccessor(
_random_int(512, 256, 1, 15)
)
)
def test_ilastik_infer_pxmap_from_patchstack(self):
def _r(h):
return np.random.randint(0, 2 ** 8, size=(h, 512, 1, 1), dtype='uint8')
acc = PatchStack([_r(256), _r(512), _r(256)])
self.assertEqual(acc.hw, (512, 512))
self.assertEqual(acc.iat(0, crop=True).hw, (256, 512))
model = ilm.IlastikPixelClassifierModel(
{'project_file': ilastik_classifiers['px']}
)
mask = model.label_patch_stack(acc)
self.assertEqual(mask.dtype, bool)
self.assertEqual(mask.chroma, 1)
self.assertEqual(mask.hw, acc.hw)
self.assertEqual(mask.nz, acc.nz)
self.assertEqual(mask.count, acc.count)
pxmap, _ = model.infer_patch_stack(acc)
self.assertEqual(pxmap.dtype, float)
self.assertEqual(pxmap.chroma, len(model.labels))
self.assertEqual(pxmap.hw, acc.hw)
self.assertEqual(pxmap.nz, acc.nz)
self.assertEqual(pxmap.count, acc.count)
def test_run_object_classifier_from_pixel_predictions(self):
self.test_run_pixel_classifier()
fp = czifile['path']
......@@ -97,7 +161,8 @@ class TestIlastikPixelClassification(unittest.TestCase):
objmap,
)
)
self.assertEqual(objmap.data.max(), 3)
self.assertEqual(objmap.data.max(), 2)
def test_run_object_classifier_from_segmentation(self):
self.test_run_pixel_classifier()
......@@ -113,7 +178,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
objmap,
)
)
self.assertEqual(objmap.data.max(), 3)
self.assertEqual(objmap.data.max(), 2)
def test_ilastik_pixel_classification_as_workflow(self):
result = classify_pixels(
......@@ -132,15 +197,15 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_httpexception_if_incorrect_project_file_loaded(self):
resp_load = self._put(
'ilastik/seg/load/',
{'project_file': 'improper.ilp'},
query={'project_file': 'improper.ilp'},
)
self.assertEqual(resp_load.status_code, 404)
self.assertEqual(resp_load.status_code, 500)
def test_load_ilastik_pixel_model(self):
resp_load = self._put(
'ilastik/seg/load/',
{'project_file': str(ilastik_classifiers['px'])},
query={'project_file': str(ilastik_classifiers['px'])},
)
self.assertEqual(resp_load.status_code, 200, resp_load.json())
model_id = resp_load.json()['model_id']
......@@ -156,19 +221,19 @@ class TestIlastikOverApi(TestServerBaseClass):
self.assertEqual(len(resp_list_1st), 1, resp_list_1st)
resp_load_2nd = self._put(
'ilastik/seg/load/',
{'project_file': str(ilastik_classifiers['px']), 'duplicate': True, },
query={'project_file': str(ilastik_classifiers['px']), 'duplicate': True, },
)
resp_list_2nd = self._get('models').json()
self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd)
resp_load_3rd = self._put(
'ilastik/seg/load/',
{'project_file': str(ilastik_classifiers['px']), 'duplicate': False},
query={'project_file': str(ilastik_classifiers['px']), 'duplicate': False},
)
resp_list_3rd = self._get('models').json()
self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd)
def test_no_duplicate_model_with_different_path_formats(self):
self._get('restart')
self._get('session/restart')
resp_list_1 = self._get('models').json()
self.assertEqual(len(resp_list_1), 0)
ilp = ilastik_classifiers['px']
......@@ -185,11 +250,11 @@ class TestIlastikOverApi(TestServerBaseClass):
# load models with these paths
resp1 = self._put(
'ilastik/seg/load/',
{'project_file': ilp_win, 'duplicate': False },
query={'project_file': ilp_win, 'duplicate': False },
)
resp2 = self._put(
'ilastik/seg/load/',
{'project_file': ilp_posx, 'duplicate': False},
query={'project_file': ilp_posx, 'duplicate': False},
)
self.assertEqual(resp1.json(), resp2.json())
......@@ -202,7 +267,7 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_load_ilastik_pxmap_to_obj_model(self):
resp_load = self._put(
'ilastik/pxmap_to_obj/load/',
{'project_file': str(ilastik_classifiers['pxmap_to_obj'])},
query={'project_file': str(ilastik_classifiers['pxmap_to_obj'])},
)
model_id = resp_load.json()['model_id']
......@@ -216,7 +281,7 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_load_ilastik_seg_to_obj_model(self):
resp_load = self._put(
'ilastik/seg_to_obj/load/',
{'project_file': str(ilastik_classifiers['seg_to_obj'])},
query={'project_file': str(ilastik_classifiers['seg_to_obj'])},
)
model_id = resp_load.json()['model_id']
......@@ -233,10 +298,11 @@ class TestIlastikOverApi(TestServerBaseClass):
resp_infer = self._put(
f'workflows/segment',
{'model_id': model_id, 'input_filename': czifile['filename'], 'channel': 0},
query={'model_id': model_id, 'input_filename': czifile['filename'], 'channel': 0},
)
self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
def test_ilastik_infer_px_then_ob(self):
self.copy_input_file_to_server()
px_model_id = self.test_load_ilastik_pixel_model()
......@@ -244,7 +310,7 @@ class TestIlastikOverApi(TestServerBaseClass):
resp_infer = self._put(
'ilastik/pixel_then_object_classification/infer/',
{
query={
'px_model_id': px_model_id,
'ob_model_id': ob_model_id,
'input_filename': czifile['filename'],
......@@ -253,6 +319,7 @@ class TestIlastikOverApi(TestServerBaseClass):
)
self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
class TestIlastikObjectClassification(unittest.TestCase):
def setUp(self):
stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path'])
......@@ -269,16 +336,36 @@ class TestIlastikObjectClassification(unittest.TestCase):
)
)
self.object_classifier = ilm.PatchStackObjectClassifier(
self.classifier = ilm.IlastikObjectClassifierFromSegmentationModel(
params={'project_file': ilastik_classifiers['seg_to_obj']}
)
self.raw = self.roiset.get_patches_acc()
self.masks = self.roiset.get_patch_masks_acc()
def test_classify_patches(self):
raw_patches = self.roiset.get_raw_patches()
patch_masks = self.roiset.get_patch_masks()
res_patches, _ = self.object_classifier.infer(raw_patches, patch_masks)
self.assertEqual(res_patches.count, self.roiset.count)
for pi in range(0, res_patches.count): # assert that there is only one nonzero label per patch
unique = np.unique(res_patches.iat(pi).data)
self.assertEqual(len(unique), 2)
self.assertEqual(unique[0], 0)
res = self.classifier.label_patch_stack(self.raw, self.masks)
self.assertEqual(res.count, self.roiset.count)
res.export_pyxcz(output_path / 'res_patches.tif')
for pi in range(0, res.count): # assert that there is only one nonzero label per patch
la, ct = np.unique(res.iat(pi).data, return_counts=True)
self.assertEqual(np.sum(ct > 1), 2) # exclude single-pixel anomaly
self.assertEqual(la[0], 0)
def test_multiple_objects_in_patch(self):
# allow multiple labels in a classified patch
res1 = self.classifier.label_patch_stack(self.raw, self.masks, allow_multiple=True)
la1, cts1 = np.unique(res1.iat(2).data, return_counts=True)
self.assertGreater(len(la1[la1 > 0]), 1)
self.assertEqual(la1[0], 0)
self.assertTrue(np.all(cts1[1] > cts1[2:]))
# raise exception if there are multiple labels in any patch
with self.assertRaises(InvalidObjectLabelsError):
res2 = self.classifier.label_patch_stack(self.raw, self.masks, allow_multiple=False, force_single=False)
# convert all nonzero pixels to the label with highest occurrence
res3 = self.classifier.label_patch_stack(self.raw, self.masks, allow_multiple=False, force_single=True)
la3, cts3 = np.unique(res3.iat(2).data, return_counts=True)
self.assertEqual(len(la3[la3 > 0]), 1)
self.assertEqual(la3[1], la1[1])
\ No newline at end of file
......@@ -27,6 +27,11 @@ def parse_args():
action='store_true',
help='display extra information that is helpful for debugging'
)
parser.add_argument(
'--reload',
action='store_true',
help='automatically restart server when changes are noticed, for development purposes'
)
return parser.parse_args()
......@@ -41,8 +46,9 @@ def main(args, app_name='model_server.base.api:app') -> None:
'host': args.host,
'port': int(args.port),
'log_level': 'debug',
'reload': args.reload,
},
daemon=True,
daemon=(args.reload is False),
)
url = f'http://{args.host}:{int(args.port):04d}/status'
print(url)
......
......@@ -7,6 +7,9 @@ from model_server.base.accessors import PatchStack, make_patch_stack_from_file,
from model_server.conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile, monozstackmask
from model_server.base.accessors import CziImageFileAccessor, DataShapeError, generate_file_accessor, InMemoryDataAccessor, PngFileAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor
def _random_int(*args):
return np.random.randint(0, 2 ** 8, size=args, dtype='uint8')
class TestCziImageFileAccess(unittest.TestCase):
def setUp(self) -> None:
......@@ -23,6 +26,7 @@ class TestCziImageFileAccess(unittest.TestCase):
self.assertEqual(len(tf.data.shape), 4)
self.assertEqual(tf.shape[0], tifffile['h'])
self.assertEqual(tf.shape[1], tifffile['w'])
self.assertEqual(tf.get_axis('x'), 1)
def test_czifile_is_correct_shape(self):
cf = CziImageFileAccessor(czifile['path'])
......@@ -40,7 +44,7 @@ class TestCziImageFileAccess(unittest.TestCase):
nc = 4
nz = 11
c = 3
cf = InMemoryDataAccessor(np.random.rand(h, w, nc, nz))
cf = InMemoryDataAccessor(_random_int(h, w, nc, nz))
sc = cf.get_one_channel_data(c)
self.assertEqual(sc.shape, (h, w, 1, nz))
......@@ -70,7 +74,7 @@ class TestCziImageFileAccess(unittest.TestCase):
def test_conform_data_shorter_than_xycz(self):
h = 256
w = 512
data = np.random.rand(h, w, 1)
data = _random_int(h, w, 1)
acc = InMemoryDataAccessor(data)
self.assertEqual(
InMemoryDataAccessor.conform_data(data).shape,
......@@ -82,7 +86,7 @@ class TestCziImageFileAccess(unittest.TestCase):
)
def test_conform_data_longer_than_xycz(self):
data = np.random.rand(256, 512, 12, 8, 3)
data = _random_int(256, 512, 12, 8, 3)
with self.assertRaises(DataShapeError):
acc = InMemoryDataAccessor(data)
......@@ -93,7 +97,7 @@ class TestCziImageFileAccess(unittest.TestCase):
c = 3
nz = 10
yxcz = (2**8 * np.random.rand(h, w, c, nz)).astype('uint8')
yxcz = _random_int(h, w, c, nz)
acc = InMemoryDataAccessor(yxcz)
fp = output_path / f'rand3d.tif'
self.assertTrue(
......@@ -138,16 +142,18 @@ class TestPatchStackAccessor(unittest.TestCase):
w = 256
h = 512
n = 4
acc = PatchStack(np.random.rand(n, h, w, 1, 1))
acc = PatchStack(_random_int(n, h, w, 1, 1))
self.assertEqual(acc.count, n)
self.assertEqual(acc.hw, (h, w))
self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1))
self.assertEqual(acc.shape[1:], acc.iat(0, crop=True).shape)
def test_make_patch_stack_from_list(self):
w = 256
h = 512
n = 4
acc = PatchStack([np.random.rand(h, w, 1, 1) for _ in range(0, n)])
acc = PatchStack([_random_int(h, w, 1, 1) for _ in range(0, n)])
self.assertEqual(acc.count, n)
self.assertEqual(acc.hw, (h, w))
self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1))
......@@ -176,8 +182,8 @@ class TestPatchStackAccessor(unittest.TestCase):
nz = 5
n = 4
patches = [np.random.rand(h, w, c, nz) for _ in range(0, n)]
patches.append(np.random.rand(h, 2 * w, c, nz))
patches = [_random_int(h, w, c, nz) for _ in range(0, n)]
patches.append(_random_int(h, 2 * w, c, nz))
acc = PatchStack(patches)
self.assertEqual(acc.count, n + 1)
self.assertEqual(acc.hw, (h, 2 * w))
......@@ -185,13 +191,70 @@ class TestPatchStackAccessor(unittest.TestCase):
self.assertEqual(acc.iat(0).shape, (h, 2 * w, c, nz))
self.assertEqual(acc.iat_yxcz(0).shape, (h, 2 * w, c, nz))
# test that initial patches are maintained
for i in range(0, acc.count):
self.assertEqual(patches[i].shape, acc.iat(i, crop=True).shape)
self.assertEqual(acc.shape[1:], acc.iat(i, crop=False).shape)
def test_make_3d_patch_stack_from_list_force_long_dim(self):
def _r(h, w):
return np.random.randint(0, 2 ** 8, size=(h, w, 1, 1), dtype='uint8')
patches = [_r(256, 128), _r(128, 256), _r(512, 10), _r(10, 512)]
acc_ref = PatchStack(patches, force_ydim_longest=False)
self.assertEqual(acc_ref.hw, (512, 512))
self.assertEqual(acc_ref.iat(-1, crop=False).hw, (512, 512))
self.assertEqual(acc_ref.iat(-1, crop=True).hw, (10, 512))
acc_rot = PatchStack(patches, force_ydim_longest=True)
self.assertEqual(acc_rot.hw, (512, 128))
self.assertEqual(acc_rot.iat(-1, crop=False).hw, (512, 128))
self.assertEqual(acc_rot.iat(-1, crop=True).hw, (512, 10))
nda_rot_rot = np.rot90(acc_rot.iat(-1, crop=True).data, axes=(1, 0))
nda_ref = acc_ref.iat(-1, crop=True).data
self.assertTrue(np.all(nda_ref == nda_rot_rot))
self.assertLess(acc_rot.data.size, acc_ref.data.size)
def test_pczyx(self):
w = 256
h = 512
n = 4
nz = 15
nc = 2
acc = PatchStack(np.random.rand(n, h, w, nc, nz))
nc = 3
acc = PatchStack(_random_int(n, h, w, nc, nz))
self.assertEqual(acc.count, n)
self.assertEqual(acc.pczyx.shape, (n, nc, nz, h, w))
self.assertEqual(acc.hw, (h, w))
return acc
def test_get_one_channel(self):
acc = self.test_pczyx()
mono = acc.get_one_channel_data(channel=1)
for a in 'PXYZ':
self.assertEqual(mono.shape_dict[a], acc.shape_dict[a])
self.assertEqual(mono.shape_dict['C'], 1)
def test_get_multiple_channels(self):
acc = self.test_pczyx()
channels = [0, 1]
mcacc = acc.get_channels(channels=channels)
for a in 'PXYZ':
self.assertEqual(mcacc.shape_dict[a], acc.shape_dict[a])
self.assertEqual(mcacc.shape_dict['C'], len(channels))
def test_get_one_channel_mip(self):
acc = self.test_pczyx()
mono_mip = acc.get_one_channel_data(channel=1, mip=True)
for a in 'PXY':
self.assertEqual(mono_mip.shape_dict[a], acc.shape_dict[a])
for a in 'CZ':
self.assertEqual(mono_mip.shape_dict[a], 1)
def test_export_pczyx_patch_hyperstack(self):
acc = self.test_pczyx()
fp = output_path / 'patch_hyperstack.tif'
acc.export_pyxcz(fp)
acc2 = make_patch_stack_from_file(fp)
self.assertEqual(acc.shape, acc2.shape)
\ No newline at end of file
import json
from multiprocessing import Process
from pathlib import Path
import requests
......@@ -36,8 +38,12 @@ class TestServerBaseClass(unittest.TestCase):
def _get(self, endpoint):
return self._get_sesh().get(self.uri + endpoint)
def _put(self, endpoint, params=None):
return self._get_sesh().put(self.uri + endpoint, params=params)
def _put(self, endpoint, query=None, body=None):
return self._get_sesh().put(
self.uri + endpoint,
params=query,
data=json.dumps(body)
)
def copy_input_file_to_server(self):
from shutil import copyfile
......@@ -59,10 +65,10 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
self.assertEqual(resp.status_code, 200)
def test_bounceback_parameters(self):
resp = self._put('bounce_back', {'par1': 'hello'})
resp = self._put('bounce_back', body={'par1': 'hello', 'par2': ['ab', 'cd']})
self.assertEqual(resp.status_code, 200, resp.json())
self.assertEqual(resp.json()['params']['par1'], 'hello', resp.json())
self.assertEqual(resp.json()['params']['par2'], None, resp.json())
self.assertEqual(resp.json()['params']['par2'], ['ab', 'cd'], resp.json())
def test_default_session_paths(self):
import model_server.conf.defaults
......@@ -103,7 +109,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
resp = self._put(
f'infer/from_image_file',
{'model_id': model_id, 'input_filename': 'not_a_real_file.name'}
query={'model_id': model_id, 'input_filename': 'not_a_real_file.name'}
)
self.assertEqual(resp.status_code, 404, resp.content.decode())
......@@ -112,7 +118,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
model_id = 'not_a_real_model'
resp = self._put(
f'workflows/segment',
{'model_id': model_id, 'input_filename': 'not_a_real_file.name'}
query={'model_id': model_id, 'input_filename': 'not_a_real_file.name'}
)
self.assertEqual(resp.status_code, 409, resp.content.decode())
......@@ -121,7 +127,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
self.copy_input_file_to_server()
resp_infer = self._put(
f'workflows/segment',
{
query={
'model_id': model_id,
'input_filename': czifile['filename'],
'channel': 2,
......@@ -136,7 +142,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
self.assertEqual(resp_list_0.status_code, 200)
rj0 = resp_list_0.json()
self.assertEqual(len(rj0), 1, f'Unexpected models in response: {rj0}')
resp_restart = self._get('restart')
resp_restart = self._get('session/restart')
resp_list_1 = self._get('models')
rj1 = resp_list_1.json()
self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}')
......@@ -145,7 +151,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
resp_inpath = self._get('paths')
resp_change = self._put(
f'paths/watch_output',
{'path': resp_inpath.json()['inbound_images']}
query={'path': resp_inpath.json()['inbound_images']}
)
self.assertEqual(resp_change.status_code, 200)
resp_check = self._get('paths')
......@@ -156,7 +162,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
fakepath = 'c:/fake/path/to/nowhere'
resp_change = self._put(
f'paths/watch_output',
{'path': fakepath}
query={'path': fakepath}
)
self.assertEqual(resp_change.status_code, 404)
self.assertIn(fakepath, resp_change.json()['detail'])
......@@ -167,8 +173,13 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
resp_inpath = self._get('paths')
resp_change = self._put(
f'paths/watch_output',
{'path': resp_inpath.json()['outbound_images']}
query={'path': resp_inpath.json()['outbound_images']}
)
self.assertEqual(resp_change.status_code, 200)
resp_check = self._get('paths')
self.assertEqual(resp_inpath.json()['outbound_images'], resp_check.json()['outbound_images'])
def test_get_logs(self):
resp = self._get('session/logs')
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.json()[0]['message'], 'Initialized session')
\ No newline at end of file
import unittest
import numpy as np
from skimage.measure import find_contours
from model_server.base.process import mask_largest_object
from model_server.base.process import pad
from model_server.base.annotators import draw_contours_on_patch
from model_server.base.process import get_safe_contours, mask_largest_object, pad
class TestProcessingUtilityMethods(unittest.TestCase):
def setUp(self) -> None:
......@@ -56,3 +57,26 @@ class TestMaskLargestObject(unittest.TestCase):
self.assertTrue(np.all(np.unique(masked) == [0, 255]))
self.assertTrue(np.all(masked[:, 3:5] == 0))
self.assertTrue(np.all(masked[3:5, :] == 0))
class TestSafeContours(unittest.TestCase):
def setUp(self) -> None:
self.patch = np.ones((10, 20), dtype='uint8')
self.mask_ref = np.zeros((10, 20), dtype=bool)
self.mask_ref[0:5, 0:10] = True
self.mask_test = np.ones((1, 20), dtype=bool)
def test_contours_on_compliant_mask(self):
con = get_safe_contours(self.mask_ref)
patch = self.patch.copy()
self.assertEqual((patch == 0).sum(), 0)
patch = draw_contours_on_patch(patch, con)
self.assertEqual((patch == 0).sum(), 14)
def test_contours_on_noncompliant_mask(self):
con = get_safe_contours(self.mask_test)
patch = self.patch.copy()
self.assertEqual((patch == 0).sum(), 0)
patch = draw_contours_on_patch(self.patch, con)
self.assertEqual((patch == 0).sum(), 20)
self.assertEqual((patch[0, :] == 0).sum(), 20)
\ No newline at end of file
This diff is collapsed.
import json
from os.path import exists
import pathlib
import unittest
from model_server.base.models import DummySemanticSegmentationModel
from model_server.base.session import Session
from model_server.base.workflows import WorkflowRunRecord
class TestGetSessionObject(unittest.TestCase):
def setUp(self) -> None:
pass
self.sesh = Session()
def test_single_session_instance(self):
sesh = Session()
self.assertIs(sesh, Session(), 'Re-initializing Session class returned a new object')
def tearDown(self) -> None:
print('Tearing down...')
Session._instances = {}
from os.path import exists
self.assertTrue(exists(sesh.session_log), 'Session did not create a log file in the correct place')
self.assertTrue(exists(sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place')
def test_session_is_singleton(self):
Session._instances = {}
self.assertEqual(len(Session._instances), 0)
s = Session()
self.assertEqual(len(Session._instances), 1)
self.assertIs(s, Session())
self.assertEqual(len(Session._instances), 1)
def test_session_logfile_is_valid(self):
self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')
def test_changing_session_root_creates_new_directory(self):
from model_server.conf.defaults import root
from shutil import rmtree
sesh = Session()
old_paths = sesh.get_paths()
old_paths = self.sesh.get_paths()
newroot = root / 'subdir'
sesh.restart(root=newroot)
new_paths = sesh.get_paths()
self.sesh.restart(root=newroot)
new_paths = self.sesh.get_paths()
for k in old_paths.keys():
self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__()))
# this is necessary because logger itself is a singleton class
self.tearDown()
self.setUp()
rmtree(newroot)
self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory')
def test_change_session_subdirectory(self):
sesh = Session()
old_paths = sesh.get_paths()
old_paths = self.sesh.get_paths()
print(old_paths)
sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
self.assertEqual(sesh.paths['outbound_images'], sesh.paths['inbound_images'])
def test_restart_session(self):
sesh = Session()
logfile1 = sesh.session_log
sesh.restart()
logfile2 = sesh.session_log
self.sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images'])
def test_restarting_session_creates_new_logfile(self):
logfile1 = self.sesh.logfile
self.assertTrue(logfile1.exists())
self.sesh.restart()
logfile2 = self.sesh.logfile
self.assertTrue(logfile2.exists())
self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile')
def test_session_records_workflow(self):
import json
from model_server.base.workflows import WorkflowRunRecord
sesh = Session()
di = WorkflowRunRecord(
model_id='test_model',
input_filepath='/test/input/directory',
output_filepath='/test/output/fi.le',
success=True,
timer_results={'start': 0.123},
)
sesh.record_workflow_run(di)
with open(sesh.manifest_json, 'r') as fh:
do = json.load(fh)
self.assertEqual(di.dict(), do, 'Manifest record is not correct')
def test_log_warning(self):
msg = 'A test warning'
self.sesh.log_info(msg)
with open(self.sesh.logfile, 'r') as fh:
log = fh.read()
self.assertTrue(msg in log)
def test_get_logs(self):
self.sesh.log_info('Info example 1')
self.sesh.log_warning('Example warning')
self.sesh.log_info('Info example 2')
logs = self.sesh.get_log_data()
self.assertEqual(len(logs), 4)
self.assertEqual(logs[1]['level'], 'WARNING')
self.assertEqual(logs[-1]['message'], 'Initialized session')
def test_session_loads_model(self):
sesh = Session()
MC = DummySemanticSegmentationModel
success = sesh.load_model(MC)
success = self.sesh.load_model(MC)
self.assertTrue(success)
loaded_models = sesh.describe_loaded_models()
loaded_models = self.sesh.describe_loaded_models()
self.assertTrue(
(MC.__name__ + '_00') in loaded_models.keys()
)
......@@ -76,46 +87,56 @@ class TestGetSessionObject(unittest.TestCase):
)
def test_session_loads_second_instance_of_same_model(self):
sesh = Session()
MC = DummySemanticSegmentationModel
sesh.load_model(MC)
sesh.load_model(MC)
self.assertIn(MC.__name__ + '_00', sesh.models.keys())
self.assertIn(MC.__name__ + '_01', sesh.models.keys())
self.sesh.load_model(MC)
self.sesh.load_model(MC)
self.assertIn(MC.__name__ + '_00', self.sesh.models.keys())
self.assertIn(MC.__name__ + '_01', self.sesh.models.keys())
def test_session_loads_model_with_params(self):
sesh = Session()
MC = DummySemanticSegmentationModel
p1 = {'p1': 'abc'}
success = sesh.load_model(MC, params=p1)
success = self.sesh.load_model(MC, params=p1)
self.assertTrue(success)
loaded_models = sesh.describe_loaded_models()
loaded_models = self.sesh.describe_loaded_models()
mid = MC.__name__ + '_00'
self.assertEqual(loaded_models[mid]['params'], p1)
# load a second model and confirm that the first is locatable by its param entry
p2 = {'p2': 'def'}
sesh.load_model(MC, params=p2)
find_mid = sesh.find_param_in_loaded_models('p1', 'abc')
self.sesh.load_model(MC, params=p2)
find_mid = self.sesh.find_param_in_loaded_models('p1', 'abc')
self.assertEqual(mid, find_mid)
self.assertEqual(sesh.describe_loaded_models()[mid]['params'], p1)
self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1)
def test_session_finds_existing_model_with_different_path_formats(self):
sesh = Session()
MC = DummySemanticSegmentationModel
p1 = {'path': 'c:\\windows\\dummy.pa'}
p2 = {'path': 'c:/windows/dummy.pa'}
mid = sesh.load_model(MC, params=p1)
mid = self.sesh.load_model(MC, params=p1)
assert pathlib.Path(p1['path']) == pathlib.Path(p2['path'])
find_mid = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
find_mid = self.sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
self.assertEqual(mid, find_mid)
def test_change_output_path(self):
import pathlib
sesh = Session()
pa = sesh.get_paths()['inbound_images']
pa = self.sesh.get_paths()['inbound_images']
self.assertIsInstance(pa, pathlib.Path)
sesh.set_data_directory('outbound_images', pa.__str__())
self.assertEqual(sesh.paths['inbound_images'], sesh.paths['outbound_images'])
self.assertIsInstance(sesh.paths['outbound_images'], pathlib.Path)
\ No newline at end of file
self.sesh.set_data_directory('outbound_images', pa.__str__())
self.assertEqual(self.sesh.paths['inbound_images'], self.sesh.paths['outbound_images'])
self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path)
def test_make_table(self):
import pandas as pd
data = [{'modulo': i % 2, 'times one hundred': i * 100} for i in range(0, 8)]
self.sesh.write_to_table(
'test_numbers', {'X': 0, 'Y': 0}, pd.DataFrame(data[0:4])
)
self.assertTrue(self.sesh.tables['test_numbers'].path.exists())
self.sesh.write_to_table(
'test_numbers', {'X': 1, 'Y': 1}, pd.DataFrame(data[4:8])
)
dfv = pd.read_csv(self.sesh.tables['test_numbers'].path)
self.assertEqual(len(dfv), len(data))
self.assertEqual(dfv.columns[0], 'X')
self.assertEqual(dfv.columns[1], 'Y')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment