Source code for ovejero.model_trainer

# -*- coding: utf-8 -*-
"""
This script will initialize and train a BNN model on a strong lensing image
dataset.

Examples:
	python -m model_trainer configs/t1.json

"""

# Import some backend stuff
import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
import argparse, json, os
import pandas as pd
import copy, glob

# Import the code to construct the bnn and the data pipeline
from ovejero import bnn_alexnet, data_tools


[docs]def config_checker(cfg): """ Check that configuration file meets ovejero requirements. Throw an error if configuration file is invalid. Parameters: cfg: The dictionary attained from reading the json config. """ def recursive_key_checker(dict_check,dict_ref): """ Check that dictionary has all of the keys in a reference dictionary, and that the same is true for any sub-dictionaries. Raise an error if not identical. Parameters: dict_check (dict): The dictionary to check dict_ref (dict): The reference dictionary """ for key in dict_ref: if key not in dict_check: raise RuntimeError('Input config does not contain %s'%(key)) if isinstance(dict_ref[key],dict): recursive_key_checker(dict_check[key],dict_ref[key]) # Load the check json file root_path = os.path.dirname(os.path.abspath(__file__)) with open(os.path.join(root_path,'check.json'),'r') as json_f: cfg_ref = json.load(json_f) recursive_key_checker(cfg,cfg_ref)
[docs]def load_config(config_path): """ Load a configuration file from the path and check that it meets the requirements. Parameters: config_path (str): The path to the config file to be loaded Returns: (dict): A dictionary object with the config file. """ # Load the config with open(config_path,'r') as json_f: cfg = json.load(json_f) # Check that it's up to snuff config_checker(cfg) # Return it return cfg
[docs]def prepare_tf_record(cfg,root_path,tf_record_path,final_params,train_or_test): """ Perpare the tf record using the config file values. Parameters: cfg (dict): The dictionary attained from reading the json config file. root_path (str): The root path that will contain all of the data including the lens parameters, the npy files, and the TFRecord. tf_record_path (str): The path where the TFRecord will be saved. final_params ([str,...]): The parameters we expect to be in the final set of lens parameters. train_or_test (string): If test, the normalizations will be saved. If train, the training normalizations will be used. """ # Path to csv containing lens parameters. lens_params_path = os.path.join(root_path, cfg['dataset_params']['lens_params_path']) # The list of lens parameters that should be trained on. We will # append to this, so we want to make a copy. lens_params = copy.copy(cfg['dataset_params']['lens_params']) # Where to save the lens parameters to after the preprocessing # transformations new_param_path = os.path.join(root_path, cfg['dataset_params']['new_param_path']) # Where to save the normalization constants to. Note that we take the # root path associated with the training_params here, even if validation # params root path was passed in. This is because we always want to use # the training norms! normalization_constants_path = os.path.join( cfg['training_params']['root_path'], cfg['dataset_params']['normalization_constants_path']) # Parameters to convert to log space if 'lens_params_log' in cfg['dataset_params']: lens_params_log = cfg['dataset_params']['lens_params_log'] else: lens_params_log = None # Parameters to convert from ratio and ang to excentricities if 'gampsi' in cfg['dataset_params']: cfg_gampsi = cfg['dataset_params']['gampsi'] # New prefix for those parameters gampsi_parameter_prefixes = cfg_gampsi['gampsi_parameter_prefixes'] # The parameter names of the ratios gampsi_params_rat = cfg_gampsi['gampsi_params_rat'] # The parameter names of the angles gampsi_params_ang = cfg_gampsi['gampsi_params_ang'] else: gampsi_parameter_prefixes=None gampsi_params_rat=None gampsi_params_ang=None # First write desired parameters in log space. if lens_params_log is not None: data_tools.write_parameters_in_log_space(lens_params_log, lens_params_path,new_param_path) # Add log version of parameter. for lens_param_log in lens_params_log: lens_params.append(lens_param_log+'_log') # Parameters should be read from this path from now on. lens_params_path = new_param_path # Now convert ratio and angle parameters to excentricities. if gampsi_parameter_prefixes is not None: for gampsii in range(len(gampsi_parameter_prefixes)): data_tools.gampsi_2_g1g2(gampsi_params_rat[gampsii], gampsi_params_ang[gampsii],lens_params_path,new_param_path, gampsi_parameter_prefixes[gampsii]) # Update lens_params lens_params.append(gampsi_parameter_prefixes[gampsii]+'_g1') lens_params.append(gampsi_parameter_prefixes[gampsii]+'_g2') # Parameters should be read from this path from now on. lens_params_path = new_param_path # Now normalize all of the lens parameters data_tools.normalize_lens_parameters(lens_params,lens_params_path, new_param_path,normalization_constants_path, train_or_test=train_or_test) # Quickly check that all the desired lens_params ended up in the final # csv file. for final_param in final_params: if final_param not in lens_params: raise RuntimeError('Desired lens parameters and lens parameters in'+ ' final csv do not match') # Finally, generate the TFRecord data_tools.generate_tf_record(root_path,lens_params,new_param_path, tf_record_path)
[docs]def get_normed_pixel_scale(cfg,pixel_scale): """ Return a dictionary with the pixel scale normalized according to the normalization of each shift parameter. Parameters: cfg (dict): The dictionary attained from reading the json config file. pixel_scale (float): The pixel scale used for the original images. Returns: (dict): A dictionary of the pixel scales renormalized in the same way as the shift parameters. """ # Get the parameters we need to read the normalization from shift_params = cfg['training_params']['shift_params'] # Adjust the pixel scale by the normalization normalization_constants_path = os.path.join( cfg['training_params']['root_path'], cfg['dataset_params']['normalization_constants_path']) norm_const_dict = pd.read_csv(normalization_constants_path, index_col=None) # Set the normed pixel scale for each parameter normed_pixel_scale = {} for shift_param in shift_params[0]: normed_pixel_scale[shift_param] = pixel_scale/norm_const_dict[ shift_param][1] for shift_param in shift_params[1]: normed_pixel_scale[shift_param] = pixel_scale/norm_const_dict[ shift_param][1] return normed_pixel_scale
[docs]def model_loss_builder(cfg, verbose=False): """ Build a model according to the specifications in configuration dictionary and return both the initialized model and the loss function. Parameters: cfg (dict): The dictionary attained from reading the json config file. verbose (bool): If True, will be verbose as model is built. Returns: (tf.keras.model, function): A bnn model of the type specified in config and a callable function to construct the tesnorflow graph for the loss. """ # Load the parameters we need from the config file. Some of these will # be repeats from the main script. # The final parameters that need to be in tf_record_path final_params = cfg['training_params']['final_params'] num_params = len(final_params) # The learning rate learning_rate = cfg['training_params']['learning_rate'] # The decay rate for Adam decay = cfg['training_params']['decay'] # Image dimensions img_dim = cfg['training_params']['img_dim'] # Weight and dropout regularization parameters for the concrete dropout # model. kr = cfg['training_params']['kernel_regularizer'] dr = cfg['training_params']['dropout_regularizer'] dropout_rate = cfg['training_params']['dropout_rate'] # If the any of the parameters contain excentricities then the e1/e2 # pair should be included in the flip list for the correct loss function # behavior. See the example config files. flip_pairs = cfg['training_params']['flip_pairs'] # The type of BNN output (either diag, full, or gmm). bnn_type = cfg['training_params']['bnn_type'] dropout_type = cfg['training_params']['dropout_type'] # The path to the model weights. If they already exist they will be loaded model_weights = cfg['training_params']['model_weights'] # Finally set the random seed we will use for training random_seed = cfg['training_params']['random_seed'] # Initialize the log function according to bnn_type loss_class = bnn_alexnet.LensingLossFunctions(flip_pairs,num_params) if bnn_type == 'diag': loss = loss_class.diagonal_covariance_loss num_outputs = num_params*2 elif bnn_type == 'full': loss = loss_class.full_covariance_loss num_outputs = num_params + int(num_params*(num_params+1)/2) elif bnn_type == 'gmm': loss = loss_class.gm_full_covariance_loss num_outputs = 2*(num_params + int(num_params*(num_params+1)/2))+1 else: raise RuntimeError('BNN type %s does not exist'%(bnn_type)) # The mse loss doesn't depend on model type. mse_loss = loss_class.mse_loss adam = Adam(lr=learning_rate,amsgrad=False) if dropout_type == 'concrete': model = bnn_alexnet.concrete_alexnet((img_dim, img_dim, 1), num_outputs, kernel_regularizer=kr,dropout_regularizer=dr,random_seed=random_seed) # The final metric here is a hack to be able to track the average dropout # value. model.compile(loss=loss, optimizer=adam, metrics=[loss,mse_loss, bnn_alexnet.p_value(model)]) elif dropout_type == 'standard': model = bnn_alexnet.dropout_alexnet((img_dim, img_dim, 1), num_outputs, kernel_regularizer=kr,dropout_rate=dropout_rate) model.compile(loss=loss, optimizer=adam, metrics=[loss,mse_loss]) else: raise ValueError('dropout type %s is not an option.'%(dropout_type) + ' Either standard or concrete') if verbose: print('Is model built: ' + str(model.built)) try: model.load_weights(model_weights) if verbose: print('Loaded weights %s'%(model_weights)) except: if verbose: print('No weights found. Saving new weights to %s'%(model_weights)) return model, loss
[docs]def main(): """ Initializes and trains a BNN network. Path to config file are read from command line arguments. """ # Initialize argument parser to pull neccesary paths parser = argparse.ArgumentParser() parser.add_argument('config',help='json config file containing BNN type ' + 'and data/model paths') args = parser.parse_args() cfg = load_config(args.config) # Extract neccesary parameters from the json config # The batch size used for training batch_size = cfg['training_params']['batch_size'] # The number of epochs of training n_epochs = cfg['training_params']['n_epochs'] # The root path that will contain all of the training data including the lens # parameters, the npy files, and the TFRecord for training. root_path_t = cfg['training_params']['root_path'] # The same but for validation root_path_v = cfg['validation_params']['root_path'] # The filename of the TFRecord for training data tf_record_path_t = os.path.join(root_path_t, cfg['training_params']['tf_record_path']) # The same but for validation tf_record_path_v = os.path.join(root_path_v, cfg['validation_params']['tf_record_path']) # The final parameters that need to be in tf_record_path final_params = cfg['training_params']['final_params'] # The path to the model weights. If they already exist they will be loaded model_weights = cfg['training_params']['model_weights'] # The path for the Tensorboard logs tensorboard_log_dir = cfg['training_params']['tensorboard_log_dir'] # The path to the baobab config file that will be used to add noise baobab_config_path = cfg['training_params']['baobab_config_path'] # The parameters govern the augmentation of the data # Whether or not the images should be normalzied to have standard # deviation 1 norm_images = cfg['training_params']['norm_images'] # The number of pixels to uniformly shift the images by and the # parameters that need to be rescaled to account for this shift shift_pixels = cfg['training_params']['shift_pixels'] shift_params = cfg['training_params']['shift_params'] # What the pixel_scale of the images is. This will be adjusted for the # normalization. pixel_scale = cfg['training_params']['pixel_scale'] # Finally set the random seed we will use for training random_seed = cfg['training_params']['random_seed'] tf.random.set_seed(random_seed) # Number of steps per epoch is number of examples over the batch size npy_file_list = glob.glob(os.path.join(root_path_t,'X*.npy')) steps_per_epoch = len(npy_file_list)//batch_size print('Checking for training data.') if not os.path.exists(tf_record_path_t): print('Generating new TFRecord at %s'%(tf_record_path_t)) prepare_tf_record(cfg,root_path_t,tf_record_path_t,final_params, train_or_test='train') else: print('TFRecord found at %s'%(tf_record_path_t)) print('Checking for validation data.') if not os.path.exists(tf_record_path_v): print('Generating new TFRecord at %s'%(tf_record_path_v)) prepare_tf_record(cfg,root_path_v,tf_record_path_v,final_params, train_or_test='test') else: print('TFRecord found at %s'%(tf_record_path_v)) # Get the normalzied pixel scale (will fail if tf_record has not been # correctly created.) normed_pixel_scale = get_normed_pixel_scale(cfg,pixel_scale) # We let keras deal with epochs instead of the tf dataset object. tf_dataset_t = data_tools.build_tf_dataset(tf_record_path_t,final_params, batch_size,n_epochs,baobab_config_path,norm_images=norm_images, shift_pixels=shift_pixels,shift_params=shift_params, normed_pixel_scale=normed_pixel_scale) # Validation dataset will, by default, have no augmentation but will have # the images normalized if requested. tf_dataset_v = data_tools.build_tf_dataset(tf_record_path_v,final_params, batch_size,1,baobab_config_path,norm_images=norm_images) print('Initializing the model') model, loss = model_loss_builder(cfg,verbose=True) tensorboard = TensorBoard(log_dir=tensorboard_log_dir,update_freq='batch') modelcheckpoint = ModelCheckpoint(model_weights,monitor='val_loss', save_best_only=True,save_freq='epoch') # TODO add validation data. model.fit(tf_dataset_t,callbacks=[tensorboard, modelcheckpoint], epochs=n_epochs, steps_per_epoch=steps_per_epoch, validation_data=tf_dataset_v)
if __name__ == '__main__': main()