From 1e327b170e04dfce2b6ac474ce964552922c8ffb Mon Sep 17 00:00:00 2001 From: markus583 Date: Sun, 5 May 2024 08:38:57 +0000 Subject: [PATCH] handle short-seqs --- wtpsplit/evaluation/intrinsic.py | 311 +++++++++++++++++++++++-------- 1 file changed, 234 insertions(+), 77 deletions(-) diff --git a/wtpsplit/evaluation/intrinsic.py b/wtpsplit/evaluation/intrinsic.py index 234211ad..4a4dacf2 100644 --- a/wtpsplit/evaluation/intrinsic.py +++ b/wtpsplit/evaluation/intrinsic.py @@ -44,7 +44,7 @@ class Args: # } # } # TODO: for songs/etc., maybe feed in each sample separately? - eval_data_path: str = "data/all_data_02_05.pth" + eval_data_path: str = "data/all_data_04_05.pth" valid_text_path: str = None # "data/sentence/valid.parquet" device: str = "cpu" block_size: int = 512 @@ -53,9 +53,10 @@ class Args: include_langs: List[str] = None custom_language_list: str = None threshold: float = 0.01 - max_n_train_sentences: int = 1_000 + max_n_train_sentences: int = 1000 max_n_test_sentences: int = sys.maxsize save_suffix: str = "" + # XXX: these are not used in the current implementation! done within data.pth already. do_lowercase: bool = False do_remove_punct: bool = False keep_logits: bool = False @@ -101,33 +102,60 @@ def preprocess_zh_sentence(text, n=0): def process_logits(text, model, lang_code, args): # Extract necessary data - logits, offsets_mapping, tokenizer = extract( - [text], - model, - lang_code=lang_code, - stride=args.stride, - block_size=args.block_size, - batch_size=args.batch_size, - pad_last_batch=True, - verbose=False, - ) - logits = logits[0] - if offsets_mapping is not None: - offsets_mapping = offsets_mapping[0] + if isinstance(text, list): + logits = [] + for short_seq in tqdm(text, desc="Short sequences", disable=False): + current_logits, current_offsets_mapping, tokenizer = extract( + [short_seq], + model, + lang_code=lang_code, + stride=args.stride, + block_size=args.block_size, + batch_size=args.batch_size, + pad_last_batch=True, + verbose=False, + ) + current_logits = current_logits[0] + if current_offsets_mapping is not None: + current_offsets_mapping = current_offsets_mapping[0] + + if "xlm" in model.config.model_type: + tokens = tokenizer.tokenize(short_seq, verbose=False) - if "xlm" in model.config.model_type: - tokens = tokenizer.tokenize(text, verbose=False) + char_probs = token_to_char_probs(short_seq, tokens, current_logits, tokenizer, current_offsets_mapping) - # Use the vectorized function to convert token probabilities to character probabilities for the entire array - char_probs = token_to_char_probs(text, tokens, logits, tokenizer, offsets_mapping) + current_logits = char_probs + # TODO: extra treatment for Canine necessary? - logits = char_probs + logits.append(current_logits) + else: + logits, offsets_mapping, tokenizer = extract( + [text], + model, + lang_code=lang_code, + stride=args.stride, + block_size=args.block_size, + batch_size=args.batch_size, + pad_last_batch=True, + verbose=False, + ) + logits = logits[0] + if offsets_mapping is not None: + offsets_mapping = offsets_mapping[0] - if len(model.model.config.id2label) == 2: - # Igor's models: take winning logit - logits = np.expand_dims(logits.argmax(axis=1), axis=1) - # we apply sigmoid later; convert to fake logits - logits = np.log((logits + 1e-8) / (1 - logits + 1e-8)) + if "xlm" in model.config.model_type: + tokens = tokenizer.tokenize(text, verbose=False) + + # Use the vectorized function to convert token probabilities to character probabilities for the entire array + char_probs = token_to_char_probs(text, tokens, logits, tokenizer, offsets_mapping) + + logits = char_probs + + if len(model.model.config.id2label) == 2: + # Igor's models: take winning logit + logits = np.expand_dims(logits.argmax(axis=1), axis=1) + # we apply sigmoid later; convert to fake logits + logits = np.log((logits + 1e-8) / (1 - logits + 1e-8)) return logits @@ -164,11 +192,11 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st valid_sentences = [sample["text"].strip() for sample in valid_data if sample["lang"] == lang_code] assert len(valid_sentences) > 0 - valid_sentences = [ - corrupt(sentence, do_lowercase=args.do_lowercase, do_remove_punct=args.do_remove_punct) - for sentence in valid_sentences - ] - separator = Constants.SEPARATORS[lang_code] + # valid_sentences = [ + # corrupt(sentence, do_lowercase=args.do_lowercase, do_remove_punct=args.do_remove_punct) + # for sentence in valid_sentences + # ] + separator = Constants.SEPARATORS.get(lang_code, " ") valid_text = separator.join(valid_sentences) valid_logits = process_logits(valid_text, model, lang_code, args) @@ -211,39 +239,81 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st if "test_logits" not in dset_group: test_sentences = dataset["data"][: args.max_n_test_sentences] # if list of lists: flatten + # if isinstance(test_sentences[0], list): + # test_sentences = [item for sublist in test_sentences for item in sublist] + # test_sentences = [ + # corrupt(sentence, do_lowercase=args.do_lowercase, do_remove_punct=args.do_remove_punct) + # for sentence in test_sentences + # ] + # test_sentences = [preprocess_zh_sentence(sentence, args.zh_window) for sentence in test_sentences] if isinstance(test_sentences[0], list): - test_sentences = [item for sublist in test_sentences for item in sublist] - test_sentences = [ - corrupt(sentence, do_lowercase=args.do_lowercase, do_remove_punct=args.do_remove_punct) - for sentence in test_sentences - ] - test_sentences = [preprocess_zh_sentence(sentence, args.zh_window) for sentence in test_sentences] - test_text = Constants.SEPARATORS[lang_code].join(test_sentences) + # short-seq eval: list of lists + test_text = [ + Constants.SEPARATORS.get(lang_code, " ").join(sentence) for sentence in test_sentences + ] + else: + test_text = Constants.SEPARATORS.get(lang_code, " ").join(test_sentences) start_time = time.time() # Start timing for test logits processing test_logits = process_logits(test_text, model, lang_code, args) end_time = time.time() # End timing for test logits processing total_test_time += end_time - start_time # Accumulate test processing time - - test_labels = get_labels(lang_code, test_sentences, after_space=False) + if isinstance(test_sentences[0], list): + test_logit_lengths = [] + # store start and end indices for each pair, used later to slice the logits + # (h5py does not like different length np arrays as list elements) + all_logit_lengths = np.append(0, np.cumsum([len(logits) for logits in test_logits])) + # append tuple of start and end indices for each pair + for i in range(len(test_logits)): + test_logit_lengths.append((all_logit_lengths[i], all_logit_lengths[i + 1] - 1)) + test_logits = np.concatenate(test_logits) + # NOTE: handled differently than in intrinsic_pairwise.py + # here, we keep the label at the end + # in intrinsic_pairwise.py, we only consider the labels in the middle. + test_labels = [ + get_labels(lang_code, short_seq, after_space=False)[:-1] for short_seq in test_sentences + ] + + # flatten; append 0 eos to account for later indexing/slicing + test_labels = np.append(np.concatenate(test_labels), 1) + assert len(test_labels) == len(test_logits) + 1 + dset_group.create_dataset("test_logit_lengths", data=test_logit_lengths) + else: + test_labels = get_labels(lang_code, test_sentences, after_space=False) dset_group.create_dataset("test_logits", data=test_logits) dset_group.create_dataset("test_labels", data=test_labels) train_sentences = dataset["meta"].get("train_data") if train_sentences is not None and "train_logits" not in dset_group and not args.skip_adaptation: - if isinstance(train_sentences[0], list): - train_sentences = [item for sublist in train_sentences for item in sublist] - train_sentences = [ - corrupt(sentence, do_lowercase=args.do_lowercase, do_remove_punct=args.do_remove_punct) - for sentence in train_sentences - ] - train_sentences = [preprocess_zh_sentence(sentence, args.zh_window) for sentence in train_sentences] + # if isinstance(train_sentences[0], list): + # train_sentences = [item for sublist in train_sentences for item in sublist] + # train_sentences = [ + # corrupt(sentence, do_lowercase=args.do_lowercase, do_remove_punct=args.do_remove_punct) + # for sentence in train_sentences + # ] + # train_sentences = [preprocess_zh_sentence(sentence, args.zh_window) for sentence in train_sentences] train_sentences = train_sentences[: args.max_n_train_sentences] - train_text = Constants.SEPARATORS[lang_code].join(train_sentences) + if isinstance(train_sentences[0], list): + # short-seq eval: list of lists + train_text = [ + Constants.SEPARATORS.get(lang_code, " ").join(sentence) for sentence in train_sentences + ] + else: + train_text = Constants.SEPARATORS.get(lang_code, " ").join(train_sentences) train_logits = process_logits(train_text, model, lang_code, args) - train_labels = get_labels(lang_code, train_sentences, after_space=False) + if isinstance(train_sentences[0], list): + train_logits = np.concatenate(train_logits) + train_labels = [ + get_labels(lang_code, short_seq, after_space=False)[:-1] for short_seq in train_sentences + ] + + # flatten; append 0 eos to account for later indexing/slicing + train_labels = np.append(np.concatenate(train_labels), 1) + assert len(train_labels) == len(train_logits) + 1 + else: + train_labels = get_labels(lang_code, train_sentences, after_space=False) dset_group.create_dataset("train_logits", data=train_logits) dset_group.create_dataset("train_labels", data=train_labels) @@ -341,13 +411,13 @@ def main(args): for dataset_name, dataset in dsets["sentence"].items(): sentences = dataset["data"][: args.max_n_test_sentences] - if isinstance(sentences[0], list): - sentences = [item for sublist in sentences for item in sublist] - sentences = [ - corrupt(sentence, do_lowercase=args.do_lowercase, do_remove_punct=args.do_remove_punct) - for sentence in sentences - ] - sentences = [preprocess_zh_sentence(sentence, args.zh_window) for sentence in sentences] + # if isinstance(sentences[0], list): + # sentences = [item for sublist in sentences for item in sublist] + # sentences = [ + # corrupt(sentence, do_lowercase=args.do_lowercase, do_remove_punct=args.do_remove_punct) + # for sentence in sentences + # ] + # sentences = [preprocess_zh_sentence(sentence, args.zh_window) for sentence in sentences] # check if f[lang_code][dataset_name] exists if lang_code not in f or dataset_name not in f[lang_code]: continue @@ -364,13 +434,45 @@ def main(args): if clf[0] is not None: print(clf) - score_t, score_punct, _, t_indices, punct_indices = evaluate_mixture( - lang_code, - f[lang_code][dataset_name]["test_logits"][:], - sentences, - args.return_indices, - *clf, - ) + if isinstance(sentences[0], list): + acc_t, acc_punct = [], [] + score_t, score_punct = [], [] + t_indices, punct_indices = [], [] + for i, short_seq in enumerate(sentences): + start, end = f[lang_code][dataset_name]["test_logit_lengths"][i] + single_score_t, single_score_punct, info, cur_t_indices, cur_punct_indices = evaluate_mixture( + lang_code, + f[lang_code][dataset_name]["test_logits"][:][start:end], + list(short_seq), + args.return_indices, + *clf, + ) + score_t.append(single_score_t) + score_punct.append(single_score_punct) + acc_t.append(info["info_newline"]["correct_pairwise"] if info["info_newline"] else None) + acc_punct.append( + info["info_transformed"]["correct_pairwise"] if info["info_transformed"] else None + ) + # indices: accumulate from start + t_indices.extend( + [idx + start for idx in cur_t_indices["pred_indices"]] + if cur_t_indices and cur_t_indices["pred_indices"] + else [] + ) + punct_indices.extend( + [idx + start for idx in cur_punct_indices["pred_indices"]] + if cur_punct_indices and cur_punct_indices["pred_indices"] + else [] + ) + + else: + score_t, score_punct, _, t_indices, punct_indices = evaluate_mixture( + lang_code, + f[lang_code][dataset_name]["test_logits"][:], + sentences, + args.return_indices, + *clf, + ) clfs[lang_code][dataset_name] = clf @@ -381,25 +483,79 @@ def main(args): clf = [None, None, None, args.threshold] t_indices, punct_indices = None, None - score_u, _, _, u_indices, _ = evaluate_mixture( - lang_code, f[lang_code][dataset_name]["test_logits"][:], sentences, args.return_indices, *clf - ) + if isinstance(sentences[0], list): + acc_u = [] + score_u = [] + u_indices, true_indices = [], [] + length = 0 + for i, short_seq in enumerate(sentences): + start, end = f[lang_code][dataset_name]["test_logit_lengths"][i] + single_score_u, _, info, cur_u_indices, _ = evaluate_mixture( + lang_code, + f[lang_code][dataset_name]["test_logits"][:][start:end], + list(short_seq), + args.return_indices, + *clf, + ) + score_u.append(single_score_u) + acc_u.append(info["info_newline"]["correct_pairwise"]) + # indices: accumulate from start + u_indices.extend( + [idx + start for idx in cur_u_indices["pred_indices"]] if cur_u_indices["pred_indices"] else [] + ) + true_indices.extend( + [idx + start for idx in cur_u_indices["true_indices"]] if cur_u_indices["true_indices"] else [] + ) + length += cur_u_indices["length"] - 1 - results[lang_code][dataset_name] = { - "u": score_u, - "t": score_t, - "punct": score_punct, - } + else: + score_u, _, _, u_indices, _ = evaluate_mixture( + lang_code, f[lang_code][dataset_name]["test_logits"][:], sentences, args.return_indices, *clf + ) - if args.return_indices: - indices[lang_code][dataset_name] = { - "u": u_indices["pred_indices"], - "t": t_indices["pred_indices"] if t_indices is not None else None, - "punct": punct_indices["pred_indices"] if punct_indices is not None else None, - "true_indices": u_indices["true_indices"], - "length": u_indices["length"], + if isinstance(sentences[0], list): + score_u = np.mean(score_u) + score_t = np.mean(score_t) if score_t and not args.skip_adaptation else None + score_punct = ( + np.mean(score_punct) if score_punct and not (args.skip_punct or args.skip_adaptation) else None + ) + acc_u = np.mean(acc_u) + acc_t = np.mean(acc_t) if score_t else None + acc_punct = np.mean(acc_punct) if score_punct else None + + results[lang_code][dataset_name] = { + "u": score_u, + "t": score_t, + "punct": score_punct, + "acc_u": acc_u, + "acc_t": acc_t, + "acc_punct": acc_punct, + } + else: + results[lang_code][dataset_name] = { + "u": score_u, + "t": score_t, + "punct": score_punct, } + if args.return_indices: + if isinstance(sentences[0], list): + indices[lang_code][dataset_name] = { + "u": u_indices, + "t": t_indices, + "punct": punct_indices, + "true_indices": true_indices, + "length": length, + } + else: + indices[lang_code][dataset_name] = { + "u": u_indices["pred_indices"], + "t": t_indices["pred_indices"] if t_indices is not None else None, + "punct": punct_indices["pred_indices"] if punct_indices is not None else None, + "true_indices": u_indices["true_indices"], + "length": u_indices["length"], + } + if score_u is not None: u_scores.append((score_u, lang_code)) if score_t is not None: @@ -453,6 +609,7 @@ def main(args): Constants.CACHE_DIR / "intrinsic" / f"{save_str}_IDX.json", "w", ), + default=int, # indent=4, ) print(Constants.CACHE_DIR / "intrinsic" / f"{save_str}_IDX.json")