Skip to content
Snippets Groups Projects
Commit 6f12d7e9 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Object classification and object map creation are now separate functions; no...

Object classification and object map creation are now separate functions; no longer persist object maps.
parent 2c7493d6
No related branches found
No related tags found
2 merge requests!65Release 2024.10.01,!57RoiSet facilitates object detection models
This commit is part of merge request !57. Comments created here will be created in the context of that merge request.
......@@ -278,7 +278,7 @@ class RoiSet(object):
self._df = df
self.count = len(self._df)
self.object_class_maps = {} # classification results
# self.object_class_maps = {} # classification results
def __iter__(self):
"""Expose ROI meta information via the Pandas.DataFrame API"""
......@@ -415,7 +415,7 @@ class RoiSet(object):
:param channels: list of nc raw input channels to send to classifier
:param object_classification_model: InstanceSegmentation model object
:param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and
return a single-channel PatchStack accessor of the same shape
that return a single-channel PatchStack accessor of the same shape
:return: None
"""
......@@ -448,12 +448,8 @@ class RoiSet(object):
self.get_patch_masks_acc(expanded=False, pad_to=None)
)
om = np.zeros(self.acc_obj_ids.shape, self.acc_obj_ids.dtype)
self._df['classify_by_' + name] = pd.Series(dtype='Int64')
# TODO: separate method to get object map
# assign labels to object map:
for i, roi in enumerate(self):
oc = np.unique(
mask_largest_object(
......@@ -461,9 +457,23 @@ class RoiSet(object):
)
)[-1]
self._df.loc[roi.Index, 'classify_by_' + name] = oc
om[self.acc_obj_ids.data == roi.label] = oc
self.object_class_maps[name] = InMemoryDataAccessor(om)
def get_object_class_map(self, name: str) -> InMemoryDataAccessor:
"""
For a given classification result, return a map where object IDs are replaced by each object's class
:param name: name of the classification result, same as passed to RoiSet.classify_by()
:return: accessor of object class map
"""
colname = ('classify_by_' + name)
assert colname in self._df.columns
obj_ids = self.acc_obj_ids
om = np.zeros(obj_ids.shape, obj_ids.dtype)
def _label_object_class(roi):
om[self.acc_obj_ids.data == roi.label] = roi[colname]
self._df.apply(_label_object_class, axis=1)
return InMemoryDataAccessor(om)
def export_dataframe(self, csv_path: Path) -> str:
csv_path.parent.mkdir(parents=True, exist_ok=True)
......@@ -510,7 +520,6 @@ class RoiSet(object):
patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8')
patch[roi.relative_slice][:, :, 0, 0] = roi.binary_mask * 255
else:
# TODO: shape issues here
patch = (roi.binary_mask * 255).astype('uint8')
if pad_to:
patch = pad(patch, pad_to)
......@@ -696,10 +705,13 @@ class RoiSet(object):
if k == 'annotated_zstacks':
record[k] = str(Path(k) / self.export_annotated_zstack(subdir, prefix=pr, **kp))
if k == 'object_classes':
for kc, acc in self.object_class_maps.items():
fp = subdir / kc / (pr + '.tif')
write_accessor_data_to_file(fp, acc)
record[f'{k}_{kc}'] = str(fp)
# for kc, acc in self.object_class_maps.items():
pr = 'classify_by_'
cnames = [c.split(pr)[1] for c in self._df.columns if c.startswith(pr)]
for n in cnames:
fp = subdir / n / (pr + '.tif')
write_accessor_data_to_file(fp, self.get_object_class_map(n))
record[f'{k}_{n}'] = str(fp)
if k == 'derived_channels':
record[k] = []
for di, dacc in enumerate(self.accs_derived):
......
......@@ -240,14 +240,14 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
roiset = self._make_roi_set()
roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel())
self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1]))
self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1]))
self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1]))
return roiset
def test_classify_by_multiple_channels(self):
roiset = RoiSet.from_binary_mask(self.stack, self.seg_mask)
roiset.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel())
self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1]))
self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1]))
self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1]))
return roiset
def test_classify_by_with_derived_channel(self):
......@@ -272,7 +272,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
]
)
self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [4]))
self.assertTrue(all(np.unique(roiset.object_class_maps['multiple_input_model'].data) == [0, 4]))
self.assertTrue(all(np.unique(roiset.get_object_class_map('multiple_input_model').data) == [0, 4]))
self.assertEqual(len(roiset.accs_derived), 2)
for di in roiset.accs_derived:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment