Skip to content
Snippets Groups Projects
Commit 1af59120 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Rearranged methods, condensed into roiset package

parent 0ea7a118
No related branches found
No related tags found
No related merge requests found
...@@ -71,4 +71,17 @@ def rescale(nda, clip=0.0): ...@@ -71,4 +71,17 @@ def rescale(nda, clip=0.0):
clip_pct = (100.0 * clip, 100.0 * (1.0 - clip)) clip_pct = (100.0 * clip, 100.0 * (1.0 - clip))
cmin, cmax = np.percentile(nda, clip_pct) cmin, cmax = np.percentile(nda, clip_pct)
rescaled = rescale_intensity(nda, in_range=(cmin, cmax)) rescaled = rescale_intensity(nda, in_range=(cmin, cmax))
return rescaled return rescaled
\ No newline at end of file
def make_rgb(nda):
"""
Convert a YXCZ stack array to RGB, and error if more than three channels
:param nda: np.ndarray (YXCZ dimensions)
:return: np.ndarray of 3-channel stack
"""
h, w, c, nz = nda.shape
assert c <= 3
outdata = np.zeros((h, w, 3, nz), dtype=nda.dtype)
outdata[:, :, 0:c, :] = nda[:, :, :, :]
return outdata
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
from skimage.io import imsave
from tifffile import imwrite
from base.process import make_rgb
from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor
class MonoPatchStack(InMemoryDataAccessor): class MonoPatchStack(InMemoryDataAccessor):
...@@ -127,6 +130,33 @@ class Multichannel3dPatchStack(InMemoryDataAccessor): ...@@ -127,6 +130,33 @@ class Multichannel3dPatchStack(InMemoryDataAccessor):
def shape_dict(self): def shape_dict(self):
return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))
def write_patch_to_file(where, fname, yxcz):
ext = fname.split('.')[-1].upper()
where.mkdir(parents=True, exist_ok=True)
if ext == 'PNG':
assert yxcz.dtype == 'uint8', f'Invalid data type {yxcz.dtype}'
assert yxcz.shape[2] <= 3, f'Cannot export images with more than 3 channels as PNGs'
assert yxcz.shape[3] == 1, f'Cannot export z-stacks as PNGs'
if yxcz.shape[2] == 1:
outdata = yxcz[:, :, 0, 0]
elif yxcz.shape[2] == 2: # add a blank blue channel
outdata = make_rgb(yxcz)
else: # preserve RGB order
outdata = yxcz[:, :, :, 0]
imsave(where / fname, outdata, check_contrast=False)
return True
elif ext in ['TIF', 'TIFF']:
zcyx = np.moveaxis(yxcz, [3, 2, 0, 1], [0, 1, 2, 3])
imwrite(where / fname, zcyx, imagej=True)
return True
else:
raise Exception(f'Unsupported file extension: {ext}')
class Error(Exception): class Error(Exception):
pass pass
...@@ -135,3 +165,5 @@ class InvalidDataForPatchStackError(Error): ...@@ -135,3 +165,5 @@ class InvalidDataForPatchStackError(Error):
class FileNotFoundError(Error): class FileNotFoundError(Error):
pass pass
from math import floor, sqrt
from pathlib import Path
import numpy as np
import pandas as pd
from scipy.stats import moment
from skimage.filters import sobel
from skimage.io import imsave
from skimage.measure import find_contours, shannon_entropy
from tifffile import imwrite
from model_server.extensions.chaeo.accessors import MonoPatchStack
from model_server.extensions.chaeo.annotators import draw_box_on_patch, draw_contours_on_patch
from model_server.base.accessors import InMemoryDataAccessor
from model_server.base.process import pad, rescale, resample_to_8bit
def _make_rgb(zs):
h, w, c, nz = zs.shape
assert c <= 3
outdata = np.zeros((h, w, 3, nz), dtype=zs.dtype)
outdata[:, :, 0:c, :] = zs[:, :, :, :]
return outdata
def _focus_metrics():
return {
'max_intensity': lambda x: np.max(x),
'stdev': lambda x: np.std(x),
'max_sobel': lambda x: np.max(sobel(x)),
'rms_sobel': lambda x: sqrt(np.mean(sobel(x) ** 2)),
'entropy': lambda x: shannon_entropy(x),
'moment': lambda x: moment(x.flatten(), moment=2),
}
def _safe_add(a, g, b):
assert a.dtype == b.dtype
assert a.shape == b.shape
assert g >= 0.0
return np.clip(
a.astype('uint32') + g * b.astype('uint32'),
0,
np.iinfo(a.dtype).max
).astype(a.dtype)
def write_patch_to_file(where, fname, yxcz):
ext = fname.split('.')[-1].upper()
where.mkdir(parents=True, exist_ok=True)
if ext == 'PNG':
assert yxcz.dtype == 'uint8', f'Invalid data type {yxcz.dtype}'
assert yxcz.shape[2] <= 3, f'Cannot export images with more than 3 channels as PNGs'
assert yxcz.shape[3] == 1, f'Cannot export z-stacks as PNGs'
if yxcz.shape[2] == 1:
outdata = yxcz[:, :, 0, 0]
elif yxcz.shape[2] == 2: # add a blank blue channel
outdata = _make_rgb(yxcz)
else: # preserve RGB order
outdata = yxcz[:, :, :, 0]
imsave(where / fname, outdata, check_contrast=False)
return True
elif ext in ['TIF', 'TIFF']:
zcyx = np.moveaxis(yxcz, [3, 2, 0, 1], [0, 1, 2, 3])
imwrite(where / fname, zcyx, imagej=True)
return True
else:
raise Exception(f'Unsupported file extension: {ext}')
def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack:
patches = []
for roi in roiset:
patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8')
patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255
if pad_to:
patch = pad(patch, pad_to)
patches.append(patch)
return MonoPatchStack(patches)
def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask', **kwargs) -> list:
patches_acc = get_patch_masks(roiset, pad_to=pad_to)
exported = []
for i, roi in enumerate(roiset): # assumes index of patches_acc is same as dataframe
patch = patches_acc.iat_yxcz(i)
ext = 'png'
fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'
write_patch_to_file(where, fname, patch)
exported.append(fname)
return exported
def get_roiset_patches(
roiset,
rescale_clip: float = 0.0,
pad_to: int = 256,
make_3d: bool = False,
focus_metric: str = None,
rgb_overlay_channels: list = None,
rgb_overlay_weights: list = [1.0, 1.0, 1.0],
white_channel: int = None,
**kwargs
) -> pd.DataFrame:
# arrange RGB channels if so specified, otherwise copy roiset.raw_acc data
raw = roiset.acc_raw
if isinstance(rgb_overlay_channels, (list, tuple)) and isinstance(rgb_overlay_weights, (list, tuple)):
assert all([c < raw.chroma for c in rgb_overlay_channels if c is not None])
assert len(rgb_overlay_channels) == 3
assert len(rgb_overlay_weights) == 3
if white_channel:
assert white_channel < raw.chroma
stack = raw.data[:, :, [white_channel, white_channel, white_channel], :]
else:
stack = np.zeros([*raw.shape[0:2], 3, raw.shape[3]], dtype=raw.dtype)
for ii, ci in enumerate(rgb_overlay_channels):
if ci is None:
continue
assert isinstance(ci, int)
assert ci < raw.chroma
stack[:, :, ii, :] = _safe_add(
stack[:, :, ii, :], # either black or grayscale channel
rgb_overlay_weights[ii],
raw.data[:, :, ci, :]
)
else:
if white_channel: # interpret as just a single channel
assert white_channel < raw.chroma
annotate_rgb = False
for k in ['contour_channel', 'bounding_box_channel', 'mask_channel']:
ca = kwargs.get(k)
if ca is None:
continue
assert(ca < raw.chroma)
if ca != white_channel:
annotate_rgb = True
break
if annotate_rgb: # make RGB patches anyway to include annotation color
stack = raw.data[:, :, [white_channel, white_channel, white_channel], :]
else: # make monochrome patches
stack = raw.data[:, :, [white_channel], :]
else:
stack = raw.data
def _make_patch(roi):
patch3d = stack[roi.slice]
ph, pw, pc, pz = patch3d.shape
subpatch = patch3d[roi.relative_slice]
# make a 3d patch
if make_3d:
patch = patch3d
# make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately
elif focus_metric is not None:
foc = _focus_metrics()[focus_metric]
patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype)
for ci in range(0, pc):
me = [foc(subpatch[:, :, ci, zi]) for zi in range(0, pz)]
zif = np.argmax(me)
patch[:, :, ci, 0] = patch3d[:, :, ci, zif]
# make a 2d patch from middle of z-stack
else:
zim = floor(pz / 2)
patch = patch3d[:, :, :, [zim]]
assert len(patch.shape) == 4
if rescale_clip is not None:
patch = rescale(patch, rescale_clip)
if kwargs.get('draw_bounding_box') is True:
bci = kwargs.get('bounding_box_channel', 0)
assert bci < 3
if bci > 0:
patch = _make_rgb(patch)
for zi in range(0, patch.shape[3]):
patch[:, :, bci, zi] = draw_box_on_patch(
patch[:, :, bci, zi],
((roi.rel_x0, roi.rel_y0), (roi.rel_x1, roi.rel_y1)),
linewidth=kwargs.get('bounding_box_linewidth', 1)
)
if kwargs.get('draw_mask'):
mci = kwargs.get('mask_channel', 0)
mask = np.zeros(patch.shape[0:2], dtype=bool)
mask[roi.relative_slice[0:2]] = roi.mask
for zi in range(0, patch.shape[3]):
patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi]
if kwargs.get('draw_contour'):
mci = kwargs.get('contour_channel', 0)
mask = np.zeros(patch.shape[0:2], dtype=bool)
mask[roi.relative_slice[0:2]] = roi.mask
for zi in range(0, patch.shape[3]):
patch[:, :, mci, zi] = draw_contours_on_patch(
patch[:, :, mci, zi],
find_contours(mask)
)
if pad_to:
patch = pad(patch, pad_to)
return patch
dfe = roiset.get_df()
dfe['patch'] = roiset.get_df().apply(lambda r: _make_patch(r), axis=1)
return dfe
def export_patches_from_zstack(
where: Path,
roiset,
prefix='patch',
**kwargs
):
make_3d = kwargs.get('make_3d', False)
patches_df = get_roiset_patches(roiset, **kwargs)
def _export_patch(roi):
patch = InMemoryDataAccessor(roi.patch)
ext = 'tif' if make_3d or patch.chroma > 3 else 'png'
fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'
if patch.dtype is np.dtype('uint16'):
write_patch_to_file(where, fname, resample_to_8bit(patch.data))
else:
write_patch_to_file(where, fname, patch)
exported.append({
'df_index': roi.Index,
'patch_filename': fname,
'location': where.__str__(),
})
exported = []
for roi in patches_df.itertuples(): # just used for label info
_export_patch(roi)
return exported
\ No newline at end of file
from math import sqrt, floor
from pathlib import Path
from uuid import uuid4 from uuid import uuid4
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from scipy.stats import moment
from skimage.filters import sobel
from skimage.measure import label, regionprops_table from skimage.measure import label, regionprops_table, shannon_entropy, find_contours
from sklearn.preprocessing import PolynomialFeatures from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression from sklearn.linear_model import LinearRegression
from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.base.models import InstanceSegmentationModel from model_server.base.models import InstanceSegmentationModel
from model_server.base.process import pad, rescale, resample_to_8bit, make_rgb
from model_server.extensions.chaeo.annotators import draw_boxes_on_3d_image from model_server.extensions.chaeo.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image
from model_server.extensions.chaeo.products import export_patches_from_zstack from model_server.extensions.chaeo.params import RoiFilter, RoiSetMetaParams, RoiSetExportParams
from extensions.chaeo.params import RoiFilter, RoiSetMetaParams, RoiSetExportParams from model_server.extensions.chaeo.accessors import write_patch_to_file, MonoPatchStack, Multichannel3dPatchStack
from model_server.extensions.chaeo.accessors import MonoPatchStack, Multichannel3dPatchStack
from model_server.extensions.chaeo.process import mask_largest_object from model_server.extensions.chaeo.process import mask_largest_object
from model_server.extensions.chaeo.products import get_roiset_patches, get_patch_masks, export_patch_masks
def get_label_ids(acc_seg_mask: GenericImageDataAccessor) -> InMemoryDataAccessor: def _get_label_ids(acc_seg_mask: GenericImageDataAccessor) -> InMemoryDataAccessor:
return InMemoryDataAccessor(label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16')) return InMemoryDataAccessor(label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16'))
def _focus_metrics():
return {
'max_intensity': lambda x: np.max(x),
'stdev': lambda x: np.std(x),
'max_sobel': lambda x: np.max(sobel(x)),
'rms_sobel': lambda x: sqrt(np.mean(sobel(x) ** 2)),
'entropy': lambda x: shannon_entropy(x),
'moment': lambda x: moment(x.flatten(), moment=2),
}
def _safe_add(a, g, b):
assert a.dtype == b.dtype
assert a.shape == b.shape
assert g >= 0.0
return np.clip(
a.astype('uint32') + g * b.astype('uint32'),
0,
np.iinfo(a.dtype).max
).astype(a.dtype)
class RoiSet(object): class RoiSet(object):
def __init__( def __init__(
self, self,
acc_obj_ids: GenericImageDataAccessor,
acc_raw: GenericImageDataAccessor, acc_raw: GenericImageDataAccessor,
acc_obj_ids: GenericImageDataAccessor,
params: RoiSetMetaParams = RoiSetMetaParams(), params: RoiSetMetaParams = RoiSetMetaParams(),
): ):
"""
A set of regions of interest, referenced by their positions and contours in the YXCZ space of stack acc_raw.
RoiSet contains their internal state, which may be exported as patches, maps, and other products by export methods.
:param acc_raw: accessor to a generally a multichannel z-stack
:param acc_obj_ids: accessor to a 2D single-channel object identities map, where each pixel's intensity
labels its membership in a connected object
:param params: optional arguments that influence the definition and representation of ROIs
"""
assert acc_obj_ids.chroma == 1 assert acc_obj_ids.chroma == 1
assert acc_obj_ids.nz == 1 assert acc_obj_ids.nz == 1
self.acc_obj_ids = acc_obj_ids self.acc_obj_ids = acc_obj_ids
...@@ -147,17 +181,11 @@ class RoiSet(object): ...@@ -147,17 +181,11 @@ class RoiSet(object):
projected = self.acc_raw.data.max(axis=-1) projected = self.acc_raw.data.max(axis=-1)
return projected return projected
def get_patch_masks(self, **kwargs):
return get_patch_masks(self, **kwargs)
def export_patch_masks(self, where, **kwargs) -> list:
return export_patch_masks(self, where, **kwargs)
def get_raw_patches(self, channel=None, pad_to=256, make_3d=False): # padded, un-annotated 2d patches def get_raw_patches(self, channel=None, pad_to=256, make_3d=False): # padded, un-annotated 2d patches
if channel: if channel:
patches_df = get_roiset_patches(self, white_channel=channel, pad_to=pad_to) patches_df = self.get_patches(white_channel=channel, pad_to=pad_to)
else: else:
patches_df = get_roiset_patches(self, pad_to=pad_to) patches_df = self.get_patches(pad_to=pad_to)
patches = list(patches_df['patch']) patches = list(patches_df['patch'])
if channel is not None or self.acc_raw.chroma == 1: if channel is not None or self.acc_raw.chroma == 1:
return MonoPatchStack(patches) return MonoPatchStack(patches)
...@@ -229,6 +257,179 @@ class RoiSet(object): ...@@ -229,6 +257,179 @@ class RoiSet(object):
om[self.acc_obj_ids.data == roi.label] = oc om[self.acc_obj_ids.data == roi.label] = oc
self.object_class_maps[name] = InMemoryDataAccessor(om) self.object_class_maps[name] = InMemoryDataAccessor(om)
def export_patch_masks(self, where: Path, pad_to: int = 256, prefix='mask', **kwargs) -> list:
patches_acc = self.get_patch_masks(pad_to=pad_to)
exported = []
for i, roi in enumerate(self): # assumes index of patches_acc is same as dataframe
patch = patches_acc.iat_yxcz(i)
ext = 'png'
fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'
write_patch_to_file(where, fname, patch)
exported.append(fname)
return exported
def export_patches(self, where: Path, prefix='patch', **kwargs):
make_3d = kwargs.get('make_3d', False)
patches_df = self.get_patches(**kwargs)
def _export_patch(roi):
patch = InMemoryDataAccessor(roi.patch)
ext = 'tif' if make_3d or patch.chroma > 3 else 'png'
fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'
if patch.dtype is np.dtype('uint16'):
write_patch_to_file(where, fname, resample_to_8bit(patch.data))
else:
write_patch_to_file(where, fname, patch)
exported.append({
'df_index': roi.Index,
'patch_filename': fname,
'location': where.__str__(),
})
exported = []
for roi in patches_df.itertuples(): # just used for label info
_export_patch(roi)
return exported
def get_patch_masks(self, pad_to: int = 256) -> MonoPatchStack:
patches = []
for roi in self:
patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8')
patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255
if pad_to:
patch = pad(patch, pad_to)
patches.append(patch)
return MonoPatchStack(patches)
def get_patches(
self,
rescale_clip: float = 0.0,
pad_to: int = 256,
make_3d: bool = False,
focus_metric: str = None,
rgb_overlay_channels: list = None,
rgb_overlay_weights: list = [1.0, 1.0, 1.0],
white_channel: int = None,
**kwargs
) -> pd.DataFrame:
# arrange RGB channels if so specified, otherwise copy roiset.raw_acc data
raw = self.acc_raw
if isinstance(rgb_overlay_channels, (list, tuple)) and isinstance(rgb_overlay_weights, (list, tuple)):
assert all([c < raw.chroma for c in rgb_overlay_channels if c is not None])
assert len(rgb_overlay_channels) == 3
assert len(rgb_overlay_weights) == 3
if white_channel:
assert white_channel < raw.chroma
stack = raw.data[:, :, [white_channel, white_channel, white_channel], :]
else:
stack = np.zeros([*raw.shape[0:2], 3, raw.shape[3]], dtype=raw.dtype)
for ii, ci in enumerate(rgb_overlay_channels):
if ci is None:
continue
assert isinstance(ci, int)
assert ci < raw.chroma
stack[:, :, ii, :] = _safe_add(
stack[:, :, ii, :], # either black or grayscale channel
rgb_overlay_weights[ii],
raw.data[:, :, ci, :]
)
else:
if white_channel: # interpret as just a single channel
assert white_channel < raw.chroma
annotate_rgb = False
for k in ['contour_channel', 'bounding_box_channel', 'mask_channel']:
ca = kwargs.get(k)
if ca is None:
continue
assert (ca < raw.chroma)
if ca != white_channel:
annotate_rgb = True
break
if annotate_rgb: # make RGB patches anyway to include annotation color
stack = raw.data[:, :, [white_channel, white_channel, white_channel], :]
else: # make monochrome patches
stack = raw.data[:, :, [white_channel], :]
else:
stack = raw.data
def _make_patch(roi):
patch3d = stack[roi.slice]
ph, pw, pc, pz = patch3d.shape
subpatch = patch3d[roi.relative_slice]
# make a 3d patch
if make_3d:
patch = patch3d
# make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately
elif focus_metric is not None:
foc = _focus_metrics()[focus_metric]
patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype)
for ci in range(0, pc):
me = [foc(subpatch[:, :, ci, zi]) for zi in range(0, pz)]
zif = np.argmax(me)
patch[:, :, ci, 0] = patch3d[:, :, ci, zif]
# make a 2d patch from middle of z-stack
else:
zim = floor(pz / 2)
patch = patch3d[:, :, :, [zim]]
assert len(patch.shape) == 4
if rescale_clip is not None:
patch = rescale(patch, rescale_clip)
if kwargs.get('draw_bounding_box') is True:
bci = kwargs.get('bounding_box_channel', 0)
assert bci < 3
if bci > 0:
patch = make_rgb(patch)
for zi in range(0, patch.shape[3]):
patch[:, :, bci, zi] = draw_box_on_patch(
patch[:, :, bci, zi],
((roi.rel_x0, roi.rel_y0), (roi.rel_x1, roi.rel_y1)),
linewidth=kwargs.get('bounding_box_linewidth', 1)
)
if kwargs.get('draw_mask'):
mci = kwargs.get('mask_channel', 0)
mask = np.zeros(patch.shape[0:2], dtype=bool)
mask[roi.relative_slice[0:2]] = roi.mask
for zi in range(0, patch.shape[3]):
patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi]
if kwargs.get('draw_contour'):
mci = kwargs.get('contour_channel', 0)
mask = np.zeros(patch.shape[0:2], dtype=bool)
mask[roi.relative_slice[0:2]] = roi.mask
for zi in range(0, patch.shape[3]):
patch[:, :, mci, zi] = draw_contours_on_patch(
patch[:, :, mci, zi],
find_contours(mask)
)
if pad_to:
patch = pad(patch, pad_to)
return patch
dfe = self._df
dfe['patch'] = self._df.apply(lambda r: _make_patch(r), axis=1)
return dfe
def run_exports(self, where, channel, prefix, params: RoiSetExportParams): def run_exports(self, where, channel, prefix, params: RoiSetExportParams):
if not self.count: if not self.count:
return return
...@@ -240,17 +441,17 @@ class RoiSet(object): ...@@ -240,17 +441,17 @@ class RoiSet(object):
if kp is None: if kp is None:
continue continue
if k == 'patches_3d': if k == 'patches_3d':
files = export_patches_from_zstack( files = self.export_patches(
subdir, self, white_channel=channel, prefix=pr, make_3d=True, **kp subdir, white_channel=channel, prefix=pr, make_3d=True, **kp
) )
if k == 'annotated_patches_2d': if k == 'annotated_patches_2d':
files = export_patches_from_zstack( files = self.export_patches(
subdir, self, prefix=pr, make_3d=False, white_channel=channel, subdir, prefix=pr, make_3d=False, white_channel=channel,
bounding_box_channel=1, bounding_box_linewidth=2, **kp, bounding_box_channel=1, bounding_box_linewidth=2, **kp,
) )
if k == 'patches_2d': if k == 'patches_2d':
files = export_patches_from_zstack( files = self.export_patches(
subdir, self, white_channel=channel, prefix=pr, make_3d=False, **kp subdir, white_channel=channel, prefix=pr, make_3d=False, **kp
) )
df_patches = pd.DataFrame(files) df_patches = pd.DataFrame(files)
self._df = pd.merge(self._df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') self._df = pd.merge(self._df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
...@@ -313,3 +514,7 @@ def project_stack_from_focal_points( ...@@ -313,3 +514,7 @@ def project_stack_from_focal_points(
), ),
axis=3 axis=3
) )
...@@ -7,9 +7,8 @@ from model_server.conf.testing import output_path ...@@ -7,9 +7,8 @@ from model_server.conf.testing import output_path
from model_server.extensions.chaeo.conf.testing import multichannel_zstack, pixel_classifier, pipeline_params from model_server.extensions.chaeo.conf.testing import multichannel_zstack, pixel_classifier, pipeline_params
from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams
from model_server.extensions.chaeo.products import export_patches_from_zstack
from model_server.extensions.chaeo.workflows import infer_object_map_from_zstack from model_server.extensions.chaeo.workflows import infer_object_map_from_zstack
from model_server.extensions.chaeo.zmask import get_label_ids, RoiSet from model_server.extensions.chaeo.roiset import _get_label_ids, RoiSet
from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.extensions.ilastik.models import IlastikPixelClassifierModel from model_server.extensions.ilastik.models import IlastikPixelClassifierModel
from model_server.base.models import DummyInstanceSegmentationModel from model_server.base.models import DummyInstanceSegmentationModel
...@@ -42,10 +41,10 @@ class BaseTestRoiSetMonoProducts(object): ...@@ -42,10 +41,10 @@ class BaseTestRoiSetMonoProducts(object):
class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
def _make_roi_set(self, mask_type='boxes', **kwargs): def _make_roi_set(self, mask_type='boxes', **kwargs):
id_map = get_label_ids(self.seg_mask) id_map = _get_label_ids(self.seg_mask)
roiset = RoiSet( roiset = RoiSet(
id_map,
self.stack_ch_pa, self.stack_ch_pa,
id_map,
params=RoiSetMetaParams( params=RoiSetMetaParams(
mask_type=mask_type, mask_type=mask_type,
filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}), filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}),
...@@ -78,9 +77,9 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): ...@@ -78,9 +77,9 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
def test_roiset_from_non_zstacks(self, **kwargs): def test_roiset_from_non_zstacks(self, **kwargs):
acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0]) acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0])
self.assertEqual(acc_zstack_slice.nz, 1) self.assertEqual(acc_zstack_slice.nz, 1)
id_map = get_label_ids(self.seg_mask) id_map = _get_label_ids(self.seg_mask)
roiset = RoiSet(id_map, acc_zstack_slice, params=RoiSetMetaParams(mask_type='boxes')) roiset = RoiSet(acc_zstack_slice, id_map, params=RoiSetMetaParams(mask_type='boxes'))
zmask = roiset.get_zmask() zmask = roiset.get_zmask()
zmask_acc = InMemoryDataAccessor(zmask) zmask_acc = InMemoryDataAccessor(zmask)
...@@ -105,18 +104,16 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): ...@@ -105,18 +104,16 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
def test_make_2d_patches(self): def test_make_2d_patches(self):
roiset = self._make_roi_set() roiset = self._make_roi_set()
files = export_patches_from_zstack( files = roiset.export_patches(
output_path / '2d_patches', output_path / '2d_patches',
roiset,
draw_bounding_box=True, draw_bounding_box=True,
) )
self.assertGreaterEqual(len(files), 1) self.assertGreaterEqual(len(files), 1)
def test_make_3d_patches(self): def test_make_3d_patches(self):
roiset = self._make_roi_set() roiset = self._make_roi_set()
files = export_patches_from_zstack( files = roiset.export_patches(
output_path / '3d_patches', output_path / '3d_patches',
roiset,
make_3d=True) make_3d=True)
self.assertGreaterEqual(len(files), 1) self.assertGreaterEqual(len(files), 1)
...@@ -129,12 +126,12 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): ...@@ -129,12 +126,12 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
self.assertEqual(result.shape, roiset.acc_raw.shape) self.assertEqual(result.shape, roiset.acc_raw.shape)
def test_flatten_image(self): def test_flatten_image(self):
id_map = get_label_ids(self.seg_mask) id_map = _get_label_ids(self.seg_mask)
roiset = RoiSet(id_map, self.stack_ch_pa, params=RoiSetMetaParams(mask_type='boxes')) roiset = RoiSet(self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='boxes'))
df = roiset.get_df() df = roiset.get_df()
from model_server.extensions.chaeo.zmask import project_stack_from_focal_points from model_server.extensions.chaeo.roiset import project_stack_from_focal_points
img = project_stack_from_focal_points( img = project_stack_from_focal_points(
df['centroid-0'].to_numpy(), df['centroid-0'].to_numpy(),
...@@ -227,10 +224,10 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa ...@@ -227,10 +224,10 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
id_map = get_label_ids(self.seg_mask) id_map = _get_label_ids(self.seg_mask)
self.roiset = RoiSet( self.roiset = RoiSet(
id_map,
self.stack, self.stack,
id_map,
params=RoiSetMetaParams( params=RoiSetMetaParams(
expand_box_by=(128, 2), expand_box_by=(128, 2),
mask_type='boxes', mask_type='boxes',
...@@ -239,9 +236,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa ...@@ -239,9 +236,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
) )
def test_multichannel_to_mono_2d_patches(self): def test_multichannel_to_mono_2d_patches(self):
files = export_patches_from_zstack( files = self.roiset.export_patches(
output_path / 'multichannel' / 'mono_2d_patches', output_path / 'multichannel' / 'mono_2d_patches',
self.roiset,
white_channel=3, white_channel=3,
draw_bounding_box=True, draw_bounding_box=True,
) )
...@@ -249,9 +245,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa ...@@ -249,9 +245,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
self.assertEqual(result.chroma, 1) self.assertEqual(result.chroma, 1)
def test_multichannnel_to_mono_2d_patches_rgb_bbox(self): def test_multichannnel_to_mono_2d_patches_rgb_bbox(self):
files = export_patches_from_zstack( files = self.roiset.export_patches(
output_path / 'multichannel' / 'mono_2d_patches_rgb_bbox', output_path / 'multichannel' / 'mono_2d_patches_rgb_bbox',
self.roiset,
white_channel=3, white_channel=3,
draw_bounding_box=True, draw_bounding_box=True,
bounding_box_channel=1, bounding_box_channel=1,
...@@ -260,9 +255,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa ...@@ -260,9 +255,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
self.assertEqual(result.chroma, 3) self.assertEqual(result.chroma, 3)
def test_multichannnel_to_rgb_2d_patches_bbox(self): def test_multichannnel_to_rgb_2d_patches_bbox(self):
files = export_patches_from_zstack( files = self.roiset.export_patches(
output_path / 'multichannel' / 'rgb_2d_patches_bbox', output_path / 'multichannel' / 'rgb_2d_patches_bbox',
self.roiset,
white_channel=4, white_channel=4,
rgb_overlay_channels=(3, None, None), rgb_overlay_channels=(3, None, None),
draw_mask=True, draw_mask=True,
...@@ -273,9 +267,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa ...@@ -273,9 +267,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
self.assertEqual(result.chroma, 3) self.assertEqual(result.chroma, 3)
def test_multichannnel_to_rgb_2d_patches_contour(self): def test_multichannnel_to_rgb_2d_patches_contour(self):
files = export_patches_from_zstack( files = self.roiset.export_patches(
output_path / 'multichannel' / 'rgb_2d_patches_contour', output_path / 'multichannel' / 'rgb_2d_patches_contour',
self.roiset,
rgb_overlay_channels=(3, None, None), rgb_overlay_channels=(3, None, None),
draw_contour=True, draw_contour=True,
contour_channel=1, contour_channel=1,
...@@ -286,9 +279,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa ...@@ -286,9 +279,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
self.assertEqual(result.get_one_channel_data(2).data.max(), 0) # blue channel is black self.assertEqual(result.get_one_channel_data(2).data.max(), 0) # blue channel is black
def test_multichannel_to_multichannel_tif_patches(self): def test_multichannel_to_multichannel_tif_patches(self):
files = export_patches_from_zstack( files = self.roiset.export_patches(
output_path / 'multichannel' / 'multichannel_tif_patches', output_path / 'multichannel' / 'multichannel_tif_patches',
self.roiset,
) )
result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename'])
self.assertEqual(result.chroma, 5) self.assertEqual(result.chroma, 5)
......
...@@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split ...@@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split
from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams
from model_server.extensions.chaeo.process import mask_largest_object from model_server.extensions.chaeo.process import mask_largest_object
from model_server.extensions.chaeo.zmask import get_label_ids, RoiSet from model_server.extensions.chaeo.roiset import _get_label_ids, RoiSet
from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.base.models import Model, InstanceSegmentationModel, SemanticSegmentationModel from model_server.base.models import Model, InstanceSegmentationModel, SemanticSegmentationModel
...@@ -48,7 +48,7 @@ def infer_object_map_from_zstack( ...@@ -48,7 +48,7 @@ def infer_object_map_from_zstack(
ti.click('classify_pixels') ti.click('classify_pixels')
# make zmask # make zmask
rois = RoiSet(get_label_ids(mip_mask), stack, params=roi_params) rois = RoiSet(stack, _get_label_ids(mip_mask), params=roi_params)
ti.click('generate_zmasks') ti.click('generate_zmasks')
rois.classify_by( rois.classify_by(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment