# many lines omitted above def make_log(experiment_dir, X_train, X_test, Y_test, model, hist, custom_model): now = datetime.datetime.now() now = now.strftime("%Y-%m-%d %H:%M:%S") # get last commit hash commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip() # get precision and recall at a range of cutpoints cutoffs = [0.01, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60] precrecs = [precision_recall(X_test, Y_test, model, i) for i in cutoffs] # do some nice pandas formatting precrec_table = str(pd.DataFrame(precrecs, columns = ['cutpoint', 'precision', 'recall', 'fscore', 'div_zero_perc'])) hist_table = str(pd.DataFrame(hist.history)) run_info = """Model ran at {0} using code with latest commit {1}. Custom model: {2}\n Total training data was {3} examples, with {4} features. Evaluated on {5} evaluation examples.\n Precision/recall metrics:\n{6}\n Training history:\n {7}""".format(now, commit, str(custom_model), X_train.shape[0], X_train.shape[2], X_test.shape[0], precrec_table, hist_table) with open(experiment_dir + "run_info.txt", "w") as f: f.write(run_info) @plac.annotations( experiment_dir=("Location of the run's folder with a config file", "option", "i", str)) def main(experiment_dir): git_status = str(subprocess.check_output(['git', 'status']).strip()) if bool(re.search("event_model.py", git_status)): print("You have uncommitted changes to `event_model.py`. Please commit them before proceeding to ensure reproducibility.") quit = input("Type 'testing' to continue or anything else to quit: ") if quit != "testing": print("Bye!") sys.exit(0) config = ConfigParser() config.read(experiment_dir + "config.txt") print("Importing data...") if config['Data']['use_cache']: print("Using cached formatted data. This is much faster, but changes to the feature factory won't appear if you do this!") cache_loc = config['Data']['cache_loc'] with open(cache_loc, "rb") as f: formatted = pickle.load(f) else: print("Regenerating formatted data from scratch.") try: nlp = spacy.load(str( config['Model']['nlp_model'])) except: print("Tried to load custom spaCy model but failed. Falling back to en_core_web_sm") formatted = import_data(minerva_dir = config['Data']['minerva_dir'], prodigy_dir = config['Data']['prodigy_dir']) encoder = Encoder() X, _, Y = make_CNN_matrix(formatted, encoder) X_train, Y_train, X_test, Y_test = train_test_split(X, Y) sys.path.append(experiment_dir) try: import custom_model model = custom_model.make_CNN_model(X, Y, filter_size = int(config['Model']['filter_size']), conv_dropout = float(config['Model']['conv_dropout']), conv_activation = str(config['Model']['conv_activation']), dense_units = int(config['Model']['dense_units']), dense_dropout = float(config['Model']['dense_dropout']), dense_activation = str(config['Model']['dense_activation'])) print("Using a custom model.") custom_model = True except ImportError: custom_model = False model = make_CNN_model(X, Y, filter_size = int(config['Model']['filter_size']), conv_dropout = float(config['Model']['conv_dropout']), conv_activation = str(config['Model']['conv_activation']), dense_units = int(config['Model']['dense_units']), dense_dropout = float(config['Model']['dense_dropout']), dense_activation = str(config['Model']['dense_activation'])) es = config['Model']['early_stopping'] if es == "True": # str, not bool print("Using early stopping.") try: patience = int(config['Model']['patience']) except: patience = 4 callbacks = [EarlyStopping(monitor='val_categorical_accuracy', patience=patience)] epochs = 40 else: epochs = int(config['Model']['epochs']) print("Using {0} epochs".format(epochs)) callbacks = [] hist = model.fit(X_train, Y_train, epochs=epochs, batch_size=int(config['Model']['batch_size']), validation_split=0.2, callbacks=callbacks) #model.evaluate(X_test, Y_test, batch_size=12) make_log(experiment_dir, X_train, X_test, Y_test, model, hist, custom_model) plot_model(model, show_shapes=True, to_file=experiment_dir+"model_diagram.pdf") model_file = experiment_dir + "CNN_event.h5" model.save(model_file) print("Completed run. Wrote results out to ", experiment_dir) if __name__ == '__main__': plac.call(main)