Skip to content
Snippets Groups Projects
Commit ef2f74ad authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Added accessor method to get multiple channels

parent a6fbeea4
No related branches found
No related tags found
No related merge requests found
...@@ -331,11 +331,17 @@ class PatchStack(InMemoryDataAccessor): ...@@ -331,11 +331,17 @@ class PatchStack(InMemoryDataAccessor):
tifffile.imwrite(fpath, tzcyx, imagej=True) tifffile.imwrite(fpath, tzcyx, imagej=True)
def get_one_channel_data(self, channel: int, mip: bool = False): def get_one_channel_data(self, channel: int, mip: bool = False):
c = int(channel) return self.get_channels([channel], mip=mip)
def get_channels(self, channels: list, mip: bool = False):
carr = [int(c) for c in channels]
if mip: if mip:
return PatchStack(self.pyxcz[:, :, :, c:(c + 1), :].max(axis=-1, keepdims=True)) return PatchStack(self.pyxcz[:, :, :, carr, :].max(axis=-1, keepdims=True))
else: else:
return PatchStack(self.pyxcz[:, :, :, c:(c + 1), :]) return PatchStack(self.pyxcz[:, :, :, carr, :])
def gch(self, args):
return self.get_channels(args)
@property @property
def shape_dict(self): def shape_dict(self):
......
...@@ -308,11 +308,24 @@ class RoiSet(object): ...@@ -308,11 +308,24 @@ class RoiSet(object):
return zi_st return zi_st
def classify_by(self, name: str, channel: int, object_classification_model: InstanceSegmentationModel, ): def classify_by(self, name: str, channels: int, object_classification_model: InstanceSegmentationModel, derived_channel_functions: dict[callable()] = None):
raw_acc = self.get_patches_acc(expanded=False, pad_to=None) # all channels
if derived_channel_functions:
mono_data = [raw_acc.get_one_channel_data(c).data for c in range(0, raw_acc.chroma)]
for k, fcn in derived_channel_functions.keys():
der = fcn(raw_acc).data
assert der.shape == mono_data[0].shape
mono_data.append(der)
# combine channels
input_acc = PatchStack(np.concatenate([mono_data], axis=3))
else:
input_acc = raw_acc
# do this on a patch basis, i.e. only one object per frame # do this on a patch basis, i.e. only one object per frame
obmap_patches = object_classification_model.label_patch_stack( obmap_patches = object_classification_model.label_patch_stack(
self.get_patches_acc(channel=channel, expaned=False, pad_to=None), input_acc,
self.get_patch_masks_acc(expanded=False, pad_to=None) self.get_patch_masks_acc(expanded=False, pad_to=None)
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment