Skip to content
Snippets Groups Projects
Commit 7d7f0552 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Pass notebook filename and not path to load an ilastik model; assumes project...

Pass notebook filename and not path to load an ilastik model; assumes project files are in conf.server.paths['ilastik']
parent 82b77d5e
No related branches found
No related tags found
No related merge requests found
...@@ -9,8 +9,8 @@ paths = { ...@@ -9,8 +9,8 @@ paths = {
'inbound': root / 'images' / 'inbound', 'inbound': root / 'images' / 'inbound',
'outbound': root / 'images' / 'outbound', 'outbound': root / 'images' / 'outbound',
}, },
'ilastik' : { 'ilastik': {
'projects' : root / 'ilastik' 'projects': root / 'ilastik'
} }
} }
......
...@@ -11,9 +11,14 @@ czifile = { ...@@ -11,9 +11,14 @@ czifile = {
'z': 1, 'z': 1,
} }
# ilastik = {
# 'pixel_classifier': root / 'testdata' / 'ilastik' / 'demo_px.ilp',
# 'object_classifier': root / 'testdata' / 'ilastik' / 'demo_obj.ilp',
# }
ilastik = { ilastik = {
'pixel_classifier': root / 'testdata' / 'ilastik' / 'demo_px.ilp', 'pixel_classifier': 'demo_px.ilp',
'object_classifier': root / 'testdata' / 'ilastik' / 'demo_obj.ilp', 'object_classifier': 'demo_obj.ilp',
} }
output_path = root / 'testing_output' output_path = root / 'testing_output'
......
import os import os
import pathlib
import numpy as np import numpy as np
import vigra import vigra
import conf.server
from model_server.image import GenericImageDataAccessor, InMemoryDataAccessor from model_server.image import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.model import ImageToImageModel, ParameterExpectedError from model_server.model import ImageToImageModel, ParameterExpectedError
...@@ -10,9 +12,13 @@ from model_server.model import ImageToImageModel, ParameterExpectedError ...@@ -10,9 +12,13 @@ from model_server.model import ImageToImageModel, ParameterExpectedError
class IlastikImageToImageModel(ImageToImageModel): class IlastikImageToImageModel(ImageToImageModel):
def __init__(self, params, autoload=True): def __init__(self, params, autoload=True):
if 'project_file' not in params or not os.path.exists(params['project_file']):
raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file')
self.project_file = str(params['project_file']) self.project_file = str(params['project_file'])
self.project_file_abspath = pathlib.Path(
conf.server.paths['ilastik']['projects'] / self.project_file,
)
if 'project_file' not in params or not self.project_file_abspath.exists():
raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file')
self.shell = None self.shell = None
super().__init__(autoload, params) super().__init__(autoload, params)
...@@ -28,7 +34,7 @@ class IlastikImageToImageModel(ImageToImageModel): ...@@ -28,7 +34,7 @@ class IlastikImageToImageModel(ImageToImageModel):
args = app.parse_args([]) args = app.parse_args([])
args.headless = True args.headless = True
args.project = self.project_file args.project = self.project_file_abspath.__str__()
shell = app.main(args) shell = app.main(args)
if not isinstance(shell.workflow, self.get_workflow()): if not isinstance(shell.workflow, self.get_workflow()):
......
...@@ -3,7 +3,7 @@ import unittest ...@@ -3,7 +3,7 @@ import unittest
import numpy as np import numpy as np
from conf.testing import czifile, ilastik, output_path import conf.testing
from model_server.image import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.image import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.ilastik import IlastikObjectClassifierModel, IlastikPixelClassifierModel from model_server.ilastik import IlastikObjectClassifierModel, IlastikPixelClassifierModel
from model_server.model import Model from model_server.model import Model
...@@ -12,7 +12,7 @@ from tests.test_api import TestServerBaseClass ...@@ -12,7 +12,7 @@ from tests.test_api import TestServerBaseClass
class TestIlastikPixelClassification(unittest.TestCase): class TestIlastikPixelClassification(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.cf = CziImageFileAccessor(czifile['path']) self.cf = CziImageFileAccessor(conf.testing.czifile['path'])
def test_faulthandler(self): # recreate error that is messing up ilastik def test_faulthandler(self): # recreate error that is messing up ilastik
...@@ -25,18 +25,23 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -25,18 +25,23 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_raise_error_if_autoload_disabled(self): def test_raise_error_if_autoload_disabled(self):
model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}, autoload=False) model = IlastikPixelClassifierModel(
{'project_file': conf.testing.ilastik['pixel_classifier']},
autoload=False
)
w = 512 w = 512
h = 256 h = 256
input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1)) input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1))
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
pxmap , _= model.infer(input_img) pxmap, _ = model.infer(input_img)
def test_run_pixel_classifier_on_random_data(self): def test_run_pixel_classifier_on_random_data(self):
model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}) model = IlastikPixelClassifierModel(
{'project_file': conf.testing.ilastik['pixel_classifier']},
)
w = 512 w = 512
h = 256 h = 256
...@@ -48,12 +53,16 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -48,12 +53,16 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_run_pixel_classifier(self): def test_run_pixel_classifier(self):
channel = 0 channel = 0
model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}) model = IlastikPixelClassifierModel(
cf = CziImageFileAccessor(czifile['path']) {'project_file': conf.testing.ilastik['pixel_classifier']}
)
cf = CziImageFileAccessor(
conf.testing.czifile['path']
)
mono_image = cf.get_one_channel_data(channel) mono_image = cf.get_one_channel_data(channel)
self.assertEqual(mono_image.shape_dict['X'], czifile['w']) self.assertEqual(mono_image.shape_dict['X'], conf.testing.czifile['w'])
self.assertEqual(mono_image.shape_dict['Y'], czifile['h']) self.assertEqual(mono_image.shape_dict['Y'], conf.testing.czifile['h'])
self.assertEqual(mono_image.shape_dict['C'], 1) self.assertEqual(mono_image.shape_dict['C'], 1)
self.assertEqual(mono_image.shape_dict['Z'], 1) self.assertEqual(mono_image.shape_dict['Z'], 1)
...@@ -64,7 +73,7 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -64,7 +73,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.assertEqual(pxmap.shape_dict['Z'], 1) self.assertEqual(pxmap.shape_dict['Z'], 1)
self.assertTrue( self.assertTrue(
write_accessor_data_to_file( write_accessor_data_to_file(
output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif', conf.testing.output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
pxmap pxmap
) )
) )
...@@ -74,13 +83,15 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -74,13 +83,15 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_run_object_classifier(self): def test_run_object_classifier(self):
self.test_run_pixel_classifier() self.test_run_pixel_classifier()
fp = czifile['path'] fp = conf.testing.czifile['path']
model = IlastikObjectClassifierModel({'project_file': ilastik['object_classifier']}) model = IlastikObjectClassifierModel(
{'project_file': conf.testing.ilastik['object_classifier']}
)
objmap, _ = model.infer(self.mono_image, self.pxmap) objmap, _ = model.infer(self.mono_image, self.pxmap)
self.assertTrue( self.assertTrue(
write_accessor_data_to_file( write_accessor_data_to_file(
output_path / f'obmap_{fp.stem}.tif', conf.testing.output_path / f'obmap_{fp.stem}.tif',
objmap, objmap,
) )
) )
...@@ -88,9 +99,11 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -88,9 +99,11 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_ilastik_pixel_classification_as_workflow(self): def test_ilastik_pixel_classification_as_workflow(self):
result = infer_image_to_image( result = infer_image_to_image(
czifile['path'], conf.testing.czifile['path'],
IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}), IlastikPixelClassifierModel(
output_path, {'project_file': conf.testing.ilastik['pixel_classifier']}
),
conf.testing.output_path,
channel=0, channel=0,
) )
self.assertTrue(result.success) self.assertTrue(result.success)
...@@ -100,7 +113,7 @@ class TestIlastikOverApi(TestServerBaseClass): ...@@ -100,7 +113,7 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_load_ilastik_pixel_model(self): def test_load_ilastik_pixel_model(self):
resp_load = requests.put( resp_load = requests.put(
self.uri + 'models/ilastik/pixel_classification/load/', self.uri + 'models/ilastik/pixel_classification/load/',
params={'project_file': str(ilastik['pixel_classifier'])}, params={'project_file': str(conf.testing.ilastik['pixel_classifier'])},
) )
model_id = resp_load.json()['model_id'] model_id = resp_load.json()['model_id']
...@@ -116,7 +129,7 @@ class TestIlastikOverApi(TestServerBaseClass): ...@@ -116,7 +129,7 @@ class TestIlastikOverApi(TestServerBaseClass):
def test_load_ilastik_object_model(self): def test_load_ilastik_object_model(self):
resp_load = requests.put( resp_load = requests.put(
self.uri + 'models/ilastik/object_classification/load/', self.uri + 'models/ilastik/object_classification/load/',
params={'project_file': str(ilastik['object_classifier'])}, params={'project_file': str(conf.testing.ilastik['object_classifier'])},
) )
model_id = resp_load.json()['model_id'] model_id = resp_load.json()['model_id']
...@@ -134,8 +147,8 @@ class TestIlastikOverApi(TestServerBaseClass): ...@@ -134,8 +147,8 @@ class TestIlastikOverApi(TestServerBaseClass):
self.uri + f'infer/from_image_file', self.uri + f'infer/from_image_file',
params={ params={
'model_id': model_id, 'model_id': model_id,
'input_filename': czifile['filename'], 'input_filename': conf.testing.czifile['filename'],
'channel': 2, 'channel': 0,
}, },
) )
self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment