diff --git a/baselines/run_GAReader.py b/baselines/run_GAReader.py index 36c713b..e88d547 100644 --- a/baselines/run_GAReader.py +++ b/baselines/run_GAReader.py @@ -249,7 +249,6 @@ def main(config, model_filename): if __name__ == "__main__": - model_name = "GAReader" data_dir = "./data/imperceptibility/training_data" embedding_folder = "./baselines/embeddings/" ## @@ -262,6 +261,7 @@ def main(config, model_filename): if model_name == "GAReader": from baselines.GAReader import args, GAReader + main( args.get_args(data_dir, cache_dir, embedding_folder, output_dir, log_dir), model_filename, diff --git a/baselines/utils/arc_embedding_utils.py b/baselines/utils/arc_embedding_utils.py index 321b73f..9b5e2f1 100644 --- a/baselines/utils/arc_embedding_utils.py +++ b/baselines/utils/arc_embedding_utils.py @@ -42,7 +42,10 @@ def load_data( ) print( - "the size of train: {}, dev:{},".format(len(train.examples), len(dev.examples),) + "the size of train: {}, dev:{},".format( + len(train.examples), + len(dev.examples), + ) ) word_field.build_vocab( diff --git a/eval.py b/eval.py index 2e3fa7c..58897e7 100644 --- a/eval.py +++ b/eval.py @@ -18,7 +18,9 @@ validation_dataset = ConcretenessDataset( - file_path=val_file_path, tokenizer=tokenizer, split="val", + file_path=val_file_path, + tokenizer=tokenizer, + split="val", ) val_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False) embeddings = GloveEmbedding( diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py index 652b694..ed0a8e1 100644 --- a/src/datasets/__init__.py +++ b/src/datasets/__init__.py @@ -1,2 +1,2 @@ from src.datasets.cloze_dataset import * -from src.datasets.max_cloze_dataset import * \ No newline at end of file +from src.datasets.max_cloze_dataset import * diff --git a/src/improvement_methods/improve.py b/src/improvement_methods/improve.py new file mode 100644 index 0000000..7d2af81 --- /dev/null +++ b/src/improvement_methods/improve.py @@ -0,0 +1,183 @@ +import json +import argparse +from src.improvement_methods import StatisticalEmbedding +from src.datasets.cloze_dataset import ClozeDataset +from transformers import AutoTokenizer +from src.models import * +from src.utils.configuration import Config +import copy +from torch.utils.data import DataLoader +import torch +import heapq +import numpy as np + +parser = argparse.ArgumentParser( + prog="improve.py", + description="Apply improvement approach to imperceptibility methods", +) + +parser.add_argument( + "--model", + type=str, + action="store", + help="The configuration for model", + default=os.path.join(dirname, "./configs/models/forty/default.yaml"), +) +parser.add_argument( + "--data", + type=str, + action="store", + help="The configuration for data", + default=os.path.join(dirname, "./configs/datasets/forty/default.yaml"), +) +parser.add_argument( + "--trained_model_path", + type=str, + help="Path of the trained model's path", + default="/content/drive/MyDrive/SemEval/SemEval_final/distilbert_train_trial/ReCAM-final/ckpts_old/all_ckpts/3_5600.pth", +) +parser.add_argument( + "--test_configuration", + help="Whether test data is being used.", + type=str, + default=False, +) +parser.add_argument( + "--improvement_method", + help="Select between: Thresholding Method(threshold), Difference Method(difference), Second Highest Probability Method(second_highest)", + type=str, + default="threshold", +) +dataset_path = Config(path=args.data) +model_config = Config(path=args.model) +path = args.trained_model_path +test_flag = args.test_configuration +emb = StatisticalEmbedding(normalise=False) + + +def generate_cloze_predictions(dataset_path, generate_hyponyms=False): + with open(dataset_path) as f: + datapoints = [json.loads(datapoint) for datapoint in f.read().splitlines()] + model_name = model_config.params["pretrained_model_name_or_path"] + tokenizer = AutoTokenizer.from_pretrained(model_name) + cloze_dataset = ClozeDataset(dataset_config, tokenizer) + weight = torch.load(path) + model = torch.load(path)["model_state_dict"] + dataloader = DataLoader( + cloze_dataset, + collate_fn=cloze_dataset.custom_collate_fn, + batch_size=1, + shuffle=False, + ) + if torch.cuda.is_available(): + model.cuda() + + with torch.no_grad(): + model.eval() + for batch in dataloader: + *inputs, label = [torch.tensor(value, device="cuda") for value in batch] + datapoint = datapoints[i] + if hyponym_run: + cloze_prediction, bert_output = model(inputs) + cloze_prediction_label = torch.argmax(cloze_prediction) + final_datapoints.append( + { + "cloze_prediction": [ + float(i) for i in cloze_prediction.cpu().numpy()[0] + ], + "cloze_prediction_label": int( + cloze_prediction_label.cpu().numpy() + ), + "datapoint": datapoint, + "bert_output": bert_output.detach().cpu().tolist(), + } + ) + else: + cloze_prediction = model(inputs) + cloze_prediction_label = torch.argmax(cloze_prediction) + final_datapoints.append( + { + "cloze_prediction": [ + float(i) for i in cloze_prediction.cpu().numpy()[0] + ], + "cloze_prediction_label": int( + cloze_prediction_label.cpu().numpy() + ), + "datapoint": datapoint, + } + ) + if test_flag: + final_datapoints[-1]["id"] = datapoint["id"] + return final_datapoints + + +def improvement_methods(p_value, final_datapoints, method="threshold"): + lst_res = [] + softmax_function = torch.nn.Softmax() + + for i, data in enumerate(final_datapoints): + cloze_preds = torch.Tensor(data["cloze_prediction"]) + cloze_probs = softmax_function(cloze_preds) + if method == "threshold": + pred_label = int(torch.argmax(cloze_probs)) + if cloze_probs[pred_label] < p_value: + lst_inds = heapq.nlargest( + 3, range(len(cloze_preds)), key=cloze_preds.__getitem__ + ) + second_max = lst_inds[1] + second_max_option = data["datapoint"]["option_" + str(second_max)] + max_option = data["datapoint"]["option_" + str(lst_inds[0])] + max_opt_emb = np.array(emb.get_embedding(max_option)) + sec_opt_emb = np.array(emb.get_embedding(second_max_option)) + no_of_unequals = ( + sec_opt_emb.shape[0] - (sec_opt_emb == max_opt_emb).sum() + ) + if (sec_opt_emb > max_opt_emb).sum() >= no_of_unequals / 2: + final_datapoints[i]["cloze_prediction_label"] = second_max_option + + elif method == "difference": + lst_inds = heapq.nlargest( + 3, range(len(cloze_preds)), key=cloze_preds.__getitem__ + ) + first_max = lst_inds[0] + second_max = lst_inds[1] + second_max_option = data["datapoint"]["option_" + str(second_max)] + max_option = data["datapoint"]["option_" + str(first_max)] + if cloze_probs[first_max] - cloze_probs[second_max] < p_value: + max_opt_emb = np.array(emb.get_embedding(max_option)) + sec_opt_emb = np.array(emb.get_embedding(second_max_option)) + no_of_unequals = ( + sec_opt_emb.shape[0] - (sec_opt_emb == max_opt_emb).sum() + ) + if (sec_opt_emb > max_opt_emb).sum() >= no_of_unequals / 2: + final_datapoints[i]["cloze_prediction_label"] = second_max_option + + elif method == "second_highest": + lst_inds = heapq.nlargest( + 3, range(len(cloze_preds)), key=cloze_preds.__getitem__ + ) + first_max = lst_inds[0] + second_max = lst_inds[1] + second_max_option = data["datapoint"]["option_" + str(second_max)] + max_option = data["datapoint"]["option_" + str(first_max)] + if cloze_probs[second_max] > p_value: + max_opt_emb = np.array(emb.get_embedding(max_option)) + sec_opt_emb = np.array(emb.get_embedding(second_max_option)) + no_of_unequals = ( + sec_opt_emb.shape[0] - (sec_opt_emb == max_opt_emb).sum() + ) + if (sec_opt_emb > max_opt_emb).sum() >= no_of_unequals / 2: + final_datapoints[i]["cloze_prediction_label"] = second_max_option + return final_datapoints + + +def write_to_csv(final_datapoints): + output = "" + for i, data in enumerate(final_datapoints): + if test_flag: + id = data["id"] + else: + id = i + output += id + "," + int(data["cloze_prediction_label"]) + "\n" + with open("output.csv", "w") as f: + f.write(output) diff --git a/src/improvement_methods/statistical_embeddings.py b/src/improvement_methods/statistical_embeddings.py new file mode 100644 index 0000000..229c687 --- /dev/null +++ b/src/improvement_methods/statistical_embeddings.py @@ -0,0 +1,161 @@ +import glob +import nltk +from tqdm.auto import tqdm +from numpy.linalg import norm +import spacy +from nltk.corpus import sentiwordnet as swn +from itertools import chain +from nltk.corpus import wordnet as wn + +nltk.download("punkt") +nlp = spacy.load("en", disable=["parser", "ner"]) +nltk.download("sentiwordnet") +nltk.download("wordnet") + + +class StatisticalEmbedding: + def __init__(self, normalise=True): + # add word frequency later + # try to fix number of senses and add it later + # try to fix number of hyponyms and add it later + self.normalise = normalise + + def get_embedding(self, word): + len_embedding = self.get_length_of_word(word) + sense_embedding = self.get_number_of_senses(word) + hyponym_embedding = self.get_no_of_hyponyms(word) + avg_hyponym_embedding = self.get_avg_no_of_hyponyms(word) + depth_hypernymy_embedding = self.get_depth_of_hypernymy_tree(word) + avg_depth_hypernymy_embedding = self.get_avg_depth_of_hypernymy_tree(word) + pos_neg_obj_score = self.get_pos_neg_obj_scores(word) + avg_pos_neg_obj_score = self.get_avg_pos_neg_obj_scores(word) + + embedding = [ + len_embedding, + sense_embedding, + hyponym_embedding, + avg_hyponym_embedding, + depth_hypernymy_embedding, + avg_depth_hypernymy_embedding, + pos_neg_obj_score[0], + pos_neg_obj_score[1], + pos_neg_obj_score[2], + avg_pos_neg_obj_score[0], + avg_pos_neg_obj_score[1], + avg_pos_neg_obj_score[2], + ] + if self.normalise: + embedding = embedding / norm(embedding) + return embedding + + def get_length_of_word(self, word): + words = word.split(" ") + lengths = [len(word) for word in words] + max_len = max(lengths) + return max_len + + def get_number_of_senses(self, word): + # words = word.split(' ') + # lst_of_senses = [len(wn.synsets(word)) for word in words] + # max_no_of_senses = max(lst_of_senses) + return len(wn.synsets(word)) + + def get_depth_of_hypernymy_tree(self, word): + max_len_paths = 0 + words = word.split(" ") + for word_n in words: + if len(wn.synsets(word_n)) > 0: + j = wn.synsets(word_n)[0] + paths_to_top = j.hypernym_paths() + max_len_paths = max( + max_len_paths, len(max(paths_to_top, key=lambda i: len(i))) + ) + + return 100000 - max_len_paths + + def get_avg_depth_of_hypernymy_tree(self, word): + words = word.split(" ") + lst_avg_len_paths = [] + for word_n in words: + i = 0 + avg_len_paths = 0 + + for j in wn.synsets(word_n): + paths_to_top = j.hypernym_paths() + max_len_path = len(max(paths_to_top, key=lambda k: len(k))) + avg_len_paths += max_len_path + i += 1 + if i > 0: + return 100000 - avg_len_paths / i + else: + return 100000 + + def get_pos_neg_obj_scores(self, word): + words = word.split(" ") + pos_scores = [] + neg_scores = [] + obj_scores = [] + + for word_n in words: + + if len(list(swn.senti_synsets(word_n))) > 0: + j = list(swn.senti_synsets(word_n))[0] + + pos_scores.append(j.pos_score()) + neg_scores.append(j.neg_score()) + obj_scores.append(j.obj_score()) + else: + pos_scores.append(0) + neg_scores.append(0) + obj_scores.append(0) + return (max(pos_scores), max(neg_scores), 1 - max(obj_scores)) + + def get_avg_pos_neg_obj_scores(self, word): + words = word.split(" ") + pos_scores = [] + neg_scores = [] + obj_scores = [] + + for word_n in words: + ct = 0 + avg_pos_score = 0 + avg_neg_score = 0 + avg_obj_score = 0 + + for j in list(swn.senti_synsets(word_n)): + avg_pos_score += j.pos_score() + avg_neg_score += j.neg_score() + avg_obj_score += j.obj_score() + ct += 1 + + if ct > 0: + pos_scores.append(avg_pos_score / ct) + neg_scores.append(avg_neg_score / ct) + obj_scores.append(avg_obj_score / ct) + else: + pos_scores.append(0) + neg_scores.append(0) + obj_scores.append(0) + return (max(pos_scores), max(neg_scores), 1 - max(obj_scores)) + + def get_no_of_hyponyms(self, word): + + if len(wn.synsets(word)) > 0: + j = wn.synsets(word)[0] + # print(word) + # print(j.hyponyms()) + no_of_hypos = len(list(chain(*[l.lemma_names() for l in j.hyponyms()]))) + return no_of_hypos + else: + return 0 + + def get_avg_no_of_hyponyms(self, word): + i = 0 + no_of_hypos = 0 + for j in wn.synsets(word): + no_of_hypos += len(list(chain(*[l.lemma_names() for l in j.hyponyms()]))) + i += 1 + if i > 0: + return no_of_hypos / i + else: + return 0 diff --git a/src/non-specificity/hypernym/bert_finetune.py b/src/non-specificity/hypernym/bert_finetune.py index 7d01c84..3d3dcaa 100644 --- a/src/non-specificity/hypernym/bert_finetune.py +++ b/src/non-specificity/hypernym/bert_finetune.py @@ -66,10 +66,14 @@ def main(): t0 = time.time() dataset = LineByLineTextDataset( - tokenizer=tokenizer, file_path=args["train_data_file"], block_size=512, + tokenizer=tokenizer, + file_path=args["train_data_file"], + block_size=512, ) val_dataset = LineByLineTextDataset( - tokenizer=tokenizer, file_path=args["val_data_file"], block_size=512, + tokenizer=tokenizer, + file_path=args["val_data_file"], + block_size=512, ) data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=True, mlm_probability=args["mlm_probability"] diff --git a/src/shelved_approaches/datasets/concreteness_dataset.py b/src/shelved_approaches/datasets/concreteness_dataset.py index 6ec6bf2..5d8cc7d 100644 --- a/src/shelved_approaches/datasets/concreteness_dataset.py +++ b/src/shelved_approaches/datasets/concreteness_dataset.py @@ -32,7 +32,12 @@ def __init__(self, config, tokenizer): self.config = config self.data = pd.read_csv( self.config.file_path, error_bad_lines=False, delimiter="\t" - )[self.config.text_cols + [self.config.label_col,]].dropna() + )[ + self.config.text_cols + + [ + self.config.label_col, + ] + ].dropna() self.tokenizer = tokenizer def __len__(self): diff --git a/src/shelved_approaches/datasets/transformers_concreteness_dataset.py b/src/shelved_approaches/datasets/transformers_concreteness_dataset.py index 7340355..f390051 100644 --- a/src/shelved_approaches/datasets/transformers_concreteness_dataset.py +++ b/src/shelved_approaches/datasets/transformers_concreteness_dataset.py @@ -27,7 +27,9 @@ class TransformersConcretenessDataset(Dataset): """ def __init__( - self, config, tokenizer, + self, + config, + tokenizer, ): """ diff --git a/src/shelved_approaches/models/bert_cloze_statistical_embeddings.py b/src/shelved_approaches/models/bert_cloze_statistical_embeddings.py index d5f6b13..19a31cf 100644 --- a/src/shelved_approaches/models/bert_cloze_statistical_embeddings.py +++ b/src/shelved_approaches/models/bert_cloze_statistical_embeddings.py @@ -12,83 +12,98 @@ from numpy.linalg import norm -nltk.download('punkt') +nltk.download("punkt") import spacy -nlp = spacy.load('en', disable=['parser', 'ner']) -nltk.download('sentiwordnet') +nlp = spacy.load("en", disable=["parser", "ner"]) + +nltk.download("sentiwordnet") from nltk.corpus import sentiwordnet as swn from itertools import chain -nltk.download('wordnet') + +nltk.download("wordnet") from nltk.corpus import wordnet as wn + class StatisticalEmbedding: - def __init__(self,normalise=True): + def __init__(self, normalise=True): # add word frequency later # try to fix number of senses and add it later # try to fix number of hyponyms and add it later self.normalise = normalise - def get_embedding(self,word): + def get_embedding(self, word): len_embedding = self.get_length_of_word(word) depth_hypernymy_embedding = self.get_depth_of_hypernymy_tree(word) avg_depth_hypernymy_embedding = self.get_avg_depth_of_hypernymy_tree(word) pos_neg_obj_score = self.get_pos_neg_obj_scores(word) avg_pos_neg_obj_score = self.get_avg_pos_neg_obj_scores(word) - embedding = [len_embedding,depth_hypernymy_embedding,avg_depth_hypernymy_embedding,pos_neg_obj_score[0],pos_neg_obj_score[1],pos_neg_obj_score[2],avg_pos_neg_obj_score[0],avg_pos_neg_obj_score[1],avg_pos_neg_obj_score[2]] - if(self.normalise): - embedding = embedding/norm(embedding) + embedding = [ + len_embedding, + depth_hypernymy_embedding, + avg_depth_hypernymy_embedding, + pos_neg_obj_score[0], + pos_neg_obj_score[1], + pos_neg_obj_score[2], + avg_pos_neg_obj_score[0], + avg_pos_neg_obj_score[1], + avg_pos_neg_obj_score[2], + ] + if self.normalise: + embedding = embedding / norm(embedding) return embedding - def get_length_of_word(self,word): - words = word.split(' ') + def get_length_of_word(self, word): + words = word.split(" ") lengths = [len(word) for word in words] max_len = max(lengths) return max_len - def get_depth_of_hypernymy_tree(self,word): + def get_depth_of_hypernymy_tree(self, word): max_len_paths = 0 - words = word.split(' ') + words = word.split(" ") for word_n in words: - if(len(wn.synsets(word_n))>0): + if len(wn.synsets(word_n)) > 0: j = wn.synsets(word_n)[0] paths_to_top = j.hypernym_paths() - max_len_paths = max(max_len_paths,len(max(paths_to_top, key = lambda i: len(i)))) + max_len_paths = max( + max_len_paths, len(max(paths_to_top, key=lambda i: len(i))) + ) return max_len_paths - def get_avg_depth_of_hypernymy_tree(self,word): - words = word.split(' ') + def get_avg_depth_of_hypernymy_tree(self, word): + words = word.split(" ") lst_avg_len_paths = [] for word_n in words: i = 0 avg_len_paths = 0 - + for j in wn.synsets(word_n): paths_to_top = j.hypernym_paths() - max_len_path = len(max(paths_to_top, key = lambda k: len(k))) + max_len_path = len(max(paths_to_top, key=lambda k: len(k))) avg_len_paths += max_len_path i += 1 - if(i>0): - return avg_len_paths/i + if i > 0: + return avg_len_paths / i else: return 0 - def get_pos_neg_obj_scores(self,word): - words = word.split(' ') + def get_pos_neg_obj_scores(self, word): + words = word.split(" ") pos_scores = [] neg_scores = [] obj_scores = [] - + for word_n in words: - if(len(list(swn.senti_synsets(word_n)))>0): + if len(list(swn.senti_synsets(word_n))) > 0: j = list(swn.senti_synsets(word_n))[0] - + pos_scores.append(j.pos_score()) neg_scores.append(j.neg_score()) obj_scores.append(j.obj_score()) @@ -96,14 +111,14 @@ def get_pos_neg_obj_scores(self,word): pos_scores.append(0) neg_scores.append(0) obj_scores.append(0) - return (max(pos_scores),max(neg_scores),max(obj_scores)) + return (max(pos_scores), max(neg_scores), max(obj_scores)) - def get_avg_pos_neg_obj_scores(self,word): - words = word.split(' ') + def get_avg_pos_neg_obj_scores(self, word): + words = word.split(" ") pos_scores = [] neg_scores = [] obj_scores = [] - + for word_n in words: ct = 0 avg_pos_score = 0 @@ -116,15 +131,15 @@ def get_avg_pos_neg_obj_scores(self,word): avg_obj_score += j.obj_score() ct += 1 - if(ct>0): - pos_scores.append(avg_pos_score/ct) - neg_scores.append(avg_neg_score/ct) - obj_scores.append(avg_obj_score/ct) + if ct > 0: + pos_scores.append(avg_pos_score / ct) + neg_scores.append(avg_neg_score / ct) + obj_scores.append(avg_obj_score / ct) else: pos_scores.append(0) neg_scores.append(0) obj_scores.append(0) - return (max(pos_scores),max(neg_scores),max(obj_scores)) + return (max(pos_scores), max(neg_scores), max(obj_scores)) def gelu(x): @@ -224,8 +239,8 @@ def __init__(self, config): self.vocab_size = self.bert.embeddings.word_embeddings.weight.size(0) self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) self.emb = StatisticalEmbedding() - self.fc1 = nn.Linear(45,5) - self.fc2 = nn.Linear(10,5) + self.fc1 = nn.Linear(45, 5) + self.fc2 = nn.Linear(10, 5) def init_weights(self, module): @@ -250,7 +265,7 @@ def forward(self, x_input): """ articles, articles_mask, ops, question_pos = x_input option_tokens = ops.clone() - + ### BERT CLOZE bsz = ops.size(0) @@ -279,17 +294,23 @@ def forward(self, x_input): for batch_of_options in range(bsz): batch_op = [] for option_idx in range(option_tokens.shape[1]): - opt_str = self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(option_tokens[batch_of_options][option_idx].detach().tolist())) + opt_str = self.tokenizer.convert_tokens_to_string( + self.tokenizer.convert_ids_to_tokens( + option_tokens[batch_of_options][option_idx].detach().tolist() + ) + ) emb = self.emb.get_embedding(opt_str) batch_op.append(torch.tensor(emb)) - option_strings.append(torch.cat(batch_op,dim=0)) + option_strings.append(torch.cat(batch_op, dim=0)) - option_strings = torch.as_tensor(torch.stack(option_strings),dtype=torch.float32,device=out.device) + option_strings = torch.as_tensor( + torch.stack(option_strings), dtype=torch.float32, device=out.device + ) linguistic_output = self.fc1(option_strings) - + ### COMMON NETWORK - bert_cat_linguistic = torch.cat((out,linguistic_output), dim=1) + bert_cat_linguistic = torch.cat((out, linguistic_output), dim=1) final_output = self.fc2(bert_cat_linguistic) - return final_output \ No newline at end of file + return final_output diff --git a/src/trainers/cloze_trainer.py b/src/trainers/cloze_trainer.py index 35ca202..d60d686 100644 --- a/src/trainers/cloze_trainer.py +++ b/src/trainers/cloze_trainer.py @@ -182,8 +182,14 @@ def train(self, model, train_dataset, val_dataset=None, logger=None): train_scores = dict( zip( - [train_loss_name,] + metric_name_list, - [training_loss,] + metric_list, + [ + train_loss_name, + ] + + metric_name_list, + [ + training_loss, + ] + + metric_list, ) ) @@ -278,7 +284,7 @@ def train(self, model, train_dataset, val_dataset=None, logger=None): + str(epoch) + ".pth", ) - ''' + """ if epoch == max_epochs: print("\nEvaluating\n") val_scores = self.val( @@ -439,7 +445,8 @@ def train(self, model, train_dataset, val_dataset=None, logger=None): global_step, append_text=self.train_config.append_text, ) - ''' + """ + ## Need to check if we want same loggers of different loggers for train and eval ## Evaluate @@ -455,7 +462,18 @@ def log( append_text, ): - return_dic = dict(zip([loss_name,] + metric_name_list, [loss,] + metric_list,)) + return_dic = dict( + zip( + [ + loss_name, + ] + + metric_name_list, + [ + loss, + ] + + metric_list, + ) + ) loss_name = f"{append_text}_{self.log_label}_{loss_name}" if log_values["loss"]: @@ -544,14 +562,25 @@ def val( val_loss_name = self.train_config.criterion.type all_outputs = torch.argmax(all_outputs, axis=1) metric_list = [ - metric(all_outputs.detach().cpu(), all_labels.cpu(), **self.metrics[metric]) + metric( + all_outputs.detach().cpu(), all_labels.cpu(), **self.metrics[metric] + ) for metric in self.metrics ] metric_name_list = [ metric["type"] for metric in self._config.main_config.metrics ] return_dic = dict( - zip([val_loss_name,] + metric_name_list, [loss,] + metric_list,) + zip( + [ + val_loss_name, + ] + + metric_name_list, + [ + loss, + ] + + metric_list, + ) ) if log: val_scores = self.log( diff --git a/src/utils/integrated_gradients.py b/src/utils/integrated_gradients.py index 13827a3..54085fb 100644 --- a/src/utils/integrated_gradients.py +++ b/src/utils/integrated_gradients.py @@ -353,4 +353,4 @@ def get_extended_attention_mask(self, attention_mask, input_shape, dtype): extended_attention_mask = tf.cast(extended_attention_mask, dtype) extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - return extended_attention_mask \ No newline at end of file + return extended_attention_mask diff --git a/src/utils/misc.py b/src/utils/misc.py index 1d9fc18..5dd361a 100644 --- a/src/utils/misc.py +++ b/src/utils/misc.py @@ -96,15 +96,31 @@ def generate_grid_search_configs(main_config, grid_config, root="hyperparams"): continue if "log_label" in root.keys(): - log_label_path = copy.deepcopy(stack + ["log_label",]) + log_label_path = copy.deepcopy( + stack + + [ + "log_label", + ] + ) if "log_label" in root.keys(): - log_label_path = copy.deepcopy(stack + ["log_label",]) + log_label_path = copy.deepcopy( + stack + + [ + "log_label", + ] + ) parent = root ## Otherwise it has children for key in parent.keys(): ## For the children if ( - ".".join(stack + [key,]) not in visited + ".".join( + stack + + [ + key, + ] + ) + not in visited ): ## Check if I have visited these children flag = 1 ## If not, we need to repeat the process for this key stack.append(key) ## Append this key to the stack