Skip to content

Commit

Permalink
Merge branch 'efficient-transfer' of https://github.com/bminixhofer/w…
Browse files Browse the repository at this point in the history
…tpsplit into efficient-transfer
  • Loading branch information
markus583 committed May 16, 2024
2 parents 9feb8d5 + 2c99884 commit 9994910
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 82 deletions.
40 changes: 37 additions & 3 deletions wtpsplit/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def evaluate_sentences(

assert len(labels) == len(predictions)

return f1_score(labels, predictions), {
"recall": recall_score(labels, predictions),
"precision": precision_score(labels, predictions),
return f1_score(labels, predictions, zero_division=0), {
"recall": recall_score(labels, predictions, zero_division=0),
"precision": precision_score(labels, predictions, zero_division=0),
# pairwise: ignore end-of-text label
# only correct if we correctly predict the single newline in between the sentence pair
# --> no false positives, no false negatives allowed!
Expand All @@ -84,6 +84,40 @@ def evaluate_sentences(
"length": len(labels),
}

def evaluate_sentences_llm(
labels, predictions, return_indices: bool = False, exclude_every_k: int = 0
):

assert len(labels) == len(predictions)

if exclude_every_k > 0:
true_end_indices = np.where(labels == 1)[0]
# every k-th from those where labels are 1
indices_to_remove = true_end_indices[exclude_every_k-1::exclude_every_k]

# mask for indices to keep
mask = np.ones_like(labels, dtype=bool)
mask[indices_to_remove] = False
mask[-1] = False # last is always excluded

# remove indices
labels = labels[mask]
predictions = predictions[mask]

assert len(labels) == len(predictions)

return {
"f1": f1_score(labels, predictions, zero_division=0),
"recall": recall_score(labels, predictions, zero_division=0),
"precision": precision_score(labels, predictions, zero_division=0),
# pairwise: ignore end-of-text label
# only correct if we correctly predict the single newline in between the sentence pair
# --> no false positives, no false negatives allowed!
"correct_pairwise": int(np.all(labels[:-1] == predictions[:-1])),
"true_indices": np.where(labels)[0].tolist() if return_indices else None,
"predicted_indices": np.where(predictions)[0].tolist() if return_indices else None,
"length": len(labels),
}

def train_mixture(lang_code, original_train_x, train_y, n_subsample=None, features=None, skip_punct: bool = False):
original_train_x = torch.from_numpy(original_train_x).float()
Expand Down
34 changes: 34 additions & 0 deletions wtpsplit/evaluation/download_spacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import subprocess

SPACY_LANG_TO_DP_MODEL = {
"ca": "ca_core_news_sm",
"zh": "zh_core_web_sm",
"hr": "hr_core_news_sm",
"da": "da_core_news_sm",
"nl": "nl_core_news_sm",
"en": "en_core_web_sm",
"fi": "fi_core_news_sm",
"fr": "fr_core_news_sm",
"de": "de_core_news_sm",
"el": "el_core_news_sm",
"it": "it_core_news_sm",
"ja": "ja_core_news_sm",
"ko": "ko_core_news_sm",
"lt": "lt_core_news_sm",
"mk": "mk_core_news_sm",
"nb": "nb_core_news_sm",
"pl": "pl_core_news_sm",
"pt": "pt_core_news_sm",
"ro": "ro_core_news_sm",
"ru": "ru_core_news_sm",
"es": "es_core_news_sm",
"sv": "sv_core_news_sm",
"uk": "uk_core_news_sm",
}

def download_models():
for lang, model in SPACY_LANG_TO_DP_MODEL.items():
subprocess.run(["python3", "-m", "spacy", "download", model])

if __name__ == "__main__":
download_models()
84 changes: 60 additions & 24 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
import time
import logging
import sys
import re

import h5py
import skops.io as sio
import torch
from datasets import load_dataset
from tqdm.auto import tqdm
Expand All @@ -19,8 +17,9 @@

import wtpsplit.models # noqa: F401
from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture, token_to_char_probs
from wtpsplit.evaluation.intrinsic_baselines import split_language_data
from wtpsplit.extract import PyTorchWrapper, extract
from wtpsplit.utils import Constants, corrupt
from wtpsplit.utils import Constants

logger = logging.getLogger()
logger.setLevel(logging.WARNING)
Expand Down Expand Up @@ -53,17 +52,16 @@ class Args:
include_langs: List[str] = None
custom_language_list: str = None
threshold: float = 0.01
max_n_train_sentences: int = 1000
max_n_test_sentences: int = 1000
save_suffix: str = ""
# XXX: these are not used in the current implementation! done within data.pth already.
max_n_train_sentences: int = 10000
max_n_test_sentences: int = -1
keep_logits: bool = False
skip_adaptation: bool = False
skip_punct: bool = True
skip_corrupted: bool = False
clf_from_scratch: bool = False
clf_from_scratch: bool = False # for FT + LoRA
return_indices: bool = True
skip_punct: bool = True
exclude_every_k: int = 10
save_suffix: str = ""


def process_logits(text, model, lang_code, args):
Expand Down Expand Up @@ -168,14 +166,25 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
for dataset_name, dataset in tqdm(eval_data[lang_code]["sentence"].items(), desc=lang_code):
if args.skip_corrupted and "corrupted" in dataset_name:
continue
elif "nllb" in dataset_name:
continue
if "corrupted" in dataset_name and dataset_name != "ted2020-corrupted-asr":
if "corrupted-asr" in dataset_name and (
"lyrics" not in dataset_name
and "short" not in dataset_name
and "code" not in dataset_name
and "ted" not in dataset_name
and "legal" not in dataset_name
):
print("SKIP: ", lang_code, dataset_name)
continue
if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name):
print("SKIP: ", lang_code, dataset_name)
continue
if "social-media" in dataset_name:
continue
if "nllb" in dataset_name:
continue
if "-" in lang_code and "canine" in args.model_path and not "no-adapters" in args.model_path:
# code-switched data: eval 2x
lang_code = lang_code.split("_")[1].lower()
try:
if args.adapter_path:
if args.clf_from_scratch:
Expand Down Expand Up @@ -263,7 +272,11 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
dset_group.create_dataset("test_logit_lengths", data=test_logit_lengths)
else:
test_labels = get_labels(lang_code, test_sentences, after_space=False)

if args.skip_punct:
# remove punct logits
test_logits = test_logits[:, 0]
# back to [N, 1]
test_logits = np.expand_dims(test_logits, axis=1)
dset_group.create_dataset("test_logits", data=test_logits)
dset_group.create_dataset("test_labels", data=test_labels)

Expand Down Expand Up @@ -291,6 +304,11 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
else:
train_labels = get_labels(lang_code, train_sentences, after_space=False)

if args.skip_punct:
# remove punct logits
train_logits = train_logits[:, 0]
# back to [N, 1]
train_logits = np.expand_dims(train_logits, axis=1)
dset_group.create_dataset("train_logits", data=train_logits)
dset_group.create_dataset("train_labels", data=train_labels)

Expand Down Expand Up @@ -323,6 +341,8 @@ def main(args):
save_str = f"{save_model_path.replace('/','_')}_b{args.block_size}_s{args.stride}"

eval_data = torch.load(args.eval_data_path)
if "canine" in args.model_path and not "no-adapters" in args.model_path:
eval_data = split_language_data(eval_data)
if args.valid_text_path is not None:
valid_data = load_dataset("parquet", data_files=args.valid_text_path, split="train")
else:
Expand All @@ -343,7 +363,7 @@ def main(args):
model.model.classifier = torch.nn.Sequential(clf, torch.nn.Linear(clf.out_features, 1))

save_str += f"{args.save_suffix}"
if args.max_n_test_sentences < sys.maxsize or args.max_n_test_sentences != -1:
if args.max_n_test_sentences < sys.maxsize and args.max_n_test_sentences != -1:
save_str += f"_n{args.max_n_test_sentences}"
if args.max_n_test_sentences == -1:
args.max_n_test_sentences = sys.maxsize
Expand Down Expand Up @@ -517,19 +537,35 @@ def main(args):
if args.return_indices:
if isinstance(sentences[0], list):
indices[lang_code][dataset_name] = {
"u": u_indices,
"t": t_indices,
"punct": punct_indices,
"true_indices": true_indices,
"length": length,
"u": {"predicted_indices": u_indices, "true_indices": true_indices, "length": length},
"t": {"predicted_indices": t_indices, "true_indices": t_indices, "length": length}
if t_indices
else None,
"punct": {"predicted_indices": punct_indices, "true_indices": t_indices, "length": length}
if punct_indices
else None,
}
else:
indices[lang_code][dataset_name] = {
"u": u_indices["pred_indices"],
"t": t_indices["pred_indices"] if t_indices is not None else None,
"punct": punct_indices["pred_indices"] if punct_indices is not None else None,
"true_indices": u_indices["true_indices"],
"length": u_indices["length"],
"u": {
"predicted_indices": [u_indices["pred_indices"]],
"true_indices": [u_indices["true_indices"]],
"length": [u_indices["length"]],
},
"t": {
"predicted_indices": [t_indices["pred_indices"]],
"true_indices": [t_indices["true_indices"]],
"length": [t_indices["length"]],
}
if t_indices is not None
else None,
"punct": {
"predicted_indices": [punct_indices["pred_indices"]],
"true_indices": [punct_indices["true_indices"]],
"length": [punct_indices["length"]],
}
if punct_indices is not None
else None,
}

if score_u is not None:
Expand Down
Loading

0 comments on commit 9994910

Please sign in to comment.