Commit 206f19d8 authored by Constantin Pape's avatar Constantin Pape
Browse files

Add flag to save models to fiji

parent 5ab5c806
......@@ -11,10 +11,6 @@ dependencies:
- pip
- scikit-image
- tqdm
- tensorflow-gpu 1.15
- pip:
# I think this is the gpu version, it appears that tensorflow changed it's naming convention at some point:
# tensorflow < 2 gpu version is called tensorflow
# tensorflow >= 2 gpu version is called tensorflow-gpu
# In summary, tensorflow is a hot mess...
- tensorflow==1.15
- stardist
......@@ -146,28 +146,36 @@ def train_model(x_train, y_train, x_val, y_val, save_path,
model.train(x_train, y_train, validation_data=(x_val, y_val), augmenter=augmenter)
optimal_parameters = model.optimize_thresholds(x_val, y_val)
return optimal_parameters
return model, optimal_parameters
def train_stardist_model(root, model_save_path, image_folder, labels_folder, ext,
validation_fraction, patch_size, multichannel):
validation_fraction, patch_size, multichannel,
save_for_fiji):
print("Loading training data")
train_images, train_labels, n_channels = load_training_data(root, image_folder, labels_folder,
train_images, train_labels, n_channels = load_training_data(root,
image_folder, labels_folder,
ext, multichannel)
print("Found", len(train_images), "images and label masks for training")
x_train, y_train, x_val, y_val = make_train_val_split(train_images, train_labels,
validation_fraction)
print("Made train validation split with validation fraction", validation_fraction, "resulting in")
print("Made train validation split with validation fraction",
validation_fraction, "resulting in")
print(len(x_train), "training images")
print(len(y_train), "validation images")
print("Start model training ...")
print("You can connect to the tensorboard by typing 'tensorboaed --logdir=.' in the folder where the training runs")
optimal_parameters = train_model(x_train, y_train, x_val, y_val, model_save_path,
n_channels, patch_size)
model, opt_params = train_model(x_train, y_train, x_val, y_val, model_save_path,
n_channels, patch_size)
print("The model has been trained and was saved to", model_save_path)
print("The following optimal parameters were found:", optimal_parameters)
print("The following optimal parameters were found:", opt_params)
if save_for_fiji:
fiji_save_path = os.path.join(model_save_path, 'TF_SavedModel.zip')
print("Saving model for fiji", fiji_save_path)
model.export_TF()
# use configarparse?
......@@ -175,25 +183,30 @@ def train_stardist_model(root, model_save_path, image_folder, labels_folder, ext
# TODO enable excluding images by name
def main():
parser = argparse.ArgumentParser(description="Train a 2d stardist model")
parser.add_argument('root', type=str, help="Root folder with folders for the training images and labels.")
parser.add_argument('root', type=str,
help="Root folder with folders for the training images and labels.")
parser.add_argument('model_save_path', type=str, help="Where to save the model.")
parser.add_argument('--image_folder', type=str, default='images',
help="Name of the folder with the training images, default: images.")
parser.add_argument('--labels_folder', type=str, default='labels',
help="Name of the folder with the training labels, default: labels.")
parser.add_argument('--ext', type=str, default='.tif', help="Image file extension, default: .tif")
parser.add_argument('--ext', type=str, default='.tif',
help="Image file extension, default: .tif")
parser.add_argument('--validation_fraction', type=float, default=.1,
help="The fraction of available data that is used for validation, default: .1")
parser.add_argument('--patch_size', type=int, nargs=2, default=[256, 256],
help="Size of the image patches used to train the network, default: 256, 256")
parser.add_argument('--multichannel', type=int, default=0,
help="Do we have multichannel images? Default: 0")
parser.add_argument('--save_for_fiji', type=int, default=0,
help="Save the model for FIJI, default: 0")
args = parser.parse_args()
train_stardist_model(args.root, args.model_save_path,
args.image_folder, args.labels_folder,
args.ext, args.validation_fraction,
tuple(args.patch_size), bool(args.multichannel))
tuple(args.patch_size), bool(args.multichannel),
bool(args.save_for_fiji))
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment