Commit 277d3494 authored by Constantin Pape's avatar Constantin Pape

Enable loading pretrained model in stardist-2d training

parent 8af4a479
import argparse
import os
from glob import glob
from shutil import copytree
import imageio
import numpy as np
......@@ -8,7 +9,7 @@ import numpy as np
from csbdeep.utils import normalize
from stardist import fill_label_holes, gputools_available
from stardist.models import Config2D, StarDist2D
from stardist_model_to_fiji import stardist_model_to_fiji
from .stardist_model_to_fiji import stardist_model_to_fiji
def check_training_data(train_images, train_labels):
......@@ -118,32 +119,44 @@ def augmenter(x, y):
# we leave n_rays at the default of 32, but may want to expose this as well
def train_model(x_train, y_train, x_val, y_val, save_path,
n_channels, patch_size, n_rays=32):
n_channels, patch_size,
pretrained_model_path=None, n_rays=32):
# make the model config
# copied from the stardist training notebook, this is a very weird line ...
use_gpu = False and gputools_available()
# predict on subsampled image for increased efficiency
grid = (2, 2)
config = Config2D(
n_rays=n_rays,
grid=grid,
use_gpu=use_gpu,
n_channel_in=n_channels,
train_patch_size=patch_size
)
save_root, save_name = os.path.split(save_path)
os.makedirs(save_root, exist_ok=True)
# if we don't have a pre-trained model path, make model with
# the vanilla config
if pretrained_model_path is None:
print("Training model from scratch")
# predict on subsampled image for increased efficiency
grid = (2, 2)
config = Config2D(
n_rays=n_rays,
grid=grid,
use_gpu=use_gpu,
n_channel_in=n_channels,
train_patch_size=patch_size
)
model = StarDist2D(config, name=save_name, basedir=save_root)
# otherwise load the pretrained model
else:
print("Training model pretrained on", pretrained_model_path)
# copy the pretrained model
copytree(pretrained_model_path, save_path)
model = StarDist2D(None, name=save_name, basedir=save_root)
if use_gpu:
print("Using a GPU for training")
print("Using GPU for data pre-processing")
# limit gpu memory
from csbdeep.utils.tf import limit_gpu_memory
limit_gpu_memory(0.8)
else:
print("GPU not found, using the CPU for training")
save_root, save_name = os.path.split(save_path)
os.makedirs(save_root, exist_ok=True)
model = StarDist2D(config, name=save_name, basedir=save_root)
print("Using CPU for data-preprocessing")
model.train(x_train, y_train, validation_data=(x_val, y_val), augmenter=augmenter)
optimal_parameters = model.optimize_thresholds(x_val, y_val)
......@@ -152,7 +165,7 @@ def train_model(x_train, y_train, x_val, y_val, save_path,
def train_stardist_model(root, model_save_path, image_folder, labels_folder, ext,
validation_fraction, patch_size, multichannel,
save_for_fiji):
save_for_fiji, pretrained_model_path):
print("Loading training data")
train_images, train_labels, n_channels = load_training_data(root,
image_folder, labels_folder,
......@@ -169,7 +182,8 @@ def train_stardist_model(root, model_save_path, image_folder, labels_folder, ext
print("Start model training ...")
print("You can connect to the tensorboard by typing 'tensorboaed --logdir=.' in the folder where the training runs")
model, opt_params = train_model(x_train, y_train, x_val, y_val, model_save_path,
n_channels, patch_size)
n_channels, patch_size=patch_size,
pretrained_model_path=pretrained_model_path)
print("The model has been trained and was saved to", model_save_path)
print("The following optimal parameters were found:", opt_params)
......@@ -179,7 +193,6 @@ def train_stardist_model(root, model_save_path, image_folder, labels_folder, ext
# use configarparse?
# TODO set batch size
# TODO enable fine-tuning on pre-trained
# TODO enable excluding images by name
def main():
parser = argparse.ArgumentParser(description="Train a 2d stardist model")
......@@ -200,13 +213,15 @@ def main():
help="Do we have multichannel images? Default: 0")
parser.add_argument('--save_for_fiji', type=int, default=0,
help="Save the model for FIJI, default: 0")
parser.add_argument('--pretrained_model', type=str, default=None,
help="Path to pretrained model that will be fine-tuned.")
args = parser.parse_args()
train_stardist_model(args.root, args.model_save_path,
args.image_folder, args.labels_folder,
args.ext, args.validation_fraction,
tuple(args.patch_size), bool(args.multichannel),
bool(args.save_for_fiji))
bool(args.save_for_fiji), args.pretrained_model)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment