Skip to content
Snippets Groups Projects
Commit ddb9b416 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Pass function handles to add derived classes to object classification

parent da6b184a
No related branches found
No related tags found
3 merge requests!37Release 2024.04.19,!34Revert "Temporary error-handling for debug...",!30Accessor changes to support object classification
...@@ -318,7 +318,7 @@ class RoiSet(object): ...@@ -318,7 +318,7 @@ class RoiSet(object):
assert der.shape == mono_data[0].shape assert der.shape == mono_data[0].shape
mono_data.append(der.data) mono_data.append(der.data)
# combine channels # combine channels
input_acc = PatchStack(np.concatenate([mono_data], axis=3)) input_acc = PatchStack(np.concatenate(mono_data, axis=raw_acc._ga('C')))
else: else:
input_acc = raw_acc input_acc = raw_acc
......
...@@ -208,15 +208,24 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): ...@@ -208,15 +208,24 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
return roiset return roiset
def test_classify_by_with_derived_channel(self): def test_classify_by_with_derived_channel(self):
class ModelWithDerivedInputs(DummyInstanceSegmentationModel):
def infer(self, img, mask):
return PatchStack(super().infer(img, mask).data * img.chroma)
roiset = self._make_roi_set() roiset = self._make_roi_set()
roiset.classify_by( roiset.classify_by(
'dummy_class', 'multiple_input_model',
[0, 1], [0, 1],
DummyInstanceSegmentationModel(), ModelWithDerivedInputs(),
derived_channel_functions={'der1': lambda acc: PatchStack(2*acc.data)} derived_channel_functions={
'der1': lambda acc: PatchStack(2 * acc.data),
'der2': lambda acc: PatchStack(0.5 * acc.data)
}
) )
self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [3]))
self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) print(roiset.get_df().columns)
print(roiset.get_df()['classify_by_multiple_input_model'])
self.assertTrue(all(np.unique(roiset.object_class_maps['multiple_input_model'].data) == [0, 3]))
return roiset return roiset
def test_export_object_classes(self): def test_export_object_classes(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment