Skip to content
Snippets Groups Projects
Commit 49fb7101 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Instantiate models with keyword arguments now, no longer via "params"...

Instantiate models with keyword arguments now, no longer via "params" dictionary; although API still uses this name when passing Pydantic parameters models.  API describes model via its info property now, which passes instance property, not its init arguments.
parent 3c4aab06
No related branches found
No related tags found
No related merge requests found
......@@ -7,17 +7,15 @@ from .accessors import GenericImageDataAccessor, PatchStack
class Model(ABC):
def __init__(self, autoload=True, params: dict = None):
def __init__(self, autoload: bool = True, info: dict = None):
"""
Abstract base class for an inference model that uses image data as an input.
:param autoload: automatically load model and dependencies into memory if True
:param params: (optional) BaseModel of model parameters e.g. configuration files required to load model
:param info: optional dictionary of JSON-serializable information to report to API
"""
self.autoload = autoload
if params:
self.params = params
self.loaded = False
self._info = info
if not autoload:
return None
if self.load():
......@@ -26,6 +24,10 @@ class Model(ABC):
raise CouldNotLoadModelError()
return None
@property
def info(self):
return self._info
@abstractmethod
def load(self):
"""
......@@ -87,13 +89,14 @@ class SemanticSegmentationModel(ImageToImageModel):
class BinaryThresholdSegmentationModel(SemanticSegmentationModel):
"""
Trivial but functional model that labels all pixels above an intensity threshold as class 1
"""
def __init__(self, params=None):
self.tr = params.get('tr', 0.5)
self.channel = params.get('channel', 0)
def __init__(self, tr: float = 0.5, channel: int = 0):
"""
Model that labels all pixels as class 1 if the intensity in specified channel exceeds a threshold.
:param tr: threshold in range of 0.0 to 1.0; model handles normalization to full pixel intensity range
:param channel: channel to use for thresholding
"""
self.tr = tr
self.channel = channel
self.loaded = self.load()
def infer(self, acc: GenericImageDataAccessor) -> GenericImageDataAccessor:
......@@ -155,15 +158,15 @@ class InstanceMaskSegmentationModel(ImageToImageModel):
class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationModel):
"""
Model that labels all objects as class 1 if the intensity in a given channel exceeds threshold;
threshold = 0.0 means that all objects are returned class 1.
"""
def __init__(self, params=None):
# TODO: make model params kwargs to constructor, makes debugging much easier
self.tr = params.get('tr', 0.5)
self.channel = params.get('channel', 0)
def __init__(self, tr: float = 0.5, channel: int = 0):
"""
Model that labels all objects as class 1 if the intensity in specified channel exceeds a threshold; labels all
objects as class 1 if threshold = 0.0
:param tr: threshold in range of 0.0 to 1.0; model handles normalization to full pixel intensity range
:param channel: channel to use for thresholding
"""
self.tr = tr
self.channel = channel
self.loaded = self.load()
def load(self):
......
......@@ -329,10 +329,13 @@ class _Session(object):
Load an instance of a given model class and attach to this session's model registry
:param ModelClass: subclass of Model
:param key: unique identifier of model, or autogenerate if None
:param params: optional parameters that are passed to the model's construct
:param params: optional parameters that are passed to the model's constructor
:return: model_id of loaded model
"""
mi = ModelClass(params=params.dict() if params else None)
if params:
mi = ModelClass(**params.dict())
else:
mi = ModelClass()
assert mi.loaded, f'Error loading instance of {ModelClass.__name__}'
ii = 0
......@@ -349,7 +352,7 @@ class _Session(object):
self.models[key] = {
'object': mi,
'params': getattr(mi, 'params', None)
'info': getattr(mi, 'info', None)
}
self.log_info(f'Loaded model {key}')
return key
......@@ -358,7 +361,7 @@ class _Session(object):
return {
k: {
'class': self.models[k]['object'].__class__.__name__,
'params': self.models[k]['params'],
'info': self.models[k]['info'],
}
for k in self.models.keys()
}
......
......@@ -15,16 +15,15 @@ from ...base.models import Model, ImageToImageModel, InstanceMaskSegmentationMod
class IlastikModel(Model):
def __init__(self, params: dict, autoload=True, enforce_embedded=True):
def __init__(self, project_file, autoload=True, enforce_embedded=True):
"""
Base class for models that run via ilastik shell API
:param params:
project_file: path to ilastik project file
:param: project_file: path to ilastik project file
:param autoload: automatically load model into memory if true
:param enforce_embedded:
raise an error if all input data are not embedded in the project file, i.e. on the filesystem
"""
pf = Path(params['project_file'])
pf = Path(project_file)
self.enforce_embedded = enforce_embedded
if pf.is_absolute():
pap = pf
......@@ -37,7 +36,7 @@ class IlastikModel(Model):
raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file')
self.shell = None
super().__init__(autoload, params)
super().__init__(autoload, info={'project_file': project_file})
def load(self):
# suppress warnings when loading ilastik app
......@@ -98,8 +97,8 @@ class IlastikModel(Model):
class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
operations = ['segment', ]
def __init__(self, params: dict, **kwargs):
super(IlastikPixelClassifierModel, self).__init__(params, **kwargs)
def __init__(self, **kwargs):
super(IlastikPixelClassifierModel, self).__init__(**kwargs)
@staticmethod
def get_workflow():
......@@ -158,9 +157,9 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs):
pxmap = self.infer(img)
mask = pxmap.get_mono(
self.params['px_class']
self.px_class,
).apply(
lambda x: x > self.params['px_prob_threshold']
lambda x: x > self.px_prob_threshold,
)
return mask
......
......@@ -15,7 +15,7 @@ class TestCziImageFileAccess(unittest.TestCase):
self.cf = CziImageFileAccessor(czifile['path'])
def test_instantiate_model(self):
model = DummySemanticSegmentationModel(params=None)
model = DummySemanticSegmentationModel()
self.assertTrue(model.loaded)
def test_instantiate_model_with_nondefault_kwarg(self):
......@@ -58,7 +58,7 @@ class TestCziImageFileAccess(unittest.TestCase):
return img, mask
def test_binary_segmentation(self):
model = BinaryThresholdSegmentationModel({'tr': 3e4})
model = BinaryThresholdSegmentationModel(tr=3e4)
res = model.label_pixel_class(self.cf)
self.assertTrue(res.is_mask())
......@@ -71,6 +71,6 @@ class TestCziImageFileAccess(unittest.TestCase):
def test_permissive_instance_segmentation(self):
img, mask = self.test_dummy_pixel_segmentation()
model = IntensityThresholdInstanceMaskSegmentationModel(params={})
model = IntensityThresholdInstanceMaskSegmentationModel()
obmap = model.label_instance_class(img, mask)
self.assertTrue(np.all(mask.data == 255 * obmap.data))
......@@ -62,11 +62,11 @@ class BaseTestRoiSetMonoProducts(object):
return {
'pixel_classifier_segmentation': {
'name': 'min_px_mod',
'model': BinaryThresholdSegmentationModel({'tr': 1e4}),
'model': BinaryThresholdSegmentationModel(tr=1e4),
},
'object_classifier': {
'name': 'min_ob_mod',
'model': IntensityThresholdInstanceMaskSegmentationModel({}),
'model': IntensityThresholdInstanceMaskSegmentationModel(),
},
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment