Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
model_server
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Christopher Randolph Rhodes
model_server
Commits
76f7ceed
Commit
76f7ceed
authored
11 months ago
by
Christopher Randolph Rhodes
Browse files
Options
Downloads
Patches
Plain Diff
Added doc string to explain callables for derived channels
parent
e4eba50f
No related branches found
No related tags found
3 merge requests
!37
Release 2024.04.19
,
!34
Revert "Temporary error-handling for debug..."
,
!30
Accessor changes to support object classification
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
model_server/base/roiset.py
+19
-3
19 additions, 3 deletions
model_server/base/roiset.py
tests/test_roiset.py
+4
-4
4 additions, 4 deletions
tests/test_roiset.py
with
23 additions
and
7 deletions
model_server/base/roiset.py
+
19
−
3
View file @
76f7ceed
...
@@ -308,12 +308,28 @@ class RoiSet(object):
...
@@ -308,12 +308,28 @@ class RoiSet(object):
return
zi_st
return
zi_st
def
classify_by
(
self
,
name
:
str
,
channels
:
int
,
object_classification_model
:
InstanceSegmentationModel
,
derived_channel_functions
:
dict
[
callable
]
=
None
):
def
classify_by
(
self
,
name
:
str
,
channels
:
list
[
int
],
object_classification_model
:
InstanceSegmentationModel
,
derived_channel_functions
:
list
[
callable
]
=
None
):
"""
Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing
specified inputs through an instance segmentation classifier. Optionally derive additional inputs for object
classification by passing a raw input channel through one or more functions.
:param name: name of column to insert
:param channels: list of nc raw input channels to send to classifier
:param object_classification_model: InstanceSegmentation model object
:param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and
return a single-channel PatchStack accessor of the same shape
:return: None
"""
raw_acc
=
self
.
get_patches_acc
(
channels
=
channels
,
expanded
=
False
,
pad_to
=
None
)
# all channels
raw_acc
=
self
.
get_patches_acc
(
channels
=
channels
,
expanded
=
False
,
pad_to
=
None
)
# all channels
if
derived_channel_functions
:
if
derived_channel_functions
is
not
None
:
mono_data
=
[
raw_acc
.
get_one_channel_data
(
c
).
data
for
c
in
range
(
0
,
raw_acc
.
chroma
)]
mono_data
=
[
raw_acc
.
get_one_channel_data
(
c
).
data
for
c
in
range
(
0
,
raw_acc
.
chroma
)]
for
k
,
fcn
in
derived_channel_functions
.
items
()
:
for
fcn
in
derived_channel_functions
:
der
=
fcn
(
raw_acc
)
# returns patch stack
der
=
fcn
(
raw_acc
)
# returns patch stack
assert
der
.
shape
==
mono_data
[
0
].
shape
assert
der
.
shape
==
mono_data
[
0
].
shape
mono_data
.
append
(
der
.
data
)
mono_data
.
append
(
der
.
data
)
...
...
This diff is collapsed.
Click to expand it.
tests/test_roiset.py
+
4
−
4
View file @
76f7ceed
...
@@ -217,10 +217,10 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
...
@@ -217,10 +217,10 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
'
multiple_input_model
'
,
'
multiple_input_model
'
,
[
0
,
1
],
[
0
,
1
],
ModelWithDerivedInputs
(),
ModelWithDerivedInputs
(),
derived_channel_functions
=
{
derived_channel_functions
=
[
'
der1
'
:
lambda
acc
:
PatchStack
(
2
*
acc
.
data
),
lambda
acc
:
PatchStack
(
2
*
acc
.
data
),
'
der2
'
:
lambda
acc
:
PatchStack
(
0.5
*
acc
.
data
)
lambda
acc
:
PatchStack
(
0.5
*
acc
.
data
)
}
]
)
)
self
.
assertTrue
(
all
(
roiset
.
get_df
()[
'
classify_by_multiple_input_model
'
].
unique
()
==
[
3
]))
self
.
assertTrue
(
all
(
roiset
.
get_df
()[
'
classify_by_multiple_input_model
'
].
unique
()
==
[
3
]))
self
.
assertTrue
(
all
(
np
.
unique
(
roiset
.
object_class_maps
[
'
multiple_input_model
'
].
data
)
==
[
0
,
3
]))
self
.
assertTrue
(
all
(
np
.
unique
(
roiset
.
object_class_maps
[
'
multiple_input_model
'
].
data
)
==
[
0
,
3
]))
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment