model_trainer – BNN Training Module¶
This script will initialize and train a BNN model on a strong lensing image dataset.
Examples
python -m model_trainer configs/t1.json
-
ovejero.model_trainer.config_checker(cfg)[source]¶ Check that configuration file meets ovejero requirements. Throw an error if configuration file is invalid.
Parameters: cfg: The dictionary attained from reading the json config.
-
ovejero.model_trainer.get_normed_pixel_scale(cfg, pixel_scale)[source]¶ Return a dictionary with the pixel scale normalized according to the normalization of each shift parameter.
Parameters: - cfg (dict) – The dictionary attained from reading the json config file.
- pixel_scale (float) – The pixel scale used for the original images.
Returns: A dictionary of the pixel scales renormalized in the same way as the shift parameters.
Return type: (dict)
-
ovejero.model_trainer.load_config(config_path)[source]¶ Load a configuration file from the path and check that it meets the requirements.
Parameters: config_path (str) – The path to the config file to be loaded Returns: A dictionary object with the config file. Return type: (dict)
-
ovejero.model_trainer.main()[source]¶ Initializes and trains a BNN network. Path to config file are read from command line arguments.
-
ovejero.model_trainer.model_loss_builder(cfg, verbose=False)[source]¶ Build a model according to the specifications in configuration dictionary and return both the initialized model and the loss function.
Parameters: - cfg (dict) – The dictionary attained from reading the json config file.
- verbose (bool) – If True, will be verbose as model is built.
Returns: A bnn model of the type specified in config and a callable function to construct the tesnorflow graph for the loss.
Return type: (tf.keras.model, function)
-
ovejero.model_trainer.prepare_tf_record(cfg, root_path, tf_record_path, final_params, train_or_test)[source]¶ Perpare the tf record using the config file values.
Parameters: - cfg (dict) – The dictionary attained from reading the json config file.
- root_path (str) – The root path that will contain all of the data including the lens parameters, the npy files, and the TFRecord.
- tf_record_path (str) – The path where the TFRecord will be saved.
- final_params ([str,..]) – The parameters we expect to be in the final set of lens parameters.
- train_or_test (string) – If test, the normalizations will be saved. If train, the training normalizations will be used.