diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index 07eba484b24c7b64b729ee1bb71de708ecf1f934..12807bf2a49a1ce0d5a49710b520d02b3936a8db 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -162,6 +162,14 @@ class CziImageFileAccessor(GenericImageFileAccessor): except Exception: raise FileAccessorError(f'Unable to access CZI data in {fpath}') + try: + md = cf.metadata(raw=False) + compmet = md['ImageDocument']['Metadata']['Information']['Image']['OriginalCompressionMethod'] + except KeyError: + raise InvalidCziCompression('Could not find metadata key OriginalCompressionMethod') + if compmet.upper() != 'UNCOMPRESSED': + raise InvalidCziCompression(f'Unsupported compression method {compmet}') + sd = {ch: cf.shape[cf.axes.index(ch)] for ch in cf.axes} if (sd.get('S') and (sd['S'] > 1)) or (sd.get('T') and (sd['T'] > 1)): raise DataShapeError(f'Cannot handle image with multiple positions or time points: {sd}') @@ -284,6 +292,29 @@ class PatchStack(InMemoryDataAccessor): def count(self): return self.shape_dict['P'] + def export_pyxcz(self, fpath: Path): + tzcyx = np.moveaxis( + self.pyxcz, # yxcz + [0, 4, 3, 1, 2], + [0, 1, 2, 3, 4] + ) + + if self.is_mask(): + if self.dtype == 'bool': + data = (tzcyx * 255).astype('uint8') + else: + data = tzcyx.astype('uint8') + tifffile.imwrite(fpath, data, imagej=True) + else: + tifffile.imwrite(fpath, tzcyx, imagej=True) + + def get_one_channel_data(self, channel: int, mip: bool = False): + c = int(channel) + if mip: + return PatchStack(self.pyxcz[:, :, :, c:(c + 1), :].max(axis=-1, keepdims=True)) + else: + return PatchStack(self.pyxcz[:, :, :, c:(c + 1), :]) + @property def shape_dict(self): return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) @@ -313,16 +344,32 @@ class PatchStack(InMemoryDataAccessor): return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) -def make_patch_stack_from_file(fpath): # interpret z-dimension as patch position +def make_patch_stack_from_file(fpath): # interpret t-dimension as patch position if not Path(fpath).exists(): raise FileNotFoundError(f'Could not find {fpath}') - pyxc = np.moveaxis( - generate_file_accessor(fpath).data, # yxcz - [0, 1, 2, 3], - [1, 2, 3, 0] + try: + tf = tifffile.TiffFile(fpath) + except Exception: + raise FileAccessorError(f'Unable to access data in {fpath}') + + if len(tf.series) != 1: + raise DataShapeError(f'Expect only one series in {fpath}') + + se = tf.series[0] + + axs = [a for a in se.axes if a in [*'TZCYX']] + sd = dict(zip(axs, se.shape)) + for a in [*'TZC']: + if a not in axs: + sd[a] = 1 + tzcyx = se.asarray().reshape([sd[k] for k in [*'TZCYX']]) + + pyxcz = np.moveaxis( + tzcyx, + [0, 3, 4, 2, 1], + [0, 1, 2, 3, 4], ) - pyxcz = np.expand_dims(pyxc, axis=3) return PatchStack(pyxcz) @@ -345,6 +392,9 @@ class FileWriteError(Error): class InvalidAxisKey(Error): pass +class InvalidCziCompression(Error): + pass + class InvalidDataShape(Error): pass diff --git a/model_server/base/api.py b/model_server/base/api.py index 8259a98c8f74a5ca25b6b2b3c6e4d8edc95fa547..752ed48a0dcc41f62490f6e05b650348410f6231 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -46,6 +46,7 @@ def change_path(key, path): status_code=404, detail=e.__str__(), ) + session.log_info(f'Change {key} path to {path}') return session.get_paths() @app.put('/paths/watch_input') @@ -56,22 +57,30 @@ def watch_input_path(path: str): def watch_input_path(path: str): return change_path('outbound_images', path) -@app.get('/restart') +@app.get('/session/restart') def restart_session(root: str = None) -> dict: session.restart(root=root) return session.describe_loaded_models() +@app.get('/session/logs') +def list_session_log() -> list: + return session.get_log_data() + @app.get('/models') def list_active_models(): return session.describe_loaded_models() @app.put('/models/dummy_semantic/load/') def load_dummy_model() -> dict: - return {'model_id': session.load_model(DummySemanticSegmentationModel)} + mid = session.load_model(DummySemanticSegmentationModel) + session.log_info(f'Loaded model {mid}') + return {'model_id': mid} @app.put('/models/dummy_instance/load/') def load_dummy_model() -> dict: - return {'model_id': session.load_model(DummyInstanceSegmentationModel)} + mid = session.load_model(DummyInstanceSegmentationModel) + session.log_info(f'Loaded model {mid}') + return {'model_id': mid} @app.put('/workflows/segment') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: @@ -83,5 +92,5 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: session.paths['outbound_images'], channel=channel, ) - session.record_workflow_run(record) + session.log_info(f'Completed segmentation of {input_filename}') return record \ No newline at end of file diff --git a/model_server/base/models.py b/model_server/base/models.py index 8413f68d4420fe31add5e25818b4198bdc1a76eb..4fa0849031bcd318a66f49c8843763b0a99c2279 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -3,7 +3,7 @@ from math import floor import numpy as np -from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor +from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, PatchStack class Model(ABC): @@ -88,6 +88,16 @@ class InstanceSegmentationModel(ImageToImageModel): if not img.shape == mask.shape: raise InvalidInputImageError('Expect input image and mask to be the same shape') + def label_patch_stack(self, img: PatchStack, mask: PatchStack, **kwargs): + """ + Iterative over a patch stack, call inference on each patch + :return: PatchStack of same shape in input + """ + res_data = np.zeros(img.shape, dtype='uint16') + for i in range(0, img.count): # interpret as PYXCZ + res_data[i, :, :, :, :] = self.label_instance_class(img.iat(i), mask.iat(i), **kwargs).data + return PatchStack(res_data) + class DummySemanticSegmentationModel(SemanticSegmentationModel): diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index c44023db6622237b5fe40dd82b90d2f161187517..d537e92af11f458868acf14bc7c006774b4b58b8 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -1,5 +1,6 @@ from math import sqrt, floor from pathlib import Path +import re from typing import List, Union from uuid import uuid4 @@ -17,7 +18,7 @@ from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAc from model_server.base.models import InstanceSegmentationModel from model_server.base.process import pad, rescale, resample_to_8bit, make_rgb from model_server.base.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image -from model_server.base.accessors import PatchStack +from model_server.base.accessors import generate_file_accessor, PatchStack from model_server.base.process import mask_largest_object @@ -44,7 +45,6 @@ class RoiFilterRange(BaseModel): class RoiFilter(BaseModel): area: Union[RoiFilterRange, None] = None - solidity: Union[RoiFilterRange, None] = None class RoiSetMetaParams(BaseModel): @@ -57,16 +57,39 @@ class RoiSetExportParams(BaseModel): patches_3d: Union[PatchParams, None] = None annotated_patches_2d: Union[PatchParams, None] = None patches_2d: Union[PatchParams, None] = None - patch_masks: Union[PatchParams, None] = None annotated_zstacks: Union[AnnotatedZStackParams, None] = None object_classes: bool = False - dataframe: bool = False - -def _get_label_ids(acc_seg_mask: GenericImageDataAccessor) -> InMemoryDataAccessor: - return InMemoryDataAccessor(label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16')) +def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connect_3d=True) -> InMemoryDataAccessor: + """ + Convert binary segmentation mask into either a 2D or 3D object identities map + :param acc_seg_mask: binary segmentation mask (mono) of either two or three dimensions + :param allow_3d: return a 3D map if True; return a 2D map of the mask's maximum intensity project if False + :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False + :return: object identities map + """ + if allow_3d and connect_3d: + nda_la = label( + acc_seg_mask.data[:, :, 0, :] + ).astype('uint16') + return InMemoryDataAccessor(np.expand_dims(nda_la, 2)) + elif allow_3d and not connect_3d: + nla = 0 + la_3d = np.zeros((*acc_seg_mask.hw, 1, acc_seg_mask.nz), dtype='uint16') + for zi in range(0, acc_seg_mask.nz): + la_2d = label(acc_seg_mask.data[:, :, 0, zi]).astype('uint16') + la_2d[la_2d > 0] = la_2d[la_2d > 0] + nla + nla = la_2d.max() + la_3d[:, :, 0, zi] = la_2d + return InMemoryDataAccessor(la_3d) + else: + return InMemoryDataAccessor( + label( + acc_seg_mask.data[:, :, 0, :].max(axis=-1) + ).astype('uint16') + ) def _focus_metrics(): @@ -109,7 +132,6 @@ class RoiSet(object): :param params: optional arguments that influence the definition and representation of ROIs """ assert acc_obj_ids.chroma == 1 - assert acc_obj_ids.nz == 1 self.acc_obj_ids = acc_obj_ids self.acc_raw = acc_raw self.params = params @@ -135,26 +157,32 @@ class RoiSet(object): :param acc_raw: accessor to raw image data :param acc_obj_ids: accessor to map of object IDs :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary) + # :param deproject: assign object's z-position based on argmax of raw data if True :return: pd.DataFrame """ # build dataframe of objects, assign z index to each object - argmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16') - df = ( - pd.DataFrame( - regionprops_table( - acc_obj_ids.data[:, :, 0, 0], - intensity_image=argmax, - properties=('label', 'area', 'intensity_mean', 'solidity', 'bbox', 'centroid') - ) - ) - .rename( - columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1', } - ) - ) - df['zi'] = df['intensity_mean'].round().astype('int') + + if acc_obj_ids.nz == 1: # deproject objects' z-coordinates from argmax of raw image + df = pd.DataFrame(regionprops_table( + acc_obj_ids.data[:, :, 0, 0], + intensity_image=acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'), + properties=('label', 'area', 'intensity_mean', 'bbox', 'centroid') + )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'}) + df['zi'] = df['intensity_mean'].round().astype('int') + + else: # objects' z-coordinates come from arg of max count in object identities map + df = pd.DataFrame(regionprops_table( + acc_obj_ids.data[:, :, 0, :], + properties=('label', 'area', 'bbox', 'centroid') + )).rename(columns={ + 'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1' + }) + df['zi'] = df['label'].apply(lambda x: (acc_obj_ids.data == x).sum(axis=(0, 1, 2)).argmax()) # compute expanded bounding boxes h, w, c, nz = acc_raw.shape + df['h'] = df['y1'] - df['y0'] + df['w'] = df['x1'] - df['x0'] ebxy, ebz = expand_box_by df['ebb_y0'] = (df.y0 - ebxy).apply(lambda x: max(x, 0)) df['ebb_y1'] = (df.y1 + ebxy).apply(lambda x: min(x, h)) @@ -176,6 +204,11 @@ class RoiSet(object): assert np.all(df['rel_y1'] <= (df['ebb_y1'] - df['ebb_y0'])) df['slice'] = df.apply( + lambda r: + np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1)], + axis=1 + ) + df['expanded_slice'] = df.apply( lambda r: np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1], axis=1 @@ -185,19 +218,18 @@ class RoiSet(object): np.s_[int(r.rel_y0): int(r.rel_y1), int(r.rel_x0): int(r.rel_x1), :, :], axis=1 ) - df['mask'] = df.apply( - lambda r: (acc_obj_ids.data == r.label)[r.y0: r.y1, r.x0: r.x1, 0, 0], + df['binary_mask'] = df.apply( + lambda r: (acc_obj_ids.data == r.label).max(axis=-1)[r.y0: r.y1, r.x0: r.x1, 0], axis=1 ) return df - @staticmethod def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: query_str = 'label > 0' # always true if filters is not None: # parse filters for k, val in filters.dict(exclude_unset=True).items(): - assert k in ('area', 'solidity') + assert k in ('area') vmin = val['min'] vmax = val['max'] assert vmin >= 0 @@ -226,18 +258,18 @@ class RoiSet(object): projected = self.acc_raw.data.max(axis=-1) return projected - def get_raw_patches(self, channel=None, pad_to=256, make_3d=False): # padded, un-annotated 2d patches + def get_patches_acc(self, channel=None, **kwargs): # padded, un-annotated 2d patches if channel: - patches_df = self.get_patches(white_channel=channel, pad_to=pad_to) + patches_df = self.get_patches(white_channel=channel, **kwargs) else: - patches_df = self.get_patches(pad_to=pad_to) - patches = list(patches_df['patch']) - return PatchStack(patches) + patches_df = self.get_patches(**kwargs) + return PatchStack(list(patches_df.patch)) - def export_annotated_zstack(self, where, prefix='zstack', **kwargs): + def export_annotated_zstack(self, where, prefix='zstack', **kwargs) -> Path: annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs)) - success = write_accessor_data_to_file(where / (prefix + '.tif'), annotated) - return {'location': where.__str__(), 'filename': prefix + '.tif'} + fp = where / (prefix + '.tif') + write_accessor_data_to_file(fp, annotated) + return fp def get_zmask(self, mask_type='boxes'): """ @@ -271,7 +303,7 @@ class RoiSet(object): elif mask_type == 'boxes': for roi in self: - zi_st[roi.relative_slice] = 1 + zi_st[roi.slice] = True return zi_st @@ -279,9 +311,9 @@ class RoiSet(object): def classify_by(self, name: str, channel: int, object_classification_model: InstanceSegmentationModel, ): # do this on a patch basis, i.e. only one object per frame - obmap_patches = object_classification_model.label_instance_class( - self.get_raw_patches(channel=channel), - self.get_patch_masks() + obmap_patches = object_classification_model.label_patch_stack( + self.get_patches_acc(channel=channel, expaned=False, pad_to=None), + self.get_patch_masks_acc(expanded=False, pad_to=None) ) om = np.zeros(self.acc_obj_ids.shape, self.acc_obj_ids.dtype) @@ -299,20 +331,29 @@ class RoiSet(object): om[self.acc_obj_ids.data == roi.label] = oc self.object_class_maps[name] = InMemoryDataAccessor(om) - def export_patch_masks(self, where: Path, pad_to: int = 256, prefix='mask', **kwargs) -> list: - patches_acc = self.get_patch_masks(pad_to=pad_to) + def export_dataframe(self, csv_path: Path): + csv_path.parent.mkdir(parents=True, exist_ok=True) + self._df.drop(['expanded_slice', 'slice', 'relative_slice', 'binary_mask'], axis=1).to_csv(csv_path, index=False) + return csv_path + + + def export_patch_masks(self, where: Path, pad_to: int = None, prefix='mask', expanded=False) -> list: + patches_df = self.get_patch_masks(pad_to=pad_to, expanded=expanded) exported = [] - for i, roi in enumerate(self): # assumes index of patches_acc is same as dataframe - patch = patches_acc.iat_yxcz(i) + def _export_patch_mask(roi): + patch = InMemoryDataAccessor(roi.patch_mask) ext = 'png' fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' write_accessor_data_to_file(where / fname, patch) exported.append(fname) + + for roi in patches_df.itertuples(): # just used for label info + _export_patch_mask(roi) return exported - def export_patches(self, where: Path, prefix='patch', **kwargs): + def export_patches(self, where: Path, prefix='patch', **kwargs) -> list: make_3d = kwargs.get('make_3d', False) patches_df = self.get_patches(**kwargs) @@ -327,11 +368,7 @@ class RoiSet(object): else: write_accessor_data_to_file(where / fname, patch) - exported.append({ - 'df_index': roi.Index, - 'patch_filename': fname, - 'location': where.__str__(), - }) + exported.append(where / fname) exported = [] for roi in patches_df.itertuples(): # just used for label info @@ -339,28 +376,36 @@ class RoiSet(object): return exported - def get_patch_masks(self, pad_to: int = 256) -> PatchStack: - patches = [] - for roi in self: - patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') - patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255 + def get_patch_masks(self, pad_to: int = None, expanded: bool = False) -> pd.DataFrame: + def _make_patch_mask(roi): + if expanded: + patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') + patch[roi.relative_slice][:, :, 0, 0] = roi.binary_mask * 255 + else: + patch = np.zeros((roi.y1 - roi.y0, roi.x1 - roi.x0, 1, 1), dtype='uint8') + patch[:, :, 0, 0] = roi.binary_mask * 255 if pad_to: patch = pad(patch, pad_to) + return patch - patches.append(patch) - return PatchStack(patches) + dfe = self._df.copy() + dfe['patch_mask'] = dfe.apply(lambda r: _make_patch_mask(r), axis=1) + return dfe + def get_patch_masks_acc(self, **kwargs) -> PatchStack: + return PatchStack(list(self.get_patch_masks(**kwargs).patch_mask)) def get_patches( self, rescale_clip: float = 0.0, - pad_to: int = 256, + pad_to: int = None, make_3d: bool = False, focus_metric: str = None, rgb_overlay_channels: list = None, rgb_overlay_weights: list = [1.0, 1.0, 1.0], white_channel: int = None, + expanded=False, **kwargs ) -> pd.DataFrame: @@ -407,12 +452,17 @@ class RoiSet(object): stack = raw.data def _make_patch(roi): - patch3d = stack[roi.slice] + if expanded: + patch3d = stack[roi.expanded_slice] + subpatch = patch3d[roi.relative_slice] + else: + patch3d = stack[roi.slice] + subpatch = patch3d + ph, pw, pc, pz = patch3d.shape - subpatch = patch3d[roi.relative_slice] # make a 3d patch - if make_3d: + if make_3d or not expanded: patch = patch3d # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately @@ -433,10 +483,16 @@ class RoiSet(object): assert len(patch.shape) == 4 + mask = np.zeros(patch3d.shape[0:2], dtype=bool) + if expanded: + mask[roi.relative_slice[0:2]] = roi.binary_mask + else: + mask = roi.binary_mask + if rescale_clip is not None: patch = rescale(patch, rescale_clip) - if kwargs.get('draw_bounding_box') is True: + if kwargs.get('draw_bounding_box') is True and expanded: bci = kwargs.get('bounding_box_channel', 0) assert bci < 3 if bci > 0: @@ -450,15 +506,15 @@ class RoiSet(object): if kwargs.get('draw_mask'): mci = kwargs.get('mask_channel', 0) - mask = np.zeros(patch.shape[0:2], dtype=bool) - mask[roi.relative_slice[0:2]] = roi.mask + # mask = np.zeros(patch.shape[0:2], dtype=bool) + # mask[roi.relative_slice[0:2]] = roi.binary_mask for zi in range(0, patch.shape[3]): patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi] if kwargs.get('draw_contour'): mci = kwargs.get('contour_channel', 0) - mask = np.zeros(patch.shape[0:2], dtype=bool) - mask[roi.relative_slice[0:2]] = roi.mask + # mask = np.zeros(patch.shape[0:2], dtype=bool) + # mask[roi.relative_slice[0:2]] = roi.binary_mask for zi in range(0, patch.shape[3]): patch[:, :, mci, zi] = draw_contours_on_patch( @@ -466,12 +522,12 @@ class RoiSet(object): find_contours(mask) ) - if pad_to: + if pad_to and expanded: patch = pad(patch, pad_to) return patch - dfe = self._df - dfe['patch'] = self._df.apply(lambda r: _make_patch(r), axis=1) + dfe = self._df.copy() + dfe['patch'] = dfe.apply(lambda r: _make_patch(r), axis=1) return dfe def run_exports(self, where: Path, channel, prefix, params: RoiSetExportParams) -> dict: @@ -481,12 +537,15 @@ class RoiSet(object): :param channel: color channel of products to export :param prefix: prefix of the name of each product's file or subfolder :param params: RoiSetExportParams object describing which products to export and with which parameters - :return: dict of Path objects describing the location of single-file export products + :return: nested dict of Path objects describing the location of export products """ record = {} if not self.count: return - raw_ch = self.acc_raw.get_one_channel_data(channel) + + # export dataframe and patch masks + record = self.serialize(where, prefix=prefix) + for k in params.dict().keys(): subdir = where / k pr = prefix @@ -494,11 +553,11 @@ class RoiSet(object): if kp is None: continue if k == 'patches_3d': - files = self.export_patches( - subdir, white_channel=channel, prefix=pr, make_3d=True, **kp + record[k] = self.export_patches( + subdir, white_channel=channel, prefix=pr, make_3d=True, expanded=True, **kp ) if k == 'annotated_patches_2d': - files = self.export_patches( + record[k] = self.export_patches( subdir, prefix=pr, make_3d=False, white_channel=channel, bounding_box_channel=1, bounding_box_linewidth=2, **kp, ) @@ -509,22 +568,53 @@ class RoiSet(object): df_patches = pd.DataFrame(files) self._df = pd.merge(self._df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1) - if k == 'patch_masks': - self.export_patch_masks(subdir, prefix=pr, **kp) + record[k] = files if k == 'annotated_zstacks': - self.export_annotated_zstack(subdir, prefix=pr, **kp) + record[k] = self.export_annotated_zstack(subdir, prefix=pr, **kp) if k == 'object_classes': for kc, acc in self.object_class_maps.items(): fp = subdir / kc / (pr + '.tif') write_accessor_data_to_file(fp, acc) record[f'{k}_{kc}'] = fp - if k == 'dataframe': - dfpa = subdir / (pr + '.csv') - dfpa.parent.mkdir(parents=True, exist_ok=True) - self._df.to_csv(dfpa, index=False) - record[k] = dfpa + + return record + + def serialize(self, where: Path, prefix='') -> dict: + """ + Export the minimal information needed to recreate RoiSet object, i.e. CSV data file and tight patch masks + :param where: path of directory in which to write files + :param prefix: (optional) prefix + :return: nested dict of Path objects describing the locations of export products + """ + record = {} + record['dataframe'] = self.export_dataframe(where / 'dataframe' / (prefix + '.csv')) + record['tight_patch_masks'] = self.export_patch_masks( + where / 'tight_patch_masks', + prefix=prefix, + pad_to=None, + expanded=False + ) return record + @staticmethod + def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix=''): + df = pd.read_csv(where / 'dataframe' / (prefix + '.csv'))[['label', 'zi', 'y0', 'y1', 'x0', 'x1']] + + id_mask = np.zeros((*acc_raw.hw, 1, acc_raw.nz), dtype='uint16') + def _label_obj(r): + sl = np.s_[r.y0:r.y1, r.x0:r.x1, :, r.zi:r.zi + 1] + ext = 'png' + fname = f'{prefix}-la{r.label:04d}-zi{r.zi:04d}.{ext}' + try: + ma_acc = generate_file_accessor(where / 'tight_patch_masks' / fname) + bool_mask = ma_acc.data / np.iinfo(ma_acc.data.dtype).max + id_mask[sl] = r.label * bool_mask + except Exception as e: + raise DeserializeRoiSet(e) + + df.apply(_label_obj, axis=1) + return RoiSet(acc_raw, InMemoryDataAccessor(id_mask)) + def project_stack_from_focal_points( xx: np.ndarray, @@ -573,3 +663,9 @@ def project_stack_from_focal_points( ) + +class Error(Exception): + pass + +class DeserializeRoiSet(Error): + pass \ No newline at end of file diff --git a/model_server/base/session.py b/model_server/base/session.py index 1b91b12a570ebd807cc89bf0b4e3a78771a524c9..3d5b789de7f644fa6f16fe9f9170b83d540f2328 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -1,36 +1,82 @@ -import json +import logging import os from pathlib import Path from time import strftime, localtime from typing import Dict +import pandas as pd + import model_server.conf.defaults from model_server.base.models import Model -from model_server.base.workflows import WorkflowRunRecord -def create_manifest_json(): - pass +logger = logging.getLogger(__name__) + + +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + +class CsvTable(object): + def __init__(self, fpath: Path): + self.path = fpath + self.empty = True -class Session(object): + def append(self, coords: dict, data: pd.DataFrame) -> bool: + assert isinstance(data, pd.DataFrame) + for c in reversed(coords.keys()): + data.insert(0, c, coords[c]) + if self.empty: + data.to_csv(self.path, index=False, mode='w', header=True) + else: + data.to_csv(self.path, index=False, mode='a', header=False) + self.empty = False + return True + +class Session(object, metaclass=Singleton): """ Singleton class for a server session that persists data between API calls """ - def __new__(cls): - if not hasattr(cls, 'instance'): - cls.instance = super(Session, cls).__new__(cls) - return cls.instance + log_format = '%(asctime)s - %(levelname)s - %(message)s' def __init__(self, root: str = None): print('Initializing session') self.models = {} # model_id : model object - self.manifest = [] # paths to data as well as other metadata from each inference run self.paths = self.make_paths(root) - self.session_log = self.paths['logs'] / f'session.log' - self.log_event('Initialized session') - self.manifest_json = self.paths['logs'] / f'manifest.json' - open(self.manifest_json, 'w').close() # instantiate empty json file + + self.logfile = self.paths['logs'] / f'session.log' + logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True, format=self.log_format) + + self.log_info('Initialized session') + self.tables = {} + + def write_to_table(self, name: str, coords: dict, data: pd.DataFrame): + """ + Write data to a named data table, initializing if it does not yet exist. + :param name: name of the table to persist through session + :param coords: dictionary of coordinates to associate with all rows in this method call + :param data: DataFrame containing data + :return: True if successful + """ + try: + if name in self.tables.keys(): + table = self.tables.get(name) + else: + table = CsvTable(self.paths['tables'] / (name + '.csv')) + self.tables[name] = table + except Exception: + raise CouldNotCreateTable(f'Unable to create table named {name}') + + try: + table.append(coords, data) + return True + except Exception: + raise CouldNotAppendToTable(f'Unable to append data to table named {name}') def get_paths(self): return self.paths @@ -55,7 +101,7 @@ class Session(object): root_path = Path(root) sid = Session.create_session_id(root_path) paths = {'root': root_path} - for pk in ['inbound_images', 'outbound_images', 'logs']: + for pk in ['inbound_images', 'outbound_images', 'logs', 'tables']: pa = root_path / sid / model_server.conf.defaults.subdirectories[pk] paths[pk] = pa try: @@ -75,22 +121,23 @@ class Session(object): idx += 1 return f'{yyyymmdd}-{idx:04d}' - def log_event(self, event: str): - """ - Write an event string to this session's log file. - """ - timestamp = strftime('%m/%d/%Y, %H:%M:%S', localtime()) - with open(self.session_log, 'a') as fh: - fh.write(f'{timestamp} -- {event}\n') + def get_log_data(self) -> list: + log = [] + with open(self.logfile, 'r') as fh: + for line in fh: + k = ['datatime', 'level', 'message'] + v = line.strip().split(' - ')[0:3] + log.insert(0, dict(zip(k, v))) + return log - def record_workflow_run(self, record: WorkflowRunRecord or None): - """ - Append a JSON describing inference data to this session's manifest - """ - self.log_event(f'Ran model {record.model_id} on {record.input_filepath} to infer {record.output_filepath}') - with open(self.manifest_json, 'w+') as fh: - json.dump(record.dict(), fh) + def log_info(self, msg): + logger.info(msg) + + def log_warning(self, msg): + logger.warning(msg) + def log_error(self, msg): + logger.error(msg) def load_model(self, ModelClass: Model, params: Dict[str, str] = None) -> dict: """ @@ -114,7 +161,7 @@ class Session(object): 'object': mi, 'params': params } - self.log_event(f'Loaded model {key}') + self.log_info(f'Loaded model {key}') return key def describe_loaded_models(self) -> dict: @@ -157,5 +204,11 @@ class CouldNotInstantiateModelError(Error): class CouldNotCreateDirectory(Error): pass +class CouldNotCreateTable(Error): + pass + +class CouldNotAppendToTable(Error): + pass + class InvalidPathError(Error): pass \ No newline at end of file diff --git a/model_server/conf/defaults.py b/model_server/conf/defaults.py index 55b114f2686e9575fb2e3eb77b9e0577f40b5918..bdf7cfd0cf2786783b8f4c16dafdf7418af05976 100644 --- a/model_server/conf/defaults.py +++ b/model_server/conf/defaults.py @@ -6,6 +6,7 @@ subdirectories = { 'logs': 'logs', 'inbound_images': 'images/inbound', 'outbound_images': 'images/outbound', + 'tables': 'tables', } server_conf = { diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py index 3d07931ee087dba7eb56f416f232b16b89056161..97a13aaf75192c76a75ace8c8bcccf3678370a5d 100644 --- a/model_server/conf/testing.py +++ b/model_server/conf/testing.py @@ -67,6 +67,7 @@ roiset_test_data = { 'c': 5, 'z': 7, 'mask_path': root / 'zmask-test-stack-mask.tif', + 'mask_path_3d': root / 'zmask-test-stack-mask-3d.tif', }, 'pipeline_params': { 'segmentation_channel': 0, diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index de25566a6571b6a07f8f6e1d453e49a1b13f9c72..7229b43eb601dc0703499392403e42e1bce3bf98 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -1,3 +1,4 @@ +import json import os from pathlib import Path @@ -12,8 +13,17 @@ from model_server.base.models import Model, ImageToImageModel, InstanceSegmentat class IlastikModel(Model): - def __init__(self, params, autoload=True): + def __init__(self, params, autoload=True, enforce_embedded=True): + """ + Base class for models that run via ilastik shell API + :param params: + project_file: path to ilastik project file + :param autoload: automatically load model into memory if true + :param enforce_embedded: + raise an error if all input data are not embedded in the project file, i.e. on the filesystem + """ self.project_file = Path(params['project_file']) + self.enforce_embedded = enforce_embedded params['project_file'] = self.project_file.__str__() if self.project_file.is_absolute(): pap = self.project_file @@ -42,6 +52,15 @@ class IlastikModel(Model): args.project = self.project_file_abspath.__str__() shell = app.main(args, init_logging=False) + # validate if inputs are embedded in project file + input_groups = shell.projectManager.currentProjectFile['Input Data']['infos'] + lanes = input_groups.keys() + for ll in lanes: + input_types = input_groups[ll] + for tt in input_types: + ds_loc = input_groups[ll][tt].get('location', False) + if self.enforce_embedded and ds_loc and ds_loc[()] == b'FileSystem': + raise IlastikInputEmbedding('Cannot load ilastik project file where inputs are on filesystem') if not isinstance(shell.workflow, self.get_workflow()): raise ParameterExpectedError( f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}' @@ -55,12 +74,35 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): model_id = 'ilastik_pixel_classification' operations = ['segment', ] + @property + def model_shape_dict(self): + raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data'] + ax = raw_info['axistags'][()] + ax_keys = [ax['key'].upper() for ax in json.loads(ax)['axes']] + shape = raw_info['shape'][()] + dd = dict(zip(ax_keys, shape)) + for ci in 'TCZ': + if ci not in dd.keys(): + dd[ci] = 1 + return dd + + @property + def model_chroma(self): + return self.model_shape_dict['C'] + + @property + def model_3d(self): + return self.model_shape_dict['Z'] > 1 + @staticmethod def get_workflow(): from ilastik.workflows import PixelClassificationWorkflow return PixelClassificationWorkflow def infer(self, input_img: GenericImageDataAccessor) -> (np.ndarray, dict): + if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): + raise IlastikInputShapeError() + tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') dsi = [ { @@ -87,22 +129,33 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel): model_id = 'ilastik_object_classification_from_segmentation' + @staticmethod + def _make_8bit_mask(nda): + if nda.dtype == 'bool': + return 255 * nda.astype('uint8') + else: + return nda + @staticmethod def get_workflow(): from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary return ObjectClassificationWorkflowBinary def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): - tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') assert segmentation_img.is_mask() - if segmentation_img.dtype == 'bool': - seg = 255 * segmentation_img.data.astype('uint8') + if isinstance(input_img, PatchStack): + assert isinstance(segmentation_img, PatchStack) + tagged_input_data = vigra.taggedView(input_img.pczyx, 'tczyx') tagged_seg_data = vigra.taggedView( - 255 * segmentation_img.data.astype('uint8'), - 'yxcz' + self._make_8bit_mask(segmentation_img.pczyx), + 'tczyx' ) else: - tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') + tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') + tagged_seg_data = vigra.taggedView( + self._make_8bit_mask(segmentation_img.data), + 'yxcz' + ) dsi = [ { @@ -115,12 +168,21 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment assert len(obmaps) == 1, 'ilastik generated more than one object map' - yxcz = np.moveaxis( - obmaps[0], - [1, 2, 3, 0], - [0, 1, 2, 3] - ) - return InMemoryDataAccessor(data=yxcz), {'success': True} + + if isinstance(input_img, PatchStack): + pyxcz = np.moveaxis( + obmaps[0], + [0, 1, 2, 3, 4], + [0, 4, 1, 2, 3] + ) + return PatchStack(data=pyxcz), {'success': True} + else: + yxcz = np.moveaxis( + obmaps[0], + [1, 2, 3, 0], + [0, 1, 2, 3] + ) + return InMemoryDataAccessor(data=yxcz), {'success': True} def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) @@ -172,48 +234,40 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag """ if not img.shape == pxmap.shape: raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape') - # TODO: check that pxmap is in-range pxch = kwargs.get('pixel_classification_channel', 0) - pxtr = kwargs('pixel_classification_threshold', 0.5) + pxtr = kwargs.get('pixel_classification_threshold', 0.5) mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr) - # super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) obmap, _ = self.infer(img, mask) return obmap + def make_instance_segmentation_model(self, px_ch: int): + """ + Generate an instance segmentation model, i.e. one that takes binary masks instead of pixel probabilities as a + second input. + :param px_ch: channel of pixel probability map to use + :return: + InstanceSegmentationModel object + """ + class _Mod(self.__class__, InstanceSegmentationModel): + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + if mask.dtype == 'bool': + norm_mask = 1.0 * mask.data + else: + norm_mask = mask.data / np.iinfo(mask.dtype).max + norm_mask_acc = InMemoryDataAccessor(norm_mask.astype('float32')) + return super().label_instance_class(img, norm_mask_acc, pixel_classification_channel=px_ch) + return _Mod(params={'project_file': self.project_file}) -class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel): - """ - Wrap ilastik object classification for inputs comprising single-object series of raw images and binary - segmentation masks. - """ - - def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict): - assert segmentation_acc.is_mask() - if not input_acc.chroma == 1: - raise InvalidInputImageError('Object classifier expects only monochrome patches') - if not input_acc.nz == 1: - raise InvalidInputImageError('Object classifier expects only 2d patches') - - tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx') - tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx') - - dsi = [ - { - 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), - 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), - } - ] - obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] - assert len(obmaps) == 1, 'ilastik generated more than one object map' +class Error(Exception): + pass - # for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C - assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1) - pyxcz = np.moveaxis( - obmaps[0], - [0, 1, 2, 3, 4], - [0, 4, 1, 2, 3] - ) +class IlastikInputEmbedding(Error): + pass - return PatchStack(data=pyxcz), {'success': True} \ No newline at end of file +class IlastikInputShapeError(Error): + """Raised when an ilastik classifier is asked to infer on data that is incompatible with its input shape""" + pass \ No newline at end of file diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py index 3411ea7ca158a46c40d5f94943094491ad793a41..151e37915c16dd8a3943dc0b07e1e188d45606a3 100644 --- a/model_server/extensions/ilastik/router.py +++ b/model_server/extensions/ilastik/router.py @@ -25,9 +25,11 @@ def load_ilastik_model(model_class: ilm.IlastikModel, project_file: str, duplica if not duplicate: existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True) if existing_model_id is not None: + session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate') return {'model_id': existing_model_id} try: result = session.load_model(model_class, {'project_file': project_file}) + session.log_info(f'Loaded ilastik model {result} from {project_file}') except (FileNotFoundError, ParameterExpectedError): raise HTTPException( status_code=404, @@ -60,6 +62,7 @@ def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: st channel=channel, mip=mip, ) + session.log_info(f'Completed pixel and object classification of {input_filename}') except AssertionError: raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}') return record \ No newline at end of file diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index 32dd137683a1a65a0fcbfc48535dd8a88a221213..d042e57150413c18ca36a2bbb8cc379a122bdff5 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -11,6 +11,9 @@ from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams from model_server.base.workflows import classify_pixels from tests.test_api import TestServerBaseClass +def _random_int(*args): + return np.random.randint(0, 2 ** 8, size=args, dtype='uint8') + class TestIlastikPixelClassification(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) @@ -83,6 +86,40 @@ class TestIlastikPixelClassification(unittest.TestCase): self.mono_image = mono_image self.mask = mask + def test_pixel_classifier_enforces_input_shape(self): + model = ilm.IlastikPixelClassifierModel( + {'project_file': ilastik_classifiers['px']} + ) + self.assertEqual(model.model_chroma, 1) + self.assertEqual(model.model_3d, False) + + # correct data + self.assertIsInstance( + model.label_pixel_class( + InMemoryDataAccessor( + _random_int(512, 256, 1, 1) + ) + ), + InMemoryDataAccessor + ) + + # raise except with input of multiple channels + with self.assertRaises(ilm.IlastikInputShapeError): + mask = model.label_pixel_class( + InMemoryDataAccessor( + _random_int(512, 256, 3, 1) + ) + ) + + # raise except with input of multiple channels + with self.assertRaises(ilm.IlastikInputShapeError): + mask = model.label_pixel_class( + InMemoryDataAccessor( + _random_int(512, 256, 1, 15) + ) + ) + + def test_run_object_classifier_from_pixel_predictions(self): self.test_run_pixel_classifier() fp = czifile['path'] @@ -97,7 +134,24 @@ class TestIlastikPixelClassification(unittest.TestCase): objmap, ) ) - self.assertEqual(objmap.data.max(), 3) + self.assertEqual(objmap.data.max(), 2) + + def test_make_seg_obj_model_from_pxmap_obj(self): + self.test_run_pixel_classifier() + fp = czifile['path'] + pxmap_model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( + {'project_file': ilastik_classifiers['pxmap_to_obj']} + ) + seg_model = pxmap_model.make_instance_segmentation_model(px_ch=0) + objmap = seg_model.label_instance_class(self.mono_image, self.mask) + + self.assertTrue( + write_accessor_data_to_file( + output_path / f'obmap_seg_from_pxmap_{fp.stem}.tif', + objmap, + ) + ) + self.assertEqual(objmap.data.max(), 2) def test_run_object_classifier_from_segmentation(self): self.test_run_pixel_classifier() @@ -113,7 +167,7 @@ class TestIlastikPixelClassification(unittest.TestCase): objmap, ) ) - self.assertEqual(objmap.data.max(), 3) + self.assertEqual(objmap.data.max(), 2) def test_ilastik_pixel_classification_as_workflow(self): result = classify_pixels( @@ -168,7 +222,7 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd) def test_no_duplicate_model_with_different_path_formats(self): - self._get('restart') + self._get('session/restart') resp_list_1 = self._get('models').json() self.assertEqual(len(resp_list_1), 0) ilp = ilastik_classifiers['px'] @@ -269,16 +323,18 @@ class TestIlastikObjectClassification(unittest.TestCase): ) ) - self.object_classifier = ilm.PatchStackObjectClassifier( + self.object_classifier = ilm.IlastikObjectClassifierFromSegmentationModel( params={'project_file': ilastik_classifiers['seg_to_obj']} ) + def test_classify_patches(self): - raw_patches = self.roiset.get_raw_patches() - patch_masks = self.roiset.get_patch_masks() - res_patches, _ = self.object_classifier.infer(raw_patches, patch_masks) + raw_patches = self.roiset.get_patches_acc() + patch_masks = self.roiset.get_patch_masks_acc() + res_patches = self.object_classifier.label_instance_class(raw_patches, patch_masks) self.assertEqual(res_patches.count, self.roiset.count) + res_patches.export_pyxcz(output_path / 'res_patches.tif') for pi in range(0, res_patches.count): # assert that there is only one nonzero label per patch - unique = np.unique(res_patches.iat(pi).data) - self.assertEqual(len(unique), 2) - self.assertEqual(unique[0], 0) + la, ct = np.unique(res_patches.iat(pi).data, return_counts=True) + self.assertEqual(np.sum(ct > 1), 2) # exclude single-pixel anomaly + self.assertEqual(la[0], 0) diff --git a/tests/test_accessors.py b/tests/test_accessors.py index bc5b4065f314eadd66692a806cb2356eccdc28e2..e84788d251f78363945ec331d8c74632cb03f926 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -7,6 +7,9 @@ from model_server.base.accessors import PatchStack, make_patch_stack_from_file, from model_server.conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile, monozstackmask from model_server.base.accessors import CziImageFileAccessor, DataShapeError, generate_file_accessor, InMemoryDataAccessor, PngFileAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor +def _random_int(*args): + return np.random.randint(0, 2 ** 8, size=args, dtype='uint8') + class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: @@ -40,7 +43,7 @@ class TestCziImageFileAccess(unittest.TestCase): nc = 4 nz = 11 c = 3 - cf = InMemoryDataAccessor(np.random.rand(h, w, nc, nz)) + cf = InMemoryDataAccessor(_random_int(h, w, nc, nz)) sc = cf.get_one_channel_data(c) self.assertEqual(sc.shape, (h, w, 1, nz)) @@ -70,7 +73,7 @@ class TestCziImageFileAccess(unittest.TestCase): def test_conform_data_shorter_than_xycz(self): h = 256 w = 512 - data = np.random.rand(h, w, 1) + data = _random_int(h, w, 1) acc = InMemoryDataAccessor(data) self.assertEqual( InMemoryDataAccessor.conform_data(data).shape, @@ -82,7 +85,7 @@ class TestCziImageFileAccess(unittest.TestCase): ) def test_conform_data_longer_than_xycz(self): - data = np.random.rand(256, 512, 12, 8, 3) + data = _random_int(256, 512, 12, 8, 3) with self.assertRaises(DataShapeError): acc = InMemoryDataAccessor(data) @@ -93,7 +96,7 @@ class TestCziImageFileAccess(unittest.TestCase): c = 3 nz = 10 - yxcz = (2**8 * np.random.rand(h, w, c, nz)).astype('uint8') + yxcz = _random_int(h, w, c, nz) acc = InMemoryDataAccessor(yxcz) fp = output_path / f'rand3d.tif' self.assertTrue( @@ -138,7 +141,7 @@ class TestPatchStackAccessor(unittest.TestCase): w = 256 h = 512 n = 4 - acc = PatchStack(np.random.rand(n, h, w, 1, 1)) + acc = PatchStack(_random_int(n, h, w, 1, 1)) self.assertEqual(acc.count, n) self.assertEqual(acc.hw, (h, w)) self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1)) @@ -147,7 +150,7 @@ class TestPatchStackAccessor(unittest.TestCase): w = 256 h = 512 n = 4 - acc = PatchStack([np.random.rand(h, w, 1, 1) for _ in range(0, n)]) + acc = PatchStack([_random_int(h, w, 1, 1) for _ in range(0, n)]) self.assertEqual(acc.count, n) self.assertEqual(acc.hw, (h, w)) self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1)) @@ -176,8 +179,8 @@ class TestPatchStackAccessor(unittest.TestCase): nz = 5 n = 4 - patches = [np.random.rand(h, w, c, nz) for _ in range(0, n)] - patches.append(np.random.rand(h, 2 * w, c, nz)) + patches = [_random_int(h, w, c, nz) for _ in range(0, n)] + patches.append(_random_int(h, 2 * w, c, nz)) acc = PatchStack(patches) self.assertEqual(acc.count, n + 1) self.assertEqual(acc.hw, (h, 2 * w)) @@ -191,7 +194,30 @@ class TestPatchStackAccessor(unittest.TestCase): n = 4 nz = 15 nc = 2 - acc = PatchStack(np.random.rand(n, h, w, nc, nz)) + acc = PatchStack(_random_int(n, h, w, nc, nz)) self.assertEqual(acc.count, n) self.assertEqual(acc.pczyx.shape, (n, nc, nz, h, w)) self.assertEqual(acc.hw, (h, w)) + return acc + + def test_get_one_channel(self): + acc = self.test_pczyx() + mono = acc.get_one_channel_data(channel=1) + for a in 'PXYZ': + self.assertEqual(mono.shape_dict[a], acc.shape_dict[a]) + self.assertEqual(mono.shape_dict['C'], 1) + + def test_get_one_channel_mip(self): + acc = self.test_pczyx() + mono_mip = acc.get_one_channel_data(channel=1, mip=True) + for a in 'PXY': + self.assertEqual(mono_mip.shape_dict[a], acc.shape_dict[a]) + for a in 'CZ': + self.assertEqual(mono_mip.shape_dict[a], 1) + + def test_export_pczyx_patch_hyperstack(self): + acc = self.test_pczyx() + fp = output_path / 'patch_hyperstack.tif' + acc.export_pyxcz(fp) + acc2 = make_patch_stack_from_file(fp) + self.assertEqual(acc.shape, acc2.shape) \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py index 1b201440da4504e16f10573d31747603926e3b17..aa3023382890b14aada48e2f22aa00f035372931 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -136,7 +136,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(resp_list_0.status_code, 200) rj0 = resp_list_0.json() self.assertEqual(len(rj0), 1, f'Unexpected models in response: {rj0}') - resp_restart = self._get('restart') + resp_restart = self._get('session/restart') resp_list_1 = self._get('models') rj1 = resp_list_1.json() self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}') @@ -171,4 +171,9 @@ class TestApiFromAutomatedClient(TestServerBaseClass): ) self.assertEqual(resp_change.status_code, 200) resp_check = self._get('paths') - self.assertEqual(resp_inpath.json()['outbound_images'], resp_check.json()['outbound_images']) \ No newline at end of file + self.assertEqual(resp_inpath.json()['outbound_images'], resp_check.json()['outbound_images']) + + def test_get_logs(self): + resp = self._get('session/logs') + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json()[0]['message'], 'Initialized session') \ No newline at end of file diff --git a/tests/test_roiset.py b/tests/test_roiset.py index 3040e6d251305731d0c3dcad5b3727a1bb48b614..0d507b3f90f3346335f9798a84e77efea94baf35 100644 --- a/tests/test_roiset.py +++ b/tests/test_roiset.py @@ -1,8 +1,12 @@ +import os +import re import unittest import numpy as np from pathlib import Path +import pandas as pd + from model_server.conf.testing import output_path, roiset_test_data from model_server.base.roiset import RoiSetExportParams, RoiSetMetaParams @@ -29,31 +33,36 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): params=RoiSetMetaParams( mask_type=mask_type, filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}), - expand_box_by=(64, 2) + expand_box_by=(128, 2) ) ) return roiset def test_roi_mask_shape(self, **kwargs): roiset = self._make_roi_set(**kwargs) + + # all masks' bounding boxes are at least as big as ROI area + for roi in roiset.get_df().itertuples(): + self.assertEqual(roi.binary_mask.dtype, 'bool') + sh = roi.binary_mask.shape + self.assertEqual(sh, (roi.h, roi.w)) + self.assertGreaterEqual(sh[0] * sh[1], roi.area) + + def test_roi_zmask(self, **kwargs): + roiset = self._make_roi_set(**kwargs) zmask = roiset.get_zmask() zmask_acc = InMemoryDataAccessor(zmask) self.assertTrue(zmask_acc.is_mask()) # assert dimensionality of zmask - self.assertGreater(zmask_acc.shape_dict['Z'], 1) - self.assertEqual(zmask_acc.shape_dict['C'], 1) + self.assertEqual(zmask_acc.nz, roiset.acc_raw.nz) + self.assertEqual(zmask_acc.chroma, 1) write_accessor_data_to_file(output_path / 'zmask.tif', zmask_acc) # mask values are not just all True or all False self.assertTrue(np.any(zmask)) self.assertFalse(np.all(zmask)) - # assert non-trivial meta info in boxes - self.assertGreater(roiset.count, 1) - sh = roiset.get_df().iloc[1]['mask'].shape - ar = roiset.get_df().iloc[1]['area'] - self.assertGreaterEqual(sh[0] * sh[1], ar) def test_roiset_from_non_zstacks(self, **kwargs): acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0]) @@ -73,37 +82,75 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) + def test_dataframe_and_mask_array_in_iterator(self): + roiset = self._make_roi_set() + for roi in roiset: + ma = roi.binary_mask + self.assertEqual(ma.dtype, 'bool') + self.assertEqual(ma.shape, (roi.h, roi.w)) + def test_rel_slices_are_valid(self): roiset = self._make_roi_set() for roi in roiset: - ebb = roiset.acc_raw.data[roi.slice] + ebb = roiset.acc_raw.data[roi.expanded_slice] self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) rbb = ebb[roi.relative_slice] self.assertEqual(len(rbb.shape), 4) self.assertTrue(np.all([si >= 1 for si in rbb.shape])) - def test_make_2d_patches(self): + + def test_make_expanded_2d_patches(self): roiset = self._make_roi_set() files = roiset.export_patches( - output_path / '2d_patches', + output_path / 'expanded_2d_patches', draw_bounding_box=True, + expanded=True, + pad_to=256, ) - self.assertGreaterEqual(len(files), 1) - - def test_make_3d_patches(self): + df = roiset.get_df() + for f in files: + acc = generate_file_accessor(f) + la = int(re.search(r'la([\d]+)', str(f)).group(1)) + roi_q = df.loc[df.label == la, :] + self.assertEqual(len(roi_q), 1) + self.assertEqual((256, 256), acc.hw) + + def test_make_tight_2d_patches(self): + roiset = self._make_roi_set() + files = roiset.export_patches( + output_path / 'tight_2d_patches', + draw_bounding_box=True, + expanded=False + ) + df = roiset.get_df() + for f in files: # all exported files are same shape as bounding boxes in RoiSet's datatable + acc = generate_file_accessor(f) + la = int(re.search(r'la([\d]+)', str(f)).group(1)) + roi_q = df.loc[df.label == la, :] + self.assertEqual(len(roi_q), 1) + roi = roi_q.iloc[0] + self.assertEqual((roi.h, roi.w), acc.hw) + + def test_make_expanded_3d_patches(self): roiset = self._make_roi_set() files = roiset.export_patches( output_path / '3d_patches', - make_3d=True) + make_3d=True, + expanded=True + ) self.assertGreaterEqual(len(files), 1) + for f in files: + acc = generate_file_accessor(f) + self.assertGreater(acc.nz, 1) + def test_export_annotated_zstack(self): roiset = self._make_roi_set() file = roiset.export_annotated_zstack( output_path / 'annotated_zstack', ) - result = generate_file_accessor(Path(file['location']) / file['filename']) + result = generate_file_accessor(file) self.assertEqual(result.shape, roiset.acc_raw.shape) def test_flatten_image(self): @@ -132,7 +179,15 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_make_binary_masks(self): roiset = self._make_roi_set() files = roiset.export_patch_masks(output_path / '2d_mask_patches', ) - self.assertGreaterEqual(len(files), 1) + + df = roiset.get_df() + for f in files: # all exported files are same shape as bounding boxes in RoiSet's datatable + acc = generate_file_accessor(output_path / '2d_mask_patches' / f) + la = int(re.search(r'la([\d]+)', str(f)).group(1)) + roi_q = df.loc[df.label == la, :] + self.assertEqual(len(roi_q), 1) + roi = roi_q.iloc[0] + self.assertEqual((roi.h, roi.w), acc.hw) def test_classify_by(self): roiset = self._make_roi_set() @@ -156,17 +211,21 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_raw_patches_are_correct_shape(self): roiset = self._make_roi_set() - patches = roiset.get_raw_patches() + patches = roiset.get_patches_acc() np, h, w, nc, nz = patches.shape self.assertEqual(np, roiset.count) self.assertEqual(nc, roiset.acc_raw.chroma) + self.assertEqual(nz, 1) def test_patch_masks_are_correct_shape(self): roiset = self._make_roi_set() - patch_masks = roiset.get_patch_masks() - np, h, w, nc, nz = patch_masks.shape - self.assertEqual(np, roiset.count) - self.assertEqual(nc, 1) + df_patch_masks = roiset.get_patch_masks() + for roi in df_patch_masks.itertuples(): + h, w, nc, nz = roi.patch_mask.shape + self.assertEqual(nc, 1) + self.assertEqual(nz, 1) + self.assertEqual(h, roi.h) + self.assertEqual(w, roi.w) class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): @@ -189,8 +248,10 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa output_path / 'multichannel' / 'mono_2d_patches', white_channel=3, draw_bounding_box=True, + expanded=True, + pad_to=256, ) - result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + result = generate_file_accessor(files[0]) self.assertEqual(result.chroma, 1) def test_multichannnel_to_mono_2d_patches_rgb_bbox(self): @@ -199,8 +260,10 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa white_channel=3, draw_bounding_box=True, bounding_box_channel=1, + expanded=True, + pad_to=256, ) - result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + result = generate_file_accessor(files[0]) self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_bbox(self): @@ -208,11 +271,28 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa output_path / 'multichannel' / 'rgb_2d_patches_bbox', white_channel=4, rgb_overlay_channels=(3, None, None), + draw_mask=False, + draw_bounding_box=True, + bounding_box_channel=1, + rgb_overlay_weights=(0.1, 1.0, 1.0), + expanded=True, + pad_to=256, + ) + result = generate_file_accessor(files[0]) + self.assertEqual(result.chroma, 3) + + def test_multichannnel_to_rgb_2d_patches_mask(self): + files = self.roiset.export_patches( + output_path / 'multichannel' / 'rgb_2d_patches_mask', + white_channel=4, + rgb_overlay_channels=(3, None, None), draw_mask=True, mask_channel=0, - rgb_overlay_weights=(0.1, 1.0, 1.0) + rgb_overlay_weights=(0.1, 1.0, 1.0), + expanded=True, + pad_to=256, ) - result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + result = generate_file_accessor(files[0]) self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_contour(self): @@ -221,25 +301,32 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa rgb_overlay_channels=(3, None, None), draw_contour=True, contour_channel=1, - rgb_overlay_weights=(0.1, 1.0, 1.0) + rgb_overlay_weights=(0.1, 1.0, 1.0), + expanded=True, + pad_to=256, ) - result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + result = generate_file_accessor(files[0]) self.assertEqual(result.chroma, 3) self.assertEqual(result.get_one_channel_data(2).data.max(), 0) # blue channel is black def test_multichannel_to_multichannel_tif_patches(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'multichannel_tif_patches', + expanded=True, + pad_to=256, ) - result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) + result = generate_file_accessor(files[0]) self.assertEqual(result.chroma, 5) + self.assertEqual(result.nz, 1) def test_multichannel_annotated_zstack(self): file = self.roiset.export_annotated_zstack( output_path / 'multichannel' / 'annotated_zstack', 'test_multichannel_annotated_zstack', + expanded=True, + pad_to=256, ) - result = generate_file_accessor(Path(file['location']) / file['filename']) + result = generate_file_accessor(file) self.assertEqual(result.chroma, self.stack.chroma) self.assertEqual(result.nz, self.stack.nz) @@ -247,9 +334,104 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa file = self.roiset.export_annotated_zstack( output_path / 'annotated_zstack', channel=3, + expanded=True, + pad_to=256, ) - result = generate_file_accessor(Path(file['location']) / file['filename']) + result = generate_file_accessor(file) self.assertEqual(result.hw, self.roiset.acc_raw.hw) self.assertEqual(result.nz, self.roiset.acc_raw.nz) self.assertEqual(result.chroma, 1) +class TestRoiSetFromZmask(unittest.TestCase): + + def setUp(self) -> None: + # set up test raw data and segmentation from file + self.stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) + self.stack_ch_pa = self.stack.get_one_channel_data(roiset_test_data['pipeline_params']['segmentation_channel']) + self.seg_mask_3d = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path_3d']) + + @staticmethod + def _label_is_2d(id_map, la): # single label's zmask has same counts as its MIP + mask_3d = (id_map == la) + mask_mip = mask_3d.max(axis=-1) + return mask_3d.sum() == mask_mip.sum() + + def test_id_map_connects_z(self): + id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=True) + labels = np.unique(id_map.data)[1:] + is_2d = all([self._label_is_2d(id_map.data, la) for la in labels]) + self.assertFalse(is_2d) + + def test_id_map_disconnects_z(self): + id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) + labels = np.unique(id_map.data)[1:] + is_2d = all([self._label_is_2d(id_map.data, la) for la in labels]) + self.assertTrue(is_2d) + + def test_create_roiset_from_3d_obj_ids(self): + id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) + self.assertEqual(self.stack_ch_pa.shape, id_map.shape) + + roiset = RoiSet( + self.stack_ch_pa, + id_map, + params=RoiSetMetaParams(mask_type='contours') + ) + self.assertEqual(roiset.count, id_map.data.max()) + self.assertGreater(len(roiset.get_df()['zi'].unique()), 1) + + def test_create_roiset_from_2d_obj_ids(self): + id_map = _get_label_ids(self.seg_mask_3d, allow_3d=False) + self.assertEqual(self.stack_ch_pa.shape[0:3], id_map.shape[0:3]) + self.assertEqual(id_map.nz, 1) + + roiset = RoiSet( + self.stack_ch_pa, + id_map, + params=RoiSetMetaParams(mask_type='contours') + ) + self.assertEqual(roiset.count, id_map.data.max()) + self.assertGreater(len(roiset.get_df()['zi'].unique()), 1) + return roiset + + def test_create_roiset_from_df_and_patch_masks(self): + ref_roiset = self.test_create_roiset_from_2d_obj_ids() + where_ser = output_path / 'serialize' + ref_roiset.serialize(where_ser, prefix='ref') + where_df = where_ser / 'dataframe' / 'ref.csv' + self.assertTrue(where_df.exists()) + df_test = pd.read_csv(where_df) + + # check that patches are correct size + where_patch_masks = where_ser / 'tight_patch_masks' + patch_filenames = [] + for pmf in where_patch_masks.iterdir(): + self.assertTrue(pmf.suffix.upper() == '.PNG') + la = int(re.search(r'la([\d]+)', str(pmf)).group(1)) + roi_q = df_test.loc[df_test.label == la, :] + self.assertEqual(len(roi_q), 1) + roi = roi_q.iloc[0] + m_acc = generate_file_accessor(pmf) + self.assertEqual((roi.h, roi.w), m_acc.hw) + patch_filenames.append(pmf.name) + + # make another RoiSet from just the data table, raw images, and (tight) patch masks + test_roiset = RoiSet.deserialize(self.stack_ch_pa, where_ser, prefix='ref') + self.assertEqual(ref_roiset.get_zmask().shape, test_roiset.get_zmask().shape,) + self.assertTrue((ref_roiset.get_zmask() == test_roiset.get_zmask()).all()) + self.assertTrue(np.all(test_roiset.get_df().label == ref_roiset.get_df().label)) + cols = ['label', 'y1', 'y0', 'x1', 'x0', 'zi'] + self.assertTrue((test_roiset.get_df()[cols] == ref_roiset.get_df()[cols]).all().all()) + + # re-serialize and check that patch masks are the same + where_dser = output_path / 'deserialize' + test_roiset.serialize(where_dser, prefix='test') + for fr in patch_filenames: + pr = (where_ser / 'tight_patch_masks' / fr) + self.assertTrue(pr.exists()) + pt = (where_dser / 'tight_patch_masks' / fr.replace('ref', 'test')) + self.assertTrue(pt.exists()) + r_acc = generate_file_accessor(pr) + t_acc = generate_file_accessor(pt) + self.assertTrue(np.all(r_acc.data == t_acc.data)) + diff --git a/tests/test_session.py b/tests/test_session.py index 9679aad61c699de8f07ff0a31e2cc4750cb6f9e3..aafda3c27f7146edc9e5235e660bc311d579e938 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,72 +1,83 @@ +import json +from os.path import exists import pathlib import unittest + from model_server.base.models import DummySemanticSegmentationModel from model_server.base.session import Session +from model_server.base.workflows import WorkflowRunRecord class TestGetSessionObject(unittest.TestCase): def setUp(self) -> None: - pass + self.sesh = Session() + + def tearDown(self) -> None: + print('Tearing down...') + Session._instances = {} - def test_single_session_instance(self): - sesh = Session() - self.assertIs(sesh, Session(), 'Re-initializing Session class returned a new object') + def test_session_is_singleton(self): + Session._instances = {} + self.assertEqual(len(Session._instances), 0) + s = Session() + self.assertEqual(len(Session._instances), 1) + self.assertIs(s, Session()) + self.assertEqual(len(Session._instances), 1) - from os.path import exists - self.assertTrue(exists(sesh.session_log), 'Session did not create a log file in the correct place') - self.assertTrue(exists(sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place') + def test_session_logfile_is_valid(self): + self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place') def test_changing_session_root_creates_new_directory(self): from model_server.conf.defaults import root from shutil import rmtree - sesh = Session() - old_paths = sesh.get_paths() + old_paths = self.sesh.get_paths() newroot = root / 'subdir' - sesh.restart(root=newroot) - new_paths = sesh.get_paths() + self.sesh.restart(root=newroot) + new_paths = self.sesh.get_paths() for k in old_paths.keys(): self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__())) + + # this is necessary because logger itself is a singleton class + self.tearDown() + self.setUp() rmtree(newroot) self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory') def test_change_session_subdirectory(self): - sesh = Session() - old_paths = sesh.get_paths() + old_paths = self.sesh.get_paths() print(old_paths) - sesh.set_data_directory('outbound_images', old_paths['inbound_images']) - self.assertEqual(sesh.paths['outbound_images'], sesh.paths['inbound_images']) - - - def test_restart_session(self): - sesh = Session() - logfile1 = sesh.session_log - sesh.restart() - logfile2 = sesh.session_log + self.sesh.set_data_directory('outbound_images', old_paths['inbound_images']) + self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images']) + + def test_restarting_session_creates_new_logfile(self): + logfile1 = self.sesh.logfile + self.assertTrue(logfile1.exists()) + self.sesh.restart() + logfile2 = self.sesh.logfile + self.assertTrue(logfile2.exists()) self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile') - def test_session_records_workflow(self): - import json - from model_server.base.workflows import WorkflowRunRecord - sesh = Session() - di = WorkflowRunRecord( - model_id='test_model', - input_filepath='/test/input/directory', - output_filepath='/test/output/fi.le', - success=True, - timer_results={'start': 0.123}, - ) - sesh.record_workflow_run(di) - with open(sesh.manifest_json, 'r') as fh: - do = json.load(fh) - self.assertEqual(di.dict(), do, 'Manifest record is not correct') - + def test_log_warning(self): + msg = 'A test warning' + self.sesh.log_info(msg) + with open(self.sesh.logfile, 'r') as fh: + log = fh.read() + self.assertTrue(msg in log) + + def test_get_logs(self): + self.sesh.log_info('Info example 1') + self.sesh.log_warning('Example warning') + self.sesh.log_info('Info example 2') + logs = self.sesh.get_log_data() + self.assertEqual(len(logs), 4) + self.assertEqual(logs[1]['level'], 'WARNING') + self.assertEqual(logs[-1]['message'], 'Initialized session') def test_session_loads_model(self): - sesh = Session() MC = DummySemanticSegmentationModel - success = sesh.load_model(MC) + success = self.sesh.load_model(MC) self.assertTrue(success) - loaded_models = sesh.describe_loaded_models() + loaded_models = self.sesh.describe_loaded_models() self.assertTrue( (MC.__name__ + '_00') in loaded_models.keys() ) @@ -76,46 +87,56 @@ class TestGetSessionObject(unittest.TestCase): ) def test_session_loads_second_instance_of_same_model(self): - sesh = Session() MC = DummySemanticSegmentationModel - sesh.load_model(MC) - sesh.load_model(MC) - self.assertIn(MC.__name__ + '_00', sesh.models.keys()) - self.assertIn(MC.__name__ + '_01', sesh.models.keys()) - + self.sesh.load_model(MC) + self.sesh.load_model(MC) + self.assertIn(MC.__name__ + '_00', self.sesh.models.keys()) + self.assertIn(MC.__name__ + '_01', self.sesh.models.keys()) def test_session_loads_model_with_params(self): - sesh = Session() MC = DummySemanticSegmentationModel p1 = {'p1': 'abc'} - success = sesh.load_model(MC, params=p1) + success = self.sesh.load_model(MC, params=p1) self.assertTrue(success) - loaded_models = sesh.describe_loaded_models() + loaded_models = self.sesh.describe_loaded_models() mid = MC.__name__ + '_00' self.assertEqual(loaded_models[mid]['params'], p1) # load a second model and confirm that the first is locatable by its param entry p2 = {'p2': 'def'} - sesh.load_model(MC, params=p2) - find_mid = sesh.find_param_in_loaded_models('p1', 'abc') + self.sesh.load_model(MC, params=p2) + find_mid = self.sesh.find_param_in_loaded_models('p1', 'abc') self.assertEqual(mid, find_mid) - self.assertEqual(sesh.describe_loaded_models()[mid]['params'], p1) + self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1) def test_session_finds_existing_model_with_different_path_formats(self): - sesh = Session() MC = DummySemanticSegmentationModel p1 = {'path': 'c:\\windows\\dummy.pa'} p2 = {'path': 'c:/windows/dummy.pa'} - mid = sesh.load_model(MC, params=p1) + mid = self.sesh.load_model(MC, params=p1) assert pathlib.Path(p1['path']) == pathlib.Path(p2['path']) - find_mid = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True) + find_mid = self.sesh.find_param_in_loaded_models('path', p2['path'], is_path=True) self.assertEqual(mid, find_mid) def test_change_output_path(self): - import pathlib - sesh = Session() - pa = sesh.get_paths()['inbound_images'] + pa = self.sesh.get_paths()['inbound_images'] self.assertIsInstance(pa, pathlib.Path) - sesh.set_data_directory('outbound_images', pa.__str__()) - self.assertEqual(sesh.paths['inbound_images'], sesh.paths['outbound_images']) - self.assertIsInstance(sesh.paths['outbound_images'], pathlib.Path) \ No newline at end of file + self.sesh.set_data_directory('outbound_images', pa.__str__()) + self.assertEqual(self.sesh.paths['inbound_images'], self.sesh.paths['outbound_images']) + self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path) + + def test_make_table(self): + import pandas as pd + data = [{'modulo': i % 2, 'times one hundred': i * 100} for i in range(0, 8)] + self.sesh.write_to_table( + 'test_numbers', {'X': 0, 'Y': 0}, pd.DataFrame(data[0:4]) + ) + self.assertTrue(self.sesh.tables['test_numbers'].path.exists()) + self.sesh.write_to_table( + 'test_numbers', {'X': 1, 'Y': 1}, pd.DataFrame(data[4:8]) + ) + + dfv = pd.read_csv(self.sesh.tables['test_numbers'].path) + self.assertEqual(len(dfv), len(data)) + self.assertEqual(dfv.columns[0], 'X') + self.assertEqual(dfv.columns[1], 'Y')