diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index f68ac353ffa69e5832dfe9aed096f857583e764f..a75212d939dab2b336075cf94a563cea508da568 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -331,11 +331,17 @@ class PatchStack(InMemoryDataAccessor): tifffile.imwrite(fpath, tzcyx, imagej=True) def get_one_channel_data(self, channel: int, mip: bool = False): - c = int(channel) + return self.get_channels([channel], mip=mip) + + def get_channels(self, channels: list, mip: bool = False): + carr = [int(c) for c in channels] if mip: - return PatchStack(self.pyxcz[:, :, :, c:(c + 1), :].max(axis=-1, keepdims=True)) + return PatchStack(self.pyxcz[:, :, :, carr, :].max(axis=-1, keepdims=True)) else: - return PatchStack(self.pyxcz[:, :, :, c:(c + 1), :]) + return PatchStack(self.pyxcz[:, :, :, carr, :]) + + def gch(self, args): + return self.get_channels(args) @property def shape_dict(self): diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index f87bb8c8dcaffaddd64eb628f6bca2ab1192beb1..e8b3d04e128e7edb9993fc67ba62f43e75eb892c 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -308,11 +308,24 @@ class RoiSet(object): return zi_st - def classify_by(self, name: str, channel: int, object_classification_model: InstanceSegmentationModel, ): + def classify_by(self, name: str, channels: int, object_classification_model: InstanceSegmentationModel, derived_channel_functions: dict[callable()] = None): + + raw_acc = self.get_patches_acc(expanded=False, pad_to=None) # all channels + if derived_channel_functions: + mono_data = [raw_acc.get_one_channel_data(c).data for c in range(0, raw_acc.chroma)] + for k, fcn in derived_channel_functions.keys(): + der = fcn(raw_acc).data + assert der.shape == mono_data[0].shape + mono_data.append(der) + # combine channels + input_acc = PatchStack(np.concatenate([mono_data], axis=3)) + + else: + input_acc = raw_acc # do this on a patch basis, i.e. only one object per frame obmap_patches = object_classification_model.label_patch_stack( - self.get_patches_acc(channel=channel, expaned=False, pad_to=None), + input_acc, self.get_patch_masks_acc(expanded=False, pad_to=None) )