From 8dbede736d7ed92eead21df0b3e555d47feed0a5 Mon Sep 17 00:00:00 2001 From: Christopher Randolph Rhodes <christopher.rhodes@embl.de> Date: Wed, 3 Apr 2024 12:44:07 +0000 Subject: [PATCH] Revert "Merge branch 'issue0027-serialize-roiset' into 'master'" This reverts merge request !16 --- model_server/base/accessors.py | 62 +---- model_server/base/api.py | 17 +- model_server/base/models.py | 12 +- model_server/base/roiset.py | 256 ++++++------------ model_server/base/session.py | 113 ++------ model_server/conf/defaults.py | 1 - model_server/conf/testing.py | 1 - model_server/extensions/ilastik/models.py | 150 ++++------ model_server/extensions/ilastik/router.py | 3 - .../extensions/ilastik/tests/test_ilastik.py | 76 +----- tests/test_accessors.py | 44 +-- tests/test_api.py | 9 +- tests/test_roiset.py | 244 +++-------------- tests/test_session.py | 145 +++++----- 14 files changed, 283 insertions(+), 850 deletions(-) diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index 12807bf2..07eba484 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -162,14 +162,6 @@ class CziImageFileAccessor(GenericImageFileAccessor): except Exception: raise FileAccessorError(f'Unable to access CZI data in {fpath}') - try: - md = cf.metadata(raw=False) - compmet = md['ImageDocument']['Metadata']['Information']['Image']['OriginalCompressionMethod'] - except KeyError: - raise InvalidCziCompression('Could not find metadata key OriginalCompressionMethod') - if compmet.upper() != 'UNCOMPRESSED': - raise InvalidCziCompression(f'Unsupported compression method {compmet}') - sd = {ch: cf.shape[cf.axes.index(ch)] for ch in cf.axes} if (sd.get('S') and (sd['S'] > 1)) or (sd.get('T') and (sd['T'] > 1)): raise DataShapeError(f'Cannot handle image with multiple positions or time points: {sd}') @@ -292,29 +284,6 @@ class PatchStack(InMemoryDataAccessor): def count(self): return self.shape_dict['P'] - def export_pyxcz(self, fpath: Path): - tzcyx = np.moveaxis( - self.pyxcz, # yxcz - [0, 4, 3, 1, 2], - [0, 1, 2, 3, 4] - ) - - if self.is_mask(): - if self.dtype == 'bool': - data = (tzcyx * 255).astype('uint8') - else: - data = tzcyx.astype('uint8') - tifffile.imwrite(fpath, data, imagej=True) - else: - tifffile.imwrite(fpath, tzcyx, imagej=True) - - def get_one_channel_data(self, channel: int, mip: bool = False): - c = int(channel) - if mip: - return PatchStack(self.pyxcz[:, :, :, c:(c + 1), :].max(axis=-1, keepdims=True)) - else: - return PatchStack(self.pyxcz[:, :, :, c:(c + 1), :]) - @property def shape_dict(self): return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) @@ -344,32 +313,16 @@ class PatchStack(InMemoryDataAccessor): return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) -def make_patch_stack_from_file(fpath): # interpret t-dimension as patch position +def make_patch_stack_from_file(fpath): # interpret z-dimension as patch position if not Path(fpath).exists(): raise FileNotFoundError(f'Could not find {fpath}') - try: - tf = tifffile.TiffFile(fpath) - except Exception: - raise FileAccessorError(f'Unable to access data in {fpath}') - - if len(tf.series) != 1: - raise DataShapeError(f'Expect only one series in {fpath}') - - se = tf.series[0] - - axs = [a for a in se.axes if a in [*'TZCYX']] - sd = dict(zip(axs, se.shape)) - for a in [*'TZC']: - if a not in axs: - sd[a] = 1 - tzcyx = se.asarray().reshape([sd[k] for k in [*'TZCYX']]) - - pyxcz = np.moveaxis( - tzcyx, - [0, 3, 4, 2, 1], - [0, 1, 2, 3, 4], + pyxc = np.moveaxis( + generate_file_accessor(fpath).data, # yxcz + [0, 1, 2, 3], + [1, 2, 3, 0] ) + pyxcz = np.expand_dims(pyxc, axis=3) return PatchStack(pyxcz) @@ -392,9 +345,6 @@ class FileWriteError(Error): class InvalidAxisKey(Error): pass -class InvalidCziCompression(Error): - pass - class InvalidDataShape(Error): pass diff --git a/model_server/base/api.py b/model_server/base/api.py index 752ed48a..8259a98c 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -46,7 +46,6 @@ def change_path(key, path): status_code=404, detail=e.__str__(), ) - session.log_info(f'Change {key} path to {path}') return session.get_paths() @app.put('/paths/watch_input') @@ -57,30 +56,22 @@ def watch_input_path(path: str): def watch_input_path(path: str): return change_path('outbound_images', path) -@app.get('/session/restart') +@app.get('/restart') def restart_session(root: str = None) -> dict: session.restart(root=root) return session.describe_loaded_models() -@app.get('/session/logs') -def list_session_log() -> list: - return session.get_log_data() - @app.get('/models') def list_active_models(): return session.describe_loaded_models() @app.put('/models/dummy_semantic/load/') def load_dummy_model() -> dict: - mid = session.load_model(DummySemanticSegmentationModel) - session.log_info(f'Loaded model {mid}') - return {'model_id': mid} + return {'model_id': session.load_model(DummySemanticSegmentationModel)} @app.put('/models/dummy_instance/load/') def load_dummy_model() -> dict: - mid = session.load_model(DummyInstanceSegmentationModel) - session.log_info(f'Loaded model {mid}') - return {'model_id': mid} + return {'model_id': session.load_model(DummyInstanceSegmentationModel)} @app.put('/workflows/segment') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: @@ -92,5 +83,5 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: session.paths['outbound_images'], channel=channel, ) - session.log_info(f'Completed segmentation of {input_filename}') + session.record_workflow_run(record) return record \ No newline at end of file diff --git a/model_server/base/models.py b/model_server/base/models.py index 4fa08490..8413f68d 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -3,7 +3,7 @@ from math import floor import numpy as np -from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, PatchStack +from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor class Model(ABC): @@ -88,16 +88,6 @@ class InstanceSegmentationModel(ImageToImageModel): if not img.shape == mask.shape: raise InvalidInputImageError('Expect input image and mask to be the same shape') - def label_patch_stack(self, img: PatchStack, mask: PatchStack, **kwargs): - """ - Iterative over a patch stack, call inference on each patch - :return: PatchStack of same shape in input - """ - res_data = np.zeros(img.shape, dtype='uint16') - for i in range(0, img.count): # interpret as PYXCZ - res_data[i, :, :, :, :] = self.label_instance_class(img.iat(i), mask.iat(i), **kwargs).data - return PatchStack(res_data) - class DummySemanticSegmentationModel(SemanticSegmentationModel): diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index d537e92a..c44023db 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -1,6 +1,5 @@ from math import sqrt, floor from pathlib import Path -import re from typing import List, Union from uuid import uuid4 @@ -18,7 +17,7 @@ from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAc from model_server.base.models import InstanceSegmentationModel from model_server.base.process import pad, rescale, resample_to_8bit, make_rgb from model_server.base.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image -from model_server.base.accessors import generate_file_accessor, PatchStack +from model_server.base.accessors import PatchStack from model_server.base.process import mask_largest_object @@ -45,6 +44,7 @@ class RoiFilterRange(BaseModel): class RoiFilter(BaseModel): area: Union[RoiFilterRange, None] = None + solidity: Union[RoiFilterRange, None] = None class RoiSetMetaParams(BaseModel): @@ -57,39 +57,16 @@ class RoiSetExportParams(BaseModel): patches_3d: Union[PatchParams, None] = None annotated_patches_2d: Union[PatchParams, None] = None patches_2d: Union[PatchParams, None] = None + patch_masks: Union[PatchParams, None] = None annotated_zstacks: Union[AnnotatedZStackParams, None] = None object_classes: bool = False + dataframe: bool = False -def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connect_3d=True) -> InMemoryDataAccessor: - """ - Convert binary segmentation mask into either a 2D or 3D object identities map - :param acc_seg_mask: binary segmentation mask (mono) of either two or three dimensions - :param allow_3d: return a 3D map if True; return a 2D map of the mask's maximum intensity project if False - :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False - :return: object identities map - """ - if allow_3d and connect_3d: - nda_la = label( - acc_seg_mask.data[:, :, 0, :] - ).astype('uint16') - return InMemoryDataAccessor(np.expand_dims(nda_la, 2)) - elif allow_3d and not connect_3d: - nla = 0 - la_3d = np.zeros((*acc_seg_mask.hw, 1, acc_seg_mask.nz), dtype='uint16') - for zi in range(0, acc_seg_mask.nz): - la_2d = label(acc_seg_mask.data[:, :, 0, zi]).astype('uint16') - la_2d[la_2d > 0] = la_2d[la_2d > 0] + nla - nla = la_2d.max() - la_3d[:, :, 0, zi] = la_2d - return InMemoryDataAccessor(la_3d) - else: - return InMemoryDataAccessor( - label( - acc_seg_mask.data[:, :, 0, :].max(axis=-1) - ).astype('uint16') - ) + +def _get_label_ids(acc_seg_mask: GenericImageDataAccessor) -> InMemoryDataAccessor: + return InMemoryDataAccessor(label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16')) def _focus_metrics(): @@ -132,6 +109,7 @@ class RoiSet(object): :param params: optional arguments that influence the definition and representation of ROIs """ assert acc_obj_ids.chroma == 1 + assert acc_obj_ids.nz == 1 self.acc_obj_ids = acc_obj_ids self.acc_raw = acc_raw self.params = params @@ -157,32 +135,26 @@ class RoiSet(object): :param acc_raw: accessor to raw image data :param acc_obj_ids: accessor to map of object IDs :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary) - # :param deproject: assign object's z-position based on argmax of raw data if True :return: pd.DataFrame """ # build dataframe of objects, assign z index to each object - - if acc_obj_ids.nz == 1: # deproject objects' z-coordinates from argmax of raw image - df = pd.DataFrame(regionprops_table( - acc_obj_ids.data[:, :, 0, 0], - intensity_image=acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'), - properties=('label', 'area', 'intensity_mean', 'bbox', 'centroid') - )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'}) - df['zi'] = df['intensity_mean'].round().astype('int') - - else: # objects' z-coordinates come from arg of max count in object identities map - df = pd.DataFrame(regionprops_table( - acc_obj_ids.data[:, :, 0, :], - properties=('label', 'area', 'bbox', 'centroid') - )).rename(columns={ - 'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1' - }) - df['zi'] = df['label'].apply(lambda x: (acc_obj_ids.data == x).sum(axis=(0, 1, 2)).argmax()) + argmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16') + df = ( + pd.DataFrame( + regionprops_table( + acc_obj_ids.data[:, :, 0, 0], + intensity_image=argmax, + properties=('label', 'area', 'intensity_mean', 'solidity', 'bbox', 'centroid') + ) + ) + .rename( + columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1', } + ) + ) + df['zi'] = df['intensity_mean'].round().astype('int') # compute expanded bounding boxes h, w, c, nz = acc_raw.shape - df['h'] = df['y1'] - df['y0'] - df['w'] = df['x1'] - df['x0'] ebxy, ebz = expand_box_by df['ebb_y0'] = (df.y0 - ebxy).apply(lambda x: max(x, 0)) df['ebb_y1'] = (df.y1 + ebxy).apply(lambda x: min(x, h)) @@ -204,11 +176,6 @@ class RoiSet(object): assert np.all(df['rel_y1'] <= (df['ebb_y1'] - df['ebb_y0'])) df['slice'] = df.apply( - lambda r: - np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1)], - axis=1 - ) - df['expanded_slice'] = df.apply( lambda r: np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1], axis=1 @@ -218,18 +185,19 @@ class RoiSet(object): np.s_[int(r.rel_y0): int(r.rel_y1), int(r.rel_x0): int(r.rel_x1), :, :], axis=1 ) - df['binary_mask'] = df.apply( - lambda r: (acc_obj_ids.data == r.label).max(axis=-1)[r.y0: r.y1, r.x0: r.x1, 0], + df['mask'] = df.apply( + lambda r: (acc_obj_ids.data == r.label)[r.y0: r.y1, r.x0: r.x1, 0, 0], axis=1 ) return df + @staticmethod def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: query_str = 'label > 0' # always true if filters is not None: # parse filters for k, val in filters.dict(exclude_unset=True).items(): - assert k in ('area') + assert k in ('area', 'solidity') vmin = val['min'] vmax = val['max'] assert vmin >= 0 @@ -258,18 +226,18 @@ class RoiSet(object): projected = self.acc_raw.data.max(axis=-1) return projected - def get_patches_acc(self, channel=None, **kwargs): # padded, un-annotated 2d patches + def get_raw_patches(self, channel=None, pad_to=256, make_3d=False): # padded, un-annotated 2d patches if channel: - patches_df = self.get_patches(white_channel=channel, **kwargs) + patches_df = self.get_patches(white_channel=channel, pad_to=pad_to) else: - patches_df = self.get_patches(**kwargs) - return PatchStack(list(patches_df.patch)) + patches_df = self.get_patches(pad_to=pad_to) + patches = list(patches_df['patch']) + return PatchStack(patches) - def export_annotated_zstack(self, where, prefix='zstack', **kwargs) -> Path: + def export_annotated_zstack(self, where, prefix='zstack', **kwargs): annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs)) - fp = where / (prefix + '.tif') - write_accessor_data_to_file(fp, annotated) - return fp + success = write_accessor_data_to_file(where / (prefix + '.tif'), annotated) + return {'location': where.__str__(), 'filename': prefix + '.tif'} def get_zmask(self, mask_type='boxes'): """ @@ -303,7 +271,7 @@ class RoiSet(object): elif mask_type == 'boxes': for roi in self: - zi_st[roi.slice] = True + zi_st[roi.relative_slice] = 1 return zi_st @@ -311,9 +279,9 @@ class RoiSet(object): def classify_by(self, name: str, channel: int, object_classification_model: InstanceSegmentationModel, ): # do this on a patch basis, i.e. only one object per frame - obmap_patches = object_classification_model.label_patch_stack( - self.get_patches_acc(channel=channel, expaned=False, pad_to=None), - self.get_patch_masks_acc(expanded=False, pad_to=None) + obmap_patches = object_classification_model.label_instance_class( + self.get_raw_patches(channel=channel), + self.get_patch_masks() ) om = np.zeros(self.acc_obj_ids.shape, self.acc_obj_ids.dtype) @@ -331,29 +299,20 @@ class RoiSet(object): om[self.acc_obj_ids.data == roi.label] = oc self.object_class_maps[name] = InMemoryDataAccessor(om) - def export_dataframe(self, csv_path: Path): - csv_path.parent.mkdir(parents=True, exist_ok=True) - self._df.drop(['expanded_slice', 'slice', 'relative_slice', 'binary_mask'], axis=1).to_csv(csv_path, index=False) - return csv_path - - - def export_patch_masks(self, where: Path, pad_to: int = None, prefix='mask', expanded=False) -> list: - patches_df = self.get_patch_masks(pad_to=pad_to, expanded=expanded) + def export_patch_masks(self, where: Path, pad_to: int = 256, prefix='mask', **kwargs) -> list: + patches_acc = self.get_patch_masks(pad_to=pad_to) exported = [] - def _export_patch_mask(roi): - patch = InMemoryDataAccessor(roi.patch_mask) + for i, roi in enumerate(self): # assumes index of patches_acc is same as dataframe + patch = patches_acc.iat_yxcz(i) ext = 'png' fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' write_accessor_data_to_file(where / fname, patch) exported.append(fname) - - for roi in patches_df.itertuples(): # just used for label info - _export_patch_mask(roi) return exported - def export_patches(self, where: Path, prefix='patch', **kwargs) -> list: + def export_patches(self, where: Path, prefix='patch', **kwargs): make_3d = kwargs.get('make_3d', False) patches_df = self.get_patches(**kwargs) @@ -368,7 +327,11 @@ class RoiSet(object): else: write_accessor_data_to_file(where / fname, patch) - exported.append(where / fname) + exported.append({ + 'df_index': roi.Index, + 'patch_filename': fname, + 'location': where.__str__(), + }) exported = [] for roi in patches_df.itertuples(): # just used for label info @@ -376,36 +339,28 @@ class RoiSet(object): return exported - def get_patch_masks(self, pad_to: int = None, expanded: bool = False) -> pd.DataFrame: - def _make_patch_mask(roi): - if expanded: - patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') - patch[roi.relative_slice][:, :, 0, 0] = roi.binary_mask * 255 - else: - patch = np.zeros((roi.y1 - roi.y0, roi.x1 - roi.x0, 1, 1), dtype='uint8') - patch[:, :, 0, 0] = roi.binary_mask * 255 + def get_patch_masks(self, pad_to: int = 256) -> PatchStack: + patches = [] + for roi in self: + patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') + patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255 if pad_to: patch = pad(patch, pad_to) - return patch - dfe = self._df.copy() - dfe['patch_mask'] = dfe.apply(lambda r: _make_patch_mask(r), axis=1) - return dfe + patches.append(patch) + return PatchStack(patches) - def get_patch_masks_acc(self, **kwargs) -> PatchStack: - return PatchStack(list(self.get_patch_masks(**kwargs).patch_mask)) def get_patches( self, rescale_clip: float = 0.0, - pad_to: int = None, + pad_to: int = 256, make_3d: bool = False, focus_metric: str = None, rgb_overlay_channels: list = None, rgb_overlay_weights: list = [1.0, 1.0, 1.0], white_channel: int = None, - expanded=False, **kwargs ) -> pd.DataFrame: @@ -452,17 +407,12 @@ class RoiSet(object): stack = raw.data def _make_patch(roi): - if expanded: - patch3d = stack[roi.expanded_slice] - subpatch = patch3d[roi.relative_slice] - else: - patch3d = stack[roi.slice] - subpatch = patch3d - + patch3d = stack[roi.slice] ph, pw, pc, pz = patch3d.shape + subpatch = patch3d[roi.relative_slice] # make a 3d patch - if make_3d or not expanded: + if make_3d: patch = patch3d # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately @@ -483,16 +433,10 @@ class RoiSet(object): assert len(patch.shape) == 4 - mask = np.zeros(patch3d.shape[0:2], dtype=bool) - if expanded: - mask[roi.relative_slice[0:2]] = roi.binary_mask - else: - mask = roi.binary_mask - if rescale_clip is not None: patch = rescale(patch, rescale_clip) - if kwargs.get('draw_bounding_box') is True and expanded: + if kwargs.get('draw_bounding_box') is True: bci = kwargs.get('bounding_box_channel', 0) assert bci < 3 if bci > 0: @@ -506,15 +450,15 @@ class RoiSet(object): if kwargs.get('draw_mask'): mci = kwargs.get('mask_channel', 0) - # mask = np.zeros(patch.shape[0:2], dtype=bool) - # mask[roi.relative_slice[0:2]] = roi.binary_mask + mask = np.zeros(patch.shape[0:2], dtype=bool) + mask[roi.relative_slice[0:2]] = roi.mask for zi in range(0, patch.shape[3]): patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi] if kwargs.get('draw_contour'): mci = kwargs.get('contour_channel', 0) - # mask = np.zeros(patch.shape[0:2], dtype=bool) - # mask[roi.relative_slice[0:2]] = roi.binary_mask + mask = np.zeros(patch.shape[0:2], dtype=bool) + mask[roi.relative_slice[0:2]] = roi.mask for zi in range(0, patch.shape[3]): patch[:, :, mci, zi] = draw_contours_on_patch( @@ -522,12 +466,12 @@ class RoiSet(object): find_contours(mask) ) - if pad_to and expanded: + if pad_to: patch = pad(patch, pad_to) return patch - dfe = self._df.copy() - dfe['patch'] = dfe.apply(lambda r: _make_patch(r), axis=1) + dfe = self._df + dfe['patch'] = self._df.apply(lambda r: _make_patch(r), axis=1) return dfe def run_exports(self, where: Path, channel, prefix, params: RoiSetExportParams) -> dict: @@ -537,15 +481,12 @@ class RoiSet(object): :param channel: color channel of products to export :param prefix: prefix of the name of each product's file or subfolder :param params: RoiSetExportParams object describing which products to export and with which parameters - :return: nested dict of Path objects describing the location of export products + :return: dict of Path objects describing the location of single-file export products """ record = {} if not self.count: return - - # export dataframe and patch masks - record = self.serialize(where, prefix=prefix) - + raw_ch = self.acc_raw.get_one_channel_data(channel) for k in params.dict().keys(): subdir = where / k pr = prefix @@ -553,11 +494,11 @@ class RoiSet(object): if kp is None: continue if k == 'patches_3d': - record[k] = self.export_patches( - subdir, white_channel=channel, prefix=pr, make_3d=True, expanded=True, **kp + files = self.export_patches( + subdir, white_channel=channel, prefix=pr, make_3d=True, **kp ) if k == 'annotated_patches_2d': - record[k] = self.export_patches( + files = self.export_patches( subdir, prefix=pr, make_3d=False, white_channel=channel, bounding_box_channel=1, bounding_box_linewidth=2, **kp, ) @@ -568,53 +509,22 @@ class RoiSet(object): df_patches = pd.DataFrame(files) self._df = pd.merge(self._df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1) - record[k] = files + if k == 'patch_masks': + self.export_patch_masks(subdir, prefix=pr, **kp) if k == 'annotated_zstacks': - record[k] = self.export_annotated_zstack(subdir, prefix=pr, **kp) + self.export_annotated_zstack(subdir, prefix=pr, **kp) if k == 'object_classes': for kc, acc in self.object_class_maps.items(): fp = subdir / kc / (pr + '.tif') write_accessor_data_to_file(fp, acc) record[f'{k}_{kc}'] = fp - - return record - - def serialize(self, where: Path, prefix='') -> dict: - """ - Export the minimal information needed to recreate RoiSet object, i.e. CSV data file and tight patch masks - :param where: path of directory in which to write files - :param prefix: (optional) prefix - :return: nested dict of Path objects describing the locations of export products - """ - record = {} - record['dataframe'] = self.export_dataframe(where / 'dataframe' / (prefix + '.csv')) - record['tight_patch_masks'] = self.export_patch_masks( - where / 'tight_patch_masks', - prefix=prefix, - pad_to=None, - expanded=False - ) + if k == 'dataframe': + dfpa = subdir / (pr + '.csv') + dfpa.parent.mkdir(parents=True, exist_ok=True) + self._df.to_csv(dfpa, index=False) + record[k] = dfpa return record - @staticmethod - def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix=''): - df = pd.read_csv(where / 'dataframe' / (prefix + '.csv'))[['label', 'zi', 'y0', 'y1', 'x0', 'x1']] - - id_mask = np.zeros((*acc_raw.hw, 1, acc_raw.nz), dtype='uint16') - def _label_obj(r): - sl = np.s_[r.y0:r.y1, r.x0:r.x1, :, r.zi:r.zi + 1] - ext = 'png' - fname = f'{prefix}-la{r.label:04d}-zi{r.zi:04d}.{ext}' - try: - ma_acc = generate_file_accessor(where / 'tight_patch_masks' / fname) - bool_mask = ma_acc.data / np.iinfo(ma_acc.data.dtype).max - id_mask[sl] = r.label * bool_mask - except Exception as e: - raise DeserializeRoiSet(e) - - df.apply(_label_obj, axis=1) - return RoiSet(acc_raw, InMemoryDataAccessor(id_mask)) - def project_stack_from_focal_points( xx: np.ndarray, @@ -663,9 +573,3 @@ def project_stack_from_focal_points( ) - -class Error(Exception): - pass - -class DeserializeRoiSet(Error): - pass \ No newline at end of file diff --git a/model_server/base/session.py b/model_server/base/session.py index 3d5b789d..1b91b12a 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -1,82 +1,36 @@ -import logging +import json import os from pathlib import Path from time import strftime, localtime from typing import Dict -import pandas as pd - import model_server.conf.defaults from model_server.base.models import Model +from model_server.base.workflows import WorkflowRunRecord -logger = logging.getLogger(__name__) - - -class Singleton(type): - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) - return cls._instances[cls] - -class CsvTable(object): - def __init__(self, fpath: Path): - self.path = fpath - self.empty = True - - def append(self, coords: dict, data: pd.DataFrame) -> bool: - assert isinstance(data, pd.DataFrame) - for c in reversed(coords.keys()): - data.insert(0, c, coords[c]) - if self.empty: - data.to_csv(self.path, index=False, mode='w', header=True) - else: - data.to_csv(self.path, index=False, mode='a', header=False) - self.empty = False - return True +def create_manifest_json(): + pass -class Session(object, metaclass=Singleton): +class Session(object): """ Singleton class for a server session that persists data between API calls """ - log_format = '%(asctime)s - %(levelname)s - %(message)s' + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(Session, cls).__new__(cls) + return cls.instance def __init__(self, root: str = None): print('Initializing session') self.models = {} # model_id : model object + self.manifest = [] # paths to data as well as other metadata from each inference run self.paths = self.make_paths(root) - - self.logfile = self.paths['logs'] / f'session.log' - logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True, format=self.log_format) - - self.log_info('Initialized session') - self.tables = {} - - def write_to_table(self, name: str, coords: dict, data: pd.DataFrame): - """ - Write data to a named data table, initializing if it does not yet exist. - :param name: name of the table to persist through session - :param coords: dictionary of coordinates to associate with all rows in this method call - :param data: DataFrame containing data - :return: True if successful - """ - try: - if name in self.tables.keys(): - table = self.tables.get(name) - else: - table = CsvTable(self.paths['tables'] / (name + '.csv')) - self.tables[name] = table - except Exception: - raise CouldNotCreateTable(f'Unable to create table named {name}') - - try: - table.append(coords, data) - return True - except Exception: - raise CouldNotAppendToTable(f'Unable to append data to table named {name}') + self.session_log = self.paths['logs'] / f'session.log' + self.log_event('Initialized session') + self.manifest_json = self.paths['logs'] / f'manifest.json' + open(self.manifest_json, 'w').close() # instantiate empty json file def get_paths(self): return self.paths @@ -101,7 +55,7 @@ class Session(object, metaclass=Singleton): root_path = Path(root) sid = Session.create_session_id(root_path) paths = {'root': root_path} - for pk in ['inbound_images', 'outbound_images', 'logs', 'tables']: + for pk in ['inbound_images', 'outbound_images', 'logs']: pa = root_path / sid / model_server.conf.defaults.subdirectories[pk] paths[pk] = pa try: @@ -121,23 +75,22 @@ class Session(object, metaclass=Singleton): idx += 1 return f'{yyyymmdd}-{idx:04d}' - def get_log_data(self) -> list: - log = [] - with open(self.logfile, 'r') as fh: - for line in fh: - k = ['datatime', 'level', 'message'] - v = line.strip().split(' - ')[0:3] - log.insert(0, dict(zip(k, v))) - return log - - def log_info(self, msg): - logger.info(msg) + def log_event(self, event: str): + """ + Write an event string to this session's log file. + """ + timestamp = strftime('%m/%d/%Y, %H:%M:%S', localtime()) + with open(self.session_log, 'a') as fh: + fh.write(f'{timestamp} -- {event}\n') - def log_warning(self, msg): - logger.warning(msg) + def record_workflow_run(self, record: WorkflowRunRecord or None): + """ + Append a JSON describing inference data to this session's manifest + """ + self.log_event(f'Ran model {record.model_id} on {record.input_filepath} to infer {record.output_filepath}') + with open(self.manifest_json, 'w+') as fh: + json.dump(record.dict(), fh) - def log_error(self, msg): - logger.error(msg) def load_model(self, ModelClass: Model, params: Dict[str, str] = None) -> dict: """ @@ -161,7 +114,7 @@ class Session(object, metaclass=Singleton): 'object': mi, 'params': params } - self.log_info(f'Loaded model {key}') + self.log_event(f'Loaded model {key}') return key def describe_loaded_models(self) -> dict: @@ -204,11 +157,5 @@ class CouldNotInstantiateModelError(Error): class CouldNotCreateDirectory(Error): pass -class CouldNotCreateTable(Error): - pass - -class CouldNotAppendToTable(Error): - pass - class InvalidPathError(Error): pass \ No newline at end of file diff --git a/model_server/conf/defaults.py b/model_server/conf/defaults.py index bdf7cfd0..55b114f2 100644 --- a/model_server/conf/defaults.py +++ b/model_server/conf/defaults.py @@ -6,7 +6,6 @@ subdirectories = { 'logs': 'logs', 'inbound_images': 'images/inbound', 'outbound_images': 'images/outbound', - 'tables': 'tables', } server_conf = { diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py index 97a13aaf..3d07931e 100644 --- a/model_server/conf/testing.py +++ b/model_server/conf/testing.py @@ -67,7 +67,6 @@ roiset_test_data = { 'c': 5, 'z': 7, 'mask_path': root / 'zmask-test-stack-mask.tif', - 'mask_path_3d': root / 'zmask-test-stack-mask-3d.tif', }, 'pipeline_params': { 'segmentation_channel': 0, diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 7229b43e..de25566a 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -1,4 +1,3 @@ -import json import os from pathlib import Path @@ -13,17 +12,8 @@ from model_server.base.models import Model, ImageToImageModel, InstanceSegmentat class IlastikModel(Model): - def __init__(self, params, autoload=True, enforce_embedded=True): - """ - Base class for models that run via ilastik shell API - :param params: - project_file: path to ilastik project file - :param autoload: automatically load model into memory if true - :param enforce_embedded: - raise an error if all input data are not embedded in the project file, i.e. on the filesystem - """ + def __init__(self, params, autoload=True): self.project_file = Path(params['project_file']) - self.enforce_embedded = enforce_embedded params['project_file'] = self.project_file.__str__() if self.project_file.is_absolute(): pap = self.project_file @@ -52,15 +42,6 @@ class IlastikModel(Model): args.project = self.project_file_abspath.__str__() shell = app.main(args, init_logging=False) - # validate if inputs are embedded in project file - input_groups = shell.projectManager.currentProjectFile['Input Data']['infos'] - lanes = input_groups.keys() - for ll in lanes: - input_types = input_groups[ll] - for tt in input_types: - ds_loc = input_groups[ll][tt].get('location', False) - if self.enforce_embedded and ds_loc and ds_loc[()] == b'FileSystem': - raise IlastikInputEmbedding('Cannot load ilastik project file where inputs are on filesystem') if not isinstance(shell.workflow, self.get_workflow()): raise ParameterExpectedError( f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}' @@ -74,35 +55,12 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): model_id = 'ilastik_pixel_classification' operations = ['segment', ] - @property - def model_shape_dict(self): - raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data'] - ax = raw_info['axistags'][()] - ax_keys = [ax['key'].upper() for ax in json.loads(ax)['axes']] - shape = raw_info['shape'][()] - dd = dict(zip(ax_keys, shape)) - for ci in 'TCZ': - if ci not in dd.keys(): - dd[ci] = 1 - return dd - - @property - def model_chroma(self): - return self.model_shape_dict['C'] - - @property - def model_3d(self): - return self.model_shape_dict['Z'] > 1 - @staticmethod def get_workflow(): from ilastik.workflows import PixelClassificationWorkflow return PixelClassificationWorkflow def infer(self, input_img: GenericImageDataAccessor) -> (np.ndarray, dict): - if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): - raise IlastikInputShapeError() - tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') dsi = [ { @@ -129,33 +87,22 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel): model_id = 'ilastik_object_classification_from_segmentation' - @staticmethod - def _make_8bit_mask(nda): - if nda.dtype == 'bool': - return 255 * nda.astype('uint8') - else: - return nda - @staticmethod def get_workflow(): from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary return ObjectClassificationWorkflowBinary def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): + tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') assert segmentation_img.is_mask() - if isinstance(input_img, PatchStack): - assert isinstance(segmentation_img, PatchStack) - tagged_input_data = vigra.taggedView(input_img.pczyx, 'tczyx') + if segmentation_img.dtype == 'bool': + seg = 255 * segmentation_img.data.astype('uint8') tagged_seg_data = vigra.taggedView( - self._make_8bit_mask(segmentation_img.pczyx), - 'tczyx' - ) - else: - tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') - tagged_seg_data = vigra.taggedView( - self._make_8bit_mask(segmentation_img.data), + 255 * segmentation_img.data.astype('uint8'), 'yxcz' ) + else: + tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') dsi = [ { @@ -168,21 +115,12 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment assert len(obmaps) == 1, 'ilastik generated more than one object map' - - if isinstance(input_img, PatchStack): - pyxcz = np.moveaxis( - obmaps[0], - [0, 1, 2, 3, 4], - [0, 4, 1, 2, 3] - ) - return PatchStack(data=pyxcz), {'success': True} - else: - yxcz = np.moveaxis( - obmaps[0], - [1, 2, 3, 0], - [0, 1, 2, 3] - ) - return InMemoryDataAccessor(data=yxcz), {'success': True} + yxcz = np.moveaxis( + obmaps[0], + [1, 2, 3, 0], + [0, 1, 2, 3] + ) + return InMemoryDataAccessor(data=yxcz), {'success': True} def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) @@ -234,40 +172,48 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag """ if not img.shape == pxmap.shape: raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape') + # TODO: check that pxmap is in-range pxch = kwargs.get('pixel_classification_channel', 0) - pxtr = kwargs.get('pixel_classification_threshold', 0.5) + pxtr = kwargs('pixel_classification_threshold', 0.5) mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr) + # super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) obmap, _ = self.infer(img, mask) return obmap - def make_instance_segmentation_model(self, px_ch: int): - """ - Generate an instance segmentation model, i.e. one that takes binary masks instead of pixel probabilities as a - second input. - :param px_ch: channel of pixel probability map to use - :return: - InstanceSegmentationModel object - """ - class _Mod(self.__class__, InstanceSegmentationModel): - def label_instance_class( - self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs - ) -> GenericImageDataAccessor: - if mask.dtype == 'bool': - norm_mask = 1.0 * mask.data - else: - norm_mask = mask.data / np.iinfo(mask.dtype).max - norm_mask_acc = InMemoryDataAccessor(norm_mask.astype('float32')) - return super().label_instance_class(img, norm_mask_acc, pixel_classification_channel=px_ch) - return _Mod(params={'project_file': self.project_file}) +class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel): + """ + Wrap ilastik object classification for inputs comprising single-object series of raw images and binary + segmentation masks. + """ + + def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict): + assert segmentation_acc.is_mask() + if not input_acc.chroma == 1: + raise InvalidInputImageError('Object classifier expects only monochrome patches') + if not input_acc.nz == 1: + raise InvalidInputImageError('Object classifier expects only 2d patches') + + tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx') + tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx') + + dsi = [ + { + 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), + 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), + } + ] + obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] -class Error(Exception): - pass + assert len(obmaps) == 1, 'ilastik generated more than one object map' -class IlastikInputEmbedding(Error): - pass + # for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C + assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1) + pyxcz = np.moveaxis( + obmaps[0], + [0, 1, 2, 3, 4], + [0, 4, 1, 2, 3] + ) -class IlastikInputShapeError(Error): - """Raised when an ilastik classifier is asked to infer on data that is incompatible with its input shape""" - pass \ No newline at end of file + return PatchStack(data=pyxcz), {'success': True} \ No newline at end of file diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py index 151e3791..3411ea7c 100644 --- a/model_server/extensions/ilastik/router.py +++ b/model_server/extensions/ilastik/router.py @@ -25,11 +25,9 @@ def load_ilastik_model(model_class: ilm.IlastikModel, project_file: str, duplica if not duplicate: existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True) if existing_model_id is not None: - session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate') return {'model_id': existing_model_id} try: result = session.load_model(model_class, {'project_file': project_file}) - session.log_info(f'Loaded ilastik model {result} from {project_file}') except (FileNotFoundError, ParameterExpectedError): raise HTTPException( status_code=404, @@ -62,7 +60,6 @@ def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: st channel=channel, mip=mip, ) - session.log_info(f'Completed pixel and object classification of {input_filename}') except AssertionError: raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}') return record \ No newline at end of file diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index d042e571..32dd1376 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -11,9 +11,6 @@ from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams from model_server.base.workflows import classify_pixels from tests.test_api import TestServerBaseClass -def _random_int(*args): - return np.random.randint(0, 2 ** 8, size=args, dtype='uint8') - class TestIlastikPixelClassification(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) @@ -86,40 +83,6 @@ class TestIlastikPixelClassification(unittest.TestCase): self.mono_image = mono_image self.mask = mask - def test_pixel_classifier_enforces_input_shape(self): - model = ilm.IlastikPixelClassifierModel( - {'project_file': ilastik_classifiers['px']} - ) - self.assertEqual(model.model_chroma, 1) - self.assertEqual(model.model_3d, False) - - # correct data - self.assertIsInstance( - model.label_pixel_class( - InMemoryDataAccessor( - _random_int(512, 256, 1, 1) - ) - ), - InMemoryDataAccessor - ) - - # raise except with input of multiple channels - with self.assertRaises(ilm.IlastikInputShapeError): - mask = model.label_pixel_class( - InMemoryDataAccessor( - _random_int(512, 256, 3, 1) - ) - ) - - # raise except with input of multiple channels - with self.assertRaises(ilm.IlastikInputShapeError): - mask = model.label_pixel_class( - InMemoryDataAccessor( - _random_int(512, 256, 1, 15) - ) - ) - - def test_run_object_classifier_from_pixel_predictions(self): self.test_run_pixel_classifier() fp = czifile['path'] @@ -134,24 +97,7 @@ class TestIlastikPixelClassification(unittest.TestCase): objmap, ) ) - self.assertEqual(objmap.data.max(), 2) - - def test_make_seg_obj_model_from_pxmap_obj(self): - self.test_run_pixel_classifier() - fp = czifile['path'] - pxmap_model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( - {'project_file': ilastik_classifiers['pxmap_to_obj']} - ) - seg_model = pxmap_model.make_instance_segmentation_model(px_ch=0) - objmap = seg_model.label_instance_class(self.mono_image, self.mask) - - self.assertTrue( - write_accessor_data_to_file( - output_path / f'obmap_seg_from_pxmap_{fp.stem}.tif', - objmap, - ) - ) - self.assertEqual(objmap.data.max(), 2) + self.assertEqual(objmap.data.max(), 3) def test_run_object_classifier_from_segmentation(self): self.test_run_pixel_classifier() @@ -167,7 +113,7 @@ class TestIlastikPixelClassification(unittest.TestCase): objmap, ) ) - self.assertEqual(objmap.data.max(), 2) + self.assertEqual(objmap.data.max(), 3) def test_ilastik_pixel_classification_as_workflow(self): result = classify_pixels( @@ -222,7 +168,7 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd) def test_no_duplicate_model_with_different_path_formats(self): - self._get('session/restart') + self._get('restart') resp_list_1 = self._get('models').json() self.assertEqual(len(resp_list_1), 0) ilp = ilastik_classifiers['px'] @@ -323,18 +269,16 @@ class TestIlastikObjectClassification(unittest.TestCase): ) ) - self.object_classifier = ilm.IlastikObjectClassifierFromSegmentationModel( + self.object_classifier = ilm.PatchStackObjectClassifier( params={'project_file': ilastik_classifiers['seg_to_obj']} ) - def test_classify_patches(self): - raw_patches = self.roiset.get_patches_acc() - patch_masks = self.roiset.get_patch_masks_acc() - res_patches = self.object_classifier.label_instance_class(raw_patches, patch_masks) + raw_patches = self.roiset.get_raw_patches() + patch_masks = self.roiset.get_patch_masks() + res_patches, _ = self.object_classifier.infer(raw_patches, patch_masks) self.assertEqual(res_patches.count, self.roiset.count) - res_patches.export_pyxcz(output_path / 'res_patches.tif') for pi in range(0, res_patches.count): # assert that there is only one nonzero label per patch - la, ct = np.unique(res_patches.iat(pi).data, return_counts=True) - self.assertEqual(np.sum(ct > 1), 2) # exclude single-pixel anomaly - self.assertEqual(la[0], 0) + unique = np.unique(res_patches.iat(pi).data) + self.assertEqual(len(unique), 2) + self.assertEqual(unique[0], 0) diff --git a/tests/test_accessors.py b/tests/test_accessors.py index e84788d2..bc5b4065 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -7,9 +7,6 @@ from model_server.base.accessors import PatchStack, make_patch_stack_from_file, from model_server.conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile, monozstackmask from model_server.base.accessors import CziImageFileAccessor, DataShapeError, generate_file_accessor, InMemoryDataAccessor, PngFileAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor -def _random_int(*args): - return np.random.randint(0, 2 ** 8, size=args, dtype='uint8') - class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: @@ -43,7 +40,7 @@ class TestCziImageFileAccess(unittest.TestCase): nc = 4 nz = 11 c = 3 - cf = InMemoryDataAccessor(_random_int(h, w, nc, nz)) + cf = InMemoryDataAccessor(np.random.rand(h, w, nc, nz)) sc = cf.get_one_channel_data(c) self.assertEqual(sc.shape, (h, w, 1, nz)) @@ -73,7 +70,7 @@ class TestCziImageFileAccess(unittest.TestCase): def test_conform_data_shorter_than_xycz(self): h = 256 w = 512 - data = _random_int(h, w, 1) + data = np.random.rand(h, w, 1) acc = InMemoryDataAccessor(data) self.assertEqual( InMemoryDataAccessor.conform_data(data).shape, @@ -85,7 +82,7 @@ class TestCziImageFileAccess(unittest.TestCase): ) def test_conform_data_longer_than_xycz(self): - data = _random_int(256, 512, 12, 8, 3) + data = np.random.rand(256, 512, 12, 8, 3) with self.assertRaises(DataShapeError): acc = InMemoryDataAccessor(data) @@ -96,7 +93,7 @@ class TestCziImageFileAccess(unittest.TestCase): c = 3 nz = 10 - yxcz = _random_int(h, w, c, nz) + yxcz = (2**8 * np.random.rand(h, w, c, nz)).astype('uint8') acc = InMemoryDataAccessor(yxcz) fp = output_path / f'rand3d.tif' self.assertTrue( @@ -141,7 +138,7 @@ class TestPatchStackAccessor(unittest.TestCase): w = 256 h = 512 n = 4 - acc = PatchStack(_random_int(n, h, w, 1, 1)) + acc = PatchStack(np.random.rand(n, h, w, 1, 1)) self.assertEqual(acc.count, n) self.assertEqual(acc.hw, (h, w)) self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1)) @@ -150,7 +147,7 @@ class TestPatchStackAccessor(unittest.TestCase): w = 256 h = 512 n = 4 - acc = PatchStack([_random_int(h, w, 1, 1) for _ in range(0, n)]) + acc = PatchStack([np.random.rand(h, w, 1, 1) for _ in range(0, n)]) self.assertEqual(acc.count, n) self.assertEqual(acc.hw, (h, w)) self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1)) @@ -179,8 +176,8 @@ class TestPatchStackAccessor(unittest.TestCase): nz = 5 n = 4 - patches = [_random_int(h, w, c, nz) for _ in range(0, n)] - patches.append(_random_int(h, 2 * w, c, nz)) + patches = [np.random.rand(h, w, c, nz) for _ in range(0, n)] + patches.append(np.random.rand(h, 2 * w, c, nz)) acc = PatchStack(patches) self.assertEqual(acc.count, n + 1) self.assertEqual(acc.hw, (h, 2 * w)) @@ -194,30 +191,7 @@ class TestPatchStackAccessor(unittest.TestCase): n = 4 nz = 15 nc = 2 - acc = PatchStack(_random_int(n, h, w, nc, nz)) + acc = PatchStack(np.random.rand(n, h, w, nc, nz)) self.assertEqual(acc.count, n) self.assertEqual(acc.pczyx.shape, (n, nc, nz, h, w)) self.assertEqual(acc.hw, (h, w)) - return acc - - def test_get_one_channel(self): - acc = self.test_pczyx() - mono = acc.get_one_channel_data(channel=1) - for a in 'PXYZ': - self.assertEqual(mono.shape_dict[a], acc.shape_dict[a]) - self.assertEqual(mono.shape_dict['C'], 1) - - def test_get_one_channel_mip(self): - acc = self.test_pczyx() - mono_mip = acc.get_one_channel_data(channel=1, mip=True) - for a in 'PXY': - self.assertEqual(mono_mip.shape_dict[a], acc.shape_dict[a]) - for a in 'CZ': - self.assertEqual(mono_mip.shape_dict[a], 1) - - def test_export_pczyx_patch_hyperstack(self): - acc = self.test_pczyx() - fp = output_path / 'patch_hyperstack.tif' - acc.export_pyxcz(fp) - acc2 = make_patch_stack_from_file(fp) - self.assertEqual(acc.shape, acc2.shape) \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py index aa302338..1b201440 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -136,7 +136,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(resp_list_0.status_code, 200) rj0 = resp_list_0.json() self.assertEqual(len(rj0), 1, f'Unexpected models in response: {rj0}') - resp_restart = self._get('session/restart') + resp_restart = self._get('restart') resp_list_1 = self._get('models') rj1 = resp_list_1.json() self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}') @@ -171,9 +171,4 @@ class TestApiFromAutomatedClient(TestServerBaseClass): ) self.assertEqual(resp_change.status_code, 200) resp_check = self._get('paths') - self.assertEqual(resp_inpath.json()['outbound_images'], resp_check.json()['outbound_images']) - - def test_get_logs(self): - resp = self._get('session/logs') - self.assertEqual(resp.status_code, 200) - self.assertEqual(resp.json()[0]['message'], 'Initialized session') \ No newline at end of file + self.assertEqual(resp_inpath.json()['outbound_images'], resp_check.json()['outbound_images']) \ No newline at end of file diff --git a/tests/test_roiset.py b/tests/test_roiset.py index 0d507b3f..3040e6d2 100644 --- a/tests/test_roiset.py +++ b/tests/test_roiset.py @@ -1,12 +1,8 @@ -import os -import re import unittest import numpy as np from pathlib import Path -import pandas as pd - from model_server.conf.testing import output_path, roiset_test_data from model_server.base.roiset import RoiSetExportParams, RoiSetMetaParams @@ -33,36 +29,31 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): params=RoiSetMetaParams( mask_type=mask_type, filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}), - expand_box_by=(128, 2) + expand_box_by=(64, 2) ) ) return roiset def test_roi_mask_shape(self, **kwargs): roiset = self._make_roi_set(**kwargs) - - # all masks' bounding boxes are at least as big as ROI area - for roi in roiset.get_df().itertuples(): - self.assertEqual(roi.binary_mask.dtype, 'bool') - sh = roi.binary_mask.shape - self.assertEqual(sh, (roi.h, roi.w)) - self.assertGreaterEqual(sh[0] * sh[1], roi.area) - - def test_roi_zmask(self, **kwargs): - roiset = self._make_roi_set(**kwargs) zmask = roiset.get_zmask() zmask_acc = InMemoryDataAccessor(zmask) self.assertTrue(zmask_acc.is_mask()) # assert dimensionality of zmask - self.assertEqual(zmask_acc.nz, roiset.acc_raw.nz) - self.assertEqual(zmask_acc.chroma, 1) + self.assertGreater(zmask_acc.shape_dict['Z'], 1) + self.assertEqual(zmask_acc.shape_dict['C'], 1) write_accessor_data_to_file(output_path / 'zmask.tif', zmask_acc) # mask values are not just all True or all False self.assertTrue(np.any(zmask)) self.assertFalse(np.all(zmask)) + # assert non-trivial meta info in boxes + self.assertGreater(roiset.count, 1) + sh = roiset.get_df().iloc[1]['mask'].shape + ar = roiset.get_df().iloc[1]['area'] + self.assertGreaterEqual(sh[0] * sh[1], ar) def test_roiset_from_non_zstacks(self, **kwargs): acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0]) @@ -82,75 +73,37 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) - def test_dataframe_and_mask_array_in_iterator(self): - roiset = self._make_roi_set() - for roi in roiset: - ma = roi.binary_mask - self.assertEqual(ma.dtype, 'bool') - self.assertEqual(ma.shape, (roi.h, roi.w)) - def test_rel_slices_are_valid(self): roiset = self._make_roi_set() for roi in roiset: - ebb = roiset.acc_raw.data[roi.expanded_slice] + ebb = roiset.acc_raw.data[roi.slice] self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) rbb = ebb[roi.relative_slice] self.assertEqual(len(rbb.shape), 4) self.assertTrue(np.all([si >= 1 for si in rbb.shape])) - - def test_make_expanded_2d_patches(self): - roiset = self._make_roi_set() - files = roiset.export_patches( - output_path / 'expanded_2d_patches', - draw_bounding_box=True, - expanded=True, - pad_to=256, - ) - df = roiset.get_df() - for f in files: - acc = generate_file_accessor(f) - la = int(re.search(r'la([\d]+)', str(f)).group(1)) - roi_q = df.loc[df.label == la, :] - self.assertEqual(len(roi_q), 1) - self.assertEqual((256, 256), acc.hw) - - def test_make_tight_2d_patches(self): + def test_make_2d_patches(self): roiset = self._make_roi_set() files = roiset.export_patches( - output_path / 'tight_2d_patches', + output_path / '2d_patches', draw_bounding_box=True, - expanded=False ) - df = roiset.get_df() - for f in files: # all exported files are same shape as bounding boxes in RoiSet's datatable - acc = generate_file_accessor(f) - la = int(re.search(r'la([\d]+)', str(f)).group(1)) - roi_q = df.loc[df.label == la, :] - self.assertEqual(len(roi_q), 1) - roi = roi_q.iloc[0] - self.assertEqual((roi.h, roi.w), acc.hw) - - def test_make_expanded_3d_patches(self): + self.assertGreaterEqual(len(files), 1) + + def test_make_3d_patches(self): roiset = self._make_roi_set() files = roiset.export_patches( output_path / '3d_patches', - make_3d=True, - expanded=True - ) + make_3d=True) self.assertGreaterEqual(len(files), 1) - for f in files: - acc = generate_file_accessor(f) - self.assertGreater(acc.nz, 1) - def test_export_annotated_zstack(self): roiset = self._make_roi_set() file = roiset.export_annotated_zstack( output_path / 'annotated_zstack', ) - result = generate_file_accessor(file) + result = generate_file_accessor(Path(file['location']) / file['filename']) self.assertEqual(result.shape, roiset.acc_raw.shape) def test_flatten_image(self): @@ -179,15 +132,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_make_binary_masks(self): roiset = self._make_roi_set() files = roiset.export_patch_masks(output_path / '2d_mask_patches', ) - - df = roiset.get_df() - for f in files: # all exported files are same shape as bounding boxes in RoiSet's datatable - acc = generate_file_accessor(output_path / '2d_mask_patches' / f) - la = int(re.search(r'la([\d]+)', str(f)).group(1)) - roi_q = df.loc[df.label == la, :] - self.assertEqual(len(roi_q), 1) - roi = roi_q.iloc[0] - self.assertEqual((roi.h, roi.w), acc.hw) + self.assertGreaterEqual(len(files), 1) def test_classify_by(self): roiset = self._make_roi_set() @@ -211,21 +156,17 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_raw_patches_are_correct_shape(self): roiset = self._make_roi_set() - patches = roiset.get_patches_acc() + patches = roiset.get_raw_patches() np, h, w, nc, nz = patches.shape self.assertEqual(np, roiset.count) self.assertEqual(nc, roiset.acc_raw.chroma) - self.assertEqual(nz, 1) def test_patch_masks_are_correct_shape(self): roiset = self._make_roi_set() - df_patch_masks = roiset.get_patch_masks() - for roi in df_patch_masks.itertuples(): - h, w, nc, nz = roi.patch_mask.shape - self.assertEqual(nc, 1) - self.assertEqual(nz, 1) - self.assertEqual(h, roi.h) - self.assertEqual(w, roi.w) + patch_masks = roiset.get_patch_masks() + np, h, w, nc, nz = patch_masks.shape + self.assertEqual(np, roiset.count) + self.assertEqual(nc, 1) class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): @@ -248,10 +189,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa output_path / 'multichannel' / 'mono_2d_patches', white_channel=3, draw_bounding_box=True, - expanded=True, - pad_to=256, ) - result = generate_file_accessor(files[0]) + result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 1) def test_multichannnel_to_mono_2d_patches_rgb_bbox(self): @@ -260,10 +199,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa white_channel=3, draw_bounding_box=True, bounding_box_channel=1, - expanded=True, - pad_to=256, ) - result = generate_file_accessor(files[0]) + result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_bbox(self): @@ -271,28 +208,11 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa output_path / 'multichannel' / 'rgb_2d_patches_bbox', white_channel=4, rgb_overlay_channels=(3, None, None), - draw_mask=False, - draw_bounding_box=True, - bounding_box_channel=1, - rgb_overlay_weights=(0.1, 1.0, 1.0), - expanded=True, - pad_to=256, - ) - result = generate_file_accessor(files[0]) - self.assertEqual(result.chroma, 3) - - def test_multichannnel_to_rgb_2d_patches_mask(self): - files = self.roiset.export_patches( - output_path / 'multichannel' / 'rgb_2d_patches_mask', - white_channel=4, - rgb_overlay_channels=(3, None, None), draw_mask=True, mask_channel=0, - rgb_overlay_weights=(0.1, 1.0, 1.0), - expanded=True, - pad_to=256, + rgb_overlay_weights=(0.1, 1.0, 1.0) ) - result = generate_file_accessor(files[0]) + result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_contour(self): @@ -301,32 +221,25 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa rgb_overlay_channels=(3, None, None), draw_contour=True, contour_channel=1, - rgb_overlay_weights=(0.1, 1.0, 1.0), - expanded=True, - pad_to=256, + rgb_overlay_weights=(0.1, 1.0, 1.0) ) - result = generate_file_accessor(files[0]) + result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 3) self.assertEqual(result.get_one_channel_data(2).data.max(), 0) # blue channel is black def test_multichannel_to_multichannel_tif_patches(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'multichannel_tif_patches', - expanded=True, - pad_to=256, ) - result = generate_file_accessor(files[0]) + result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 5) - self.assertEqual(result.nz, 1) def test_multichannel_annotated_zstack(self): file = self.roiset.export_annotated_zstack( output_path / 'multichannel' / 'annotated_zstack', 'test_multichannel_annotated_zstack', - expanded=True, - pad_to=256, ) - result = generate_file_accessor(file) + result = generate_file_accessor(Path(file['location']) / file['filename']) self.assertEqual(result.chroma, self.stack.chroma) self.assertEqual(result.nz, self.stack.nz) @@ -334,104 +247,9 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa file = self.roiset.export_annotated_zstack( output_path / 'annotated_zstack', channel=3, - expanded=True, - pad_to=256, ) - result = generate_file_accessor(file) + result = generate_file_accessor(Path(file['location']) / file['filename']) self.assertEqual(result.hw, self.roiset.acc_raw.hw) self.assertEqual(result.nz, self.roiset.acc_raw.nz) self.assertEqual(result.chroma, 1) -class TestRoiSetFromZmask(unittest.TestCase): - - def setUp(self) -> None: - # set up test raw data and segmentation from file - self.stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) - self.stack_ch_pa = self.stack.get_one_channel_data(roiset_test_data['pipeline_params']['segmentation_channel']) - self.seg_mask_3d = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path_3d']) - - @staticmethod - def _label_is_2d(id_map, la): # single label's zmask has same counts as its MIP - mask_3d = (id_map == la) - mask_mip = mask_3d.max(axis=-1) - return mask_3d.sum() == mask_mip.sum() - - def test_id_map_connects_z(self): - id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=True) - labels = np.unique(id_map.data)[1:] - is_2d = all([self._label_is_2d(id_map.data, la) for la in labels]) - self.assertFalse(is_2d) - - def test_id_map_disconnects_z(self): - id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) - labels = np.unique(id_map.data)[1:] - is_2d = all([self._label_is_2d(id_map.data, la) for la in labels]) - self.assertTrue(is_2d) - - def test_create_roiset_from_3d_obj_ids(self): - id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) - self.assertEqual(self.stack_ch_pa.shape, id_map.shape) - - roiset = RoiSet( - self.stack_ch_pa, - id_map, - params=RoiSetMetaParams(mask_type='contours') - ) - self.assertEqual(roiset.count, id_map.data.max()) - self.assertGreater(len(roiset.get_df()['zi'].unique()), 1) - - def test_create_roiset_from_2d_obj_ids(self): - id_map = _get_label_ids(self.seg_mask_3d, allow_3d=False) - self.assertEqual(self.stack_ch_pa.shape[0:3], id_map.shape[0:3]) - self.assertEqual(id_map.nz, 1) - - roiset = RoiSet( - self.stack_ch_pa, - id_map, - params=RoiSetMetaParams(mask_type='contours') - ) - self.assertEqual(roiset.count, id_map.data.max()) - self.assertGreater(len(roiset.get_df()['zi'].unique()), 1) - return roiset - - def test_create_roiset_from_df_and_patch_masks(self): - ref_roiset = self.test_create_roiset_from_2d_obj_ids() - where_ser = output_path / 'serialize' - ref_roiset.serialize(where_ser, prefix='ref') - where_df = where_ser / 'dataframe' / 'ref.csv' - self.assertTrue(where_df.exists()) - df_test = pd.read_csv(where_df) - - # check that patches are correct size - where_patch_masks = where_ser / 'tight_patch_masks' - patch_filenames = [] - for pmf in where_patch_masks.iterdir(): - self.assertTrue(pmf.suffix.upper() == '.PNG') - la = int(re.search(r'la([\d]+)', str(pmf)).group(1)) - roi_q = df_test.loc[df_test.label == la, :] - self.assertEqual(len(roi_q), 1) - roi = roi_q.iloc[0] - m_acc = generate_file_accessor(pmf) - self.assertEqual((roi.h, roi.w), m_acc.hw) - patch_filenames.append(pmf.name) - - # make another RoiSet from just the data table, raw images, and (tight) patch masks - test_roiset = RoiSet.deserialize(self.stack_ch_pa, where_ser, prefix='ref') - self.assertEqual(ref_roiset.get_zmask().shape, test_roiset.get_zmask().shape,) - self.assertTrue((ref_roiset.get_zmask() == test_roiset.get_zmask()).all()) - self.assertTrue(np.all(test_roiset.get_df().label == ref_roiset.get_df().label)) - cols = ['label', 'y1', 'y0', 'x1', 'x0', 'zi'] - self.assertTrue((test_roiset.get_df()[cols] == ref_roiset.get_df()[cols]).all().all()) - - # re-serialize and check that patch masks are the same - where_dser = output_path / 'deserialize' - test_roiset.serialize(where_dser, prefix='test') - for fr in patch_filenames: - pr = (where_ser / 'tight_patch_masks' / fr) - self.assertTrue(pr.exists()) - pt = (where_dser / 'tight_patch_masks' / fr.replace('ref', 'test')) - self.assertTrue(pt.exists()) - r_acc = generate_file_accessor(pr) - t_acc = generate_file_accessor(pt) - self.assertTrue(np.all(r_acc.data == t_acc.data)) - diff --git a/tests/test_session.py b/tests/test_session.py index aafda3c2..9679aad6 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,83 +1,72 @@ -import json -from os.path import exists import pathlib import unittest - from model_server.base.models import DummySemanticSegmentationModel from model_server.base.session import Session -from model_server.base.workflows import WorkflowRunRecord class TestGetSessionObject(unittest.TestCase): def setUp(self) -> None: - self.sesh = Session() - - def tearDown(self) -> None: - print('Tearing down...') - Session._instances = {} + pass - def test_session_is_singleton(self): - Session._instances = {} - self.assertEqual(len(Session._instances), 0) - s = Session() - self.assertEqual(len(Session._instances), 1) - self.assertIs(s, Session()) - self.assertEqual(len(Session._instances), 1) + def test_single_session_instance(self): + sesh = Session() + self.assertIs(sesh, Session(), 'Re-initializing Session class returned a new object') - def test_session_logfile_is_valid(self): - self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place') + from os.path import exists + self.assertTrue(exists(sesh.session_log), 'Session did not create a log file in the correct place') + self.assertTrue(exists(sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place') def test_changing_session_root_creates_new_directory(self): from model_server.conf.defaults import root from shutil import rmtree - old_paths = self.sesh.get_paths() + sesh = Session() + old_paths = sesh.get_paths() newroot = root / 'subdir' - self.sesh.restart(root=newroot) - new_paths = self.sesh.get_paths() + sesh.restart(root=newroot) + new_paths = sesh.get_paths() for k in old_paths.keys(): self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__())) - - # this is necessary because logger itself is a singleton class - self.tearDown() - self.setUp() rmtree(newroot) self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory') def test_change_session_subdirectory(self): - old_paths = self.sesh.get_paths() + sesh = Session() + old_paths = sesh.get_paths() print(old_paths) - self.sesh.set_data_directory('outbound_images', old_paths['inbound_images']) - self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images']) - - def test_restarting_session_creates_new_logfile(self): - logfile1 = self.sesh.logfile - self.assertTrue(logfile1.exists()) - self.sesh.restart() - logfile2 = self.sesh.logfile - self.assertTrue(logfile2.exists()) + sesh.set_data_directory('outbound_images', old_paths['inbound_images']) + self.assertEqual(sesh.paths['outbound_images'], sesh.paths['inbound_images']) + + + def test_restart_session(self): + sesh = Session() + logfile1 = sesh.session_log + sesh.restart() + logfile2 = sesh.session_log self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile') - def test_log_warning(self): - msg = 'A test warning' - self.sesh.log_info(msg) - with open(self.sesh.logfile, 'r') as fh: - log = fh.read() - self.assertTrue(msg in log) - - def test_get_logs(self): - self.sesh.log_info('Info example 1') - self.sesh.log_warning('Example warning') - self.sesh.log_info('Info example 2') - logs = self.sesh.get_log_data() - self.assertEqual(len(logs), 4) - self.assertEqual(logs[1]['level'], 'WARNING') - self.assertEqual(logs[-1]['message'], 'Initialized session') + def test_session_records_workflow(self): + import json + from model_server.base.workflows import WorkflowRunRecord + sesh = Session() + di = WorkflowRunRecord( + model_id='test_model', + input_filepath='/test/input/directory', + output_filepath='/test/output/fi.le', + success=True, + timer_results={'start': 0.123}, + ) + sesh.record_workflow_run(di) + with open(sesh.manifest_json, 'r') as fh: + do = json.load(fh) + self.assertEqual(di.dict(), do, 'Manifest record is not correct') + def test_session_loads_model(self): + sesh = Session() MC = DummySemanticSegmentationModel - success = self.sesh.load_model(MC) + success = sesh.load_model(MC) self.assertTrue(success) - loaded_models = self.sesh.describe_loaded_models() + loaded_models = sesh.describe_loaded_models() self.assertTrue( (MC.__name__ + '_00') in loaded_models.keys() ) @@ -87,56 +76,46 @@ class TestGetSessionObject(unittest.TestCase): ) def test_session_loads_second_instance_of_same_model(self): + sesh = Session() MC = DummySemanticSegmentationModel - self.sesh.load_model(MC) - self.sesh.load_model(MC) - self.assertIn(MC.__name__ + '_00', self.sesh.models.keys()) - self.assertIn(MC.__name__ + '_01', self.sesh.models.keys()) + sesh.load_model(MC) + sesh.load_model(MC) + self.assertIn(MC.__name__ + '_00', sesh.models.keys()) + self.assertIn(MC.__name__ + '_01', sesh.models.keys()) + def test_session_loads_model_with_params(self): + sesh = Session() MC = DummySemanticSegmentationModel p1 = {'p1': 'abc'} - success = self.sesh.load_model(MC, params=p1) + success = sesh.load_model(MC, params=p1) self.assertTrue(success) - loaded_models = self.sesh.describe_loaded_models() + loaded_models = sesh.describe_loaded_models() mid = MC.__name__ + '_00' self.assertEqual(loaded_models[mid]['params'], p1) # load a second model and confirm that the first is locatable by its param entry p2 = {'p2': 'def'} - self.sesh.load_model(MC, params=p2) - find_mid = self.sesh.find_param_in_loaded_models('p1', 'abc') + sesh.load_model(MC, params=p2) + find_mid = sesh.find_param_in_loaded_models('p1', 'abc') self.assertEqual(mid, find_mid) - self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1) + self.assertEqual(sesh.describe_loaded_models()[mid]['params'], p1) def test_session_finds_existing_model_with_different_path_formats(self): + sesh = Session() MC = DummySemanticSegmentationModel p1 = {'path': 'c:\\windows\\dummy.pa'} p2 = {'path': 'c:/windows/dummy.pa'} - mid = self.sesh.load_model(MC, params=p1) + mid = sesh.load_model(MC, params=p1) assert pathlib.Path(p1['path']) == pathlib.Path(p2['path']) - find_mid = self.sesh.find_param_in_loaded_models('path', p2['path'], is_path=True) + find_mid = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True) self.assertEqual(mid, find_mid) def test_change_output_path(self): - pa = self.sesh.get_paths()['inbound_images'] + import pathlib + sesh = Session() + pa = sesh.get_paths()['inbound_images'] self.assertIsInstance(pa, pathlib.Path) - self.sesh.set_data_directory('outbound_images', pa.__str__()) - self.assertEqual(self.sesh.paths['inbound_images'], self.sesh.paths['outbound_images']) - self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path) - - def test_make_table(self): - import pandas as pd - data = [{'modulo': i % 2, 'times one hundred': i * 100} for i in range(0, 8)] - self.sesh.write_to_table( - 'test_numbers', {'X': 0, 'Y': 0}, pd.DataFrame(data[0:4]) - ) - self.assertTrue(self.sesh.tables['test_numbers'].path.exists()) - self.sesh.write_to_table( - 'test_numbers', {'X': 1, 'Y': 1}, pd.DataFrame(data[4:8]) - ) - - dfv = pd.read_csv(self.sesh.tables['test_numbers'].path) - self.assertEqual(len(dfv), len(data)) - self.assertEqual(dfv.columns[0], 'X') - self.assertEqual(dfv.columns[1], 'Y') + sesh.set_data_directory('outbound_images', pa.__str__()) + self.assertEqual(sesh.paths['inbound_images'], sesh.paths['outbound_images']) + self.assertIsInstance(sesh.paths['outbound_images'], pathlib.Path) \ No newline at end of file -- GitLab