Skip to content
Snippets Groups Projects
Commit 76f7ceed authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Added doc string to explain callables for derived channels

parent e4eba50f
No related branches found
No related tags found
3 merge requests!37Release 2024.04.19,!34Revert "Temporary error-handling for debug...",!30Accessor changes to support object classification
...@@ -308,12 +308,28 @@ class RoiSet(object): ...@@ -308,12 +308,28 @@ class RoiSet(object):
return zi_st return zi_st
def classify_by(self, name: str, channels: int, object_classification_model: InstanceSegmentationModel, derived_channel_functions: dict[callable] = None): def classify_by(
self, name: str, channels: list[int],
object_classification_model: InstanceSegmentationModel,
derived_channel_functions: list[callable] = None
):
"""
Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing
specified inputs through an instance segmentation classifier. Optionally derive additional inputs for object
classification by passing a raw input channel through one or more functions.
:param name: name of column to insert
:param channels: list of nc raw input channels to send to classifier
:param object_classification_model: InstanceSegmentation model object
:param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and
return a single-channel PatchStack accessor of the same shape
:return: None
"""
raw_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) # all channels raw_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) # all channels
if derived_channel_functions: if derived_channel_functions is not None:
mono_data = [raw_acc.get_one_channel_data(c).data for c in range(0, raw_acc.chroma)] mono_data = [raw_acc.get_one_channel_data(c).data for c in range(0, raw_acc.chroma)]
for k, fcn in derived_channel_functions.items(): for fcn in derived_channel_functions:
der = fcn(raw_acc) # returns patch stack der = fcn(raw_acc) # returns patch stack
assert der.shape == mono_data[0].shape assert der.shape == mono_data[0].shape
mono_data.append(der.data) mono_data.append(der.data)
......
...@@ -217,10 +217,10 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): ...@@ -217,10 +217,10 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
'multiple_input_model', 'multiple_input_model',
[0, 1], [0, 1],
ModelWithDerivedInputs(), ModelWithDerivedInputs(),
derived_channel_functions={ derived_channel_functions=[
'der1': lambda acc: PatchStack(2 * acc.data), lambda acc: PatchStack(2 * acc.data),
'der2': lambda acc: PatchStack(0.5 * acc.data) lambda acc: PatchStack(0.5 * acc.data)
} ]
) )
self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [3])) self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [3]))
self.assertTrue(all(np.unique(roiset.object_class_maps['multiple_input_model'].data) == [0, 3])) self.assertTrue(all(np.unique(roiset.object_class_maps['multiple_input_model'].data) == [0, 3]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment