Skip to content

Commit

Permalink
final eval setup?
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed May 15, 2024
1 parent f7554d3 commit 7ad9a3e
Showing 1 changed file with 41 additions and 21 deletions.
62 changes: 41 additions & 21 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
import time
import logging
import sys
import re

import h5py
import skops.io as sio
import torch
from datasets import load_dataset
from tqdm.auto import tqdm
Expand All @@ -20,7 +18,7 @@
import wtpsplit.models # noqa: F401
from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture, token_to_char_probs
from wtpsplit.extract import PyTorchWrapper, extract
from wtpsplit.utils import Constants, corrupt
from wtpsplit.utils import Constants

logger = logging.getLogger()
logger.setLevel(logging.WARNING)
Expand Down Expand Up @@ -53,17 +51,16 @@ class Args:
include_langs: List[str] = None
custom_language_list: str = None
threshold: float = 0.01
max_n_train_sentences: int = 1000
max_n_train_sentences: int = 10000
max_n_test_sentences: int = -1
save_suffix: str = ""
# XXX: these are not used in the current implementation! done within data.pth already.
keep_logits: bool = False
skip_adaptation: bool = False
skip_punct: bool = True
skip_corrupted: bool = False
clf_from_scratch: bool = False
clf_from_scratch: bool = False # for FT + LoRA
return_indices: bool = True
skip_punct: bool = True
exclude_every_k: int = 10
save_suffix: str = ""


def process_logits(text, model, lang_code, args):
Expand Down Expand Up @@ -168,14 +165,21 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
for dataset_name, dataset in tqdm(eval_data[lang_code]["sentence"].items(), desc=lang_code):
if args.skip_corrupted and "corrupted" in dataset_name:
continue
elif "nllb" in dataset_name:
continue
if "corrupted" in dataset_name and dataset_name != "ted2020-corrupted-asr":
if "corrupted-asr" in dataset_name and (
"lyrics" not in dataset_name
and "short" not in dataset_name
and "code" not in dataset_name
and "ted" not in dataset_name
):
print("SKIP: ", lang_code, dataset_name)
continue
if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name):
print("SKIP: ", lang_code, dataset_name)
continue
if "social-media" in dataset_name:
continue
if "nllb" in dataset_name:
continue
try:
if args.adapter_path:
if args.clf_from_scratch:
Expand Down Expand Up @@ -517,19 +521,35 @@ def main(args):
if args.return_indices:
if isinstance(sentences[0], list):
indices[lang_code][dataset_name] = {
"u": u_indices,
"t": t_indices,
"punct": punct_indices,
"true_indices": true_indices,
"length": length,
"u": {"predicted_indices": u_indices, "true_indices": true_indices, "length": length},
"t": {"predicted_indices": t_indices, "true_indices": t_indices, "length": length}
if t_indices
else None,
"punct": {"predicted_indices": punct_indices, "true_indices": t_indices, "length": length}
if punct_indices
else None,
}
else:
indices[lang_code][dataset_name] = {
"u": u_indices["pred_indices"],
"t": t_indices["pred_indices"] if t_indices is not None else None,
"punct": punct_indices["pred_indices"] if punct_indices is not None else None,
"true_indices": u_indices["true_indices"],
"length": u_indices["length"],
"u": {
"predicted_indices": [u_indices["pred_indices"]],
"true_indices": [u_indices["true_indices"]],
"length": [u_indices["length"]],
},
"t": {
"predicted_indices": [t_indices["pred_indices"]],
"true_indices": [t_indices["true_indices"]],
"length": [t_indices["length"]],
}
if t_indices is not None
else None,
"punct": {
"predicted_indices": [punct_indices["pred_indices"]],
"true_indices": [punct_indices["true_indices"]],
"length": [punct_indices["length"]],
}
if punct_indices is not None
else None,
}

if score_u is not None:
Expand Down

0 comments on commit 7ad9a3e

Please sign in to comment.