diff --git a/wtpsplit/evaluation/intrinsic.py b/wtpsplit/evaluation/intrinsic.py index 4cc42fcb..f93fdd01 100644 --- a/wtpsplit/evaluation/intrinsic.py +++ b/wtpsplit/evaluation/intrinsic.py @@ -6,10 +6,8 @@ import time import logging import sys -import re import h5py -import skops.io as sio import torch from datasets import load_dataset from tqdm.auto import tqdm @@ -20,7 +18,7 @@ import wtpsplit.models # noqa: F401 from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture, token_to_char_probs from wtpsplit.extract import PyTorchWrapper, extract -from wtpsplit.utils import Constants, corrupt +from wtpsplit.utils import Constants logger = logging.getLogger() logger.setLevel(logging.WARNING) @@ -53,17 +51,16 @@ class Args: include_langs: List[str] = None custom_language_list: str = None threshold: float = 0.01 - max_n_train_sentences: int = 1000 + max_n_train_sentences: int = 10000 max_n_test_sentences: int = -1 - save_suffix: str = "" - # XXX: these are not used in the current implementation! done within data.pth already. keep_logits: bool = False skip_adaptation: bool = False + skip_punct: bool = True skip_corrupted: bool = False - clf_from_scratch: bool = False + clf_from_scratch: bool = False # for FT + LoRA return_indices: bool = True - skip_punct: bool = True exclude_every_k: int = 10 + save_suffix: str = "" def process_logits(text, model, lang_code, args): @@ -168,14 +165,21 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st for dataset_name, dataset in tqdm(eval_data[lang_code]["sentence"].items(), desc=lang_code): if args.skip_corrupted and "corrupted" in dataset_name: continue - elif "nllb" in dataset_name: - continue - if "corrupted" in dataset_name and dataset_name != "ted2020-corrupted-asr": + if "corrupted-asr" in dataset_name and ( + "lyrics" not in dataset_name + and "short" not in dataset_name + and "code" not in dataset_name + and "ted" not in dataset_name + ): print("SKIP: ", lang_code, dataset_name) continue if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name): print("SKIP: ", lang_code, dataset_name) continue + if "social-media" in dataset_name: + continue + if "nllb" in dataset_name: + continue try: if args.adapter_path: if args.clf_from_scratch: @@ -517,19 +521,35 @@ def main(args): if args.return_indices: if isinstance(sentences[0], list): indices[lang_code][dataset_name] = { - "u": u_indices, - "t": t_indices, - "punct": punct_indices, - "true_indices": true_indices, - "length": length, + "u": {"predicted_indices": u_indices, "true_indices": true_indices, "length": length}, + "t": {"predicted_indices": t_indices, "true_indices": t_indices, "length": length} + if t_indices + else None, + "punct": {"predicted_indices": punct_indices, "true_indices": t_indices, "length": length} + if punct_indices + else None, } else: indices[lang_code][dataset_name] = { - "u": u_indices["pred_indices"], - "t": t_indices["pred_indices"] if t_indices is not None else None, - "punct": punct_indices["pred_indices"] if punct_indices is not None else None, - "true_indices": u_indices["true_indices"], - "length": u_indices["length"], + "u": { + "predicted_indices": [u_indices["pred_indices"]], + "true_indices": [u_indices["true_indices"]], + "length": [u_indices["length"]], + }, + "t": { + "predicted_indices": [t_indices["pred_indices"]], + "true_indices": [t_indices["true_indices"]], + "length": [t_indices["length"]], + } + if t_indices is not None + else None, + "punct": { + "predicted_indices": [punct_indices["pred_indices"]], + "true_indices": [punct_indices["true_indices"]], + "length": [punct_indices["length"]], + } + if punct_indices is not None + else None, } if score_u is not None: