diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index a39270d..186947f 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -25,6 +25,7 @@ jobs: python -m pip install --upgrade pip pip install torch --extra-index-url https://download.pytorch.org/whl/cpu pip install . + pip install -r requirements-test.txt - name: Lint with flake8 run: | pip install flake8 diff --git a/README.md b/README.md index de6ee6a..ed586ff 100644 --- a/README.md +++ b/README.md @@ -126,24 +126,25 @@ model.generate(..., references_config=references_config) ``` ### Choosing a metric -By default, **mbr** integrates metrics via the [Hugging Face Evaluate](https://github.com/huggingface/evaluate) library. +By default, **mbr** uses [fastChrF](https://github.com/jvamvas/fastChrF), which is optimized for efficient comparison of many samples to many references. + +You can also plug in metrics from the [**Hugging Face Evaluate**](https://github.com/huggingface/evaluate) library. A full list of metrics is found [here](https://huggingface.co/metrics). Some typical choices are: -- [ChrF](https://huggingface.co/spaces/evaluate-metric/chrf) ([Popović, 2015](https://www.aclweb.org/anthology/W15-3049/)) - [COMET](https://huggingface.co/spaces/evaluate-metric/comet) ([Rei et al., 2020](https://aclanthology.org/2020.emnlp-main.213/)) - [BLEURT](https://huggingface.co/spaces/evaluate-metric/bleurt) ([Sellam et al., 2020](https://aclanthology.org/2020.acl-main.704)) -In the MBR config, you can either specify the metric's name (e.g., `"chrf"`, `"comet"`) or pass an `evaluate.Metric` object directly. +To use a metric from Hugging Face, either specify the metric's name (e.g., `"comet"`, `"bleurt"`) or pass an `evaluate.Metric` object directly. Since different metrics output differently structured dicts, you need to specify the `metric_output_field` that should be used as the metric score. ```python from evaluate import load -metric = load('chrf') +metric = load('bleu') mbr_config = MBRGenerationConfig( metric=metric, - metric_output_field="score", # the ChrF metric returns a dict with a "score" field + metric_output_field="bleu", # the BLEU metric returns a dict with a "bleu" field ... ) ``` @@ -188,8 +189,9 @@ model.generate(..., metric_runner=metric_runner) ### Optimizations MBR decoding is notoriously slow. **mbr** implements some optimizations: - Cached encoder outputs: For encoder-decoder models, the encoder outputs are computed only once and reused during sampling. -- Cached metric: The metric is computed only once for each unique sample–reference pair (since there will be duplicate samples and references). -- Optimized COMET metric: Inspired by [Amrhein & Sennrich (2022)](https://aclanthology.org/2022.aacl-main.83/), sequence embeddings are cached and reused for all pairwise comparisons. +- Optimized ChrF metric: [fastChrF](https://github.com/jvamvas/fastChrF) is used by default, which is a streamlined ChrF variant for MBR, implemented in Rust. +- Optimized COMET metric: Inspired by [Amrhein & Sennrich (2022)](https://aclanthology.org/2022.aacl-main.83/), `CometMetricRunner` caches sequence embeddings and reuses them for all pairwise comparisons. +- Cached metrics: Most metrics are computed only once for each unique sample–reference pair (since there will be duplicate samples and references). ## Example scripts @@ -199,6 +201,9 @@ The [experiments](experiments) directory contains the code for reproductions of - [MBR with neural metrics and epsilon sampling for machine translation](experiments/freitag-et-al-2023-epsilon) ([Freitag et al., 2023](https://arxiv.org/abs/2305.09860)) - [MBR for summarization](experiments/bertsch-et-al-2023-mbr) ([Bertsch et al., 2023](https://arxiv.org/abs/2310.01387)) +### Other experiments +- Comparison of [fastChrF](https://github.com/jvamvas/fastChrF) to standard sentence-level ChrF ([Popović, 2015](https://aclanthology.org/W15-3049/)) as a metric for MBR + ## Related projects - https://github.com/roxot/mbr-nmt: Original implementation ([demo](https://colab.research.google.com/github/probabll/demo-mbr-nmt/blob/main/German-English.ipynb)) - https://github.com/ZurichNLP/understanding-mbr: MBR with Sockeye @@ -206,6 +211,10 @@ The [experiments](experiments) directory contains the code for reproductions of - https://github.com/rainavyas/mbr_gec: MBR for Grammatical Error Correction ## Changelog + +- v0.3.0 (draft) + - Use [fastChrF](https://github.com/jvamvas/fastChrF) as default metric + - v0.2.0 - **Breaking change:** Rename `MBRGenerationConfig` to `MBRConfig` - **Breaking change:** `MetricRunner` now returns a `MetricOutput` dict instead of the raw tensor of scores. diff --git a/experiments/README.md b/experiments/README.md index b062d12..76bfffa 100644 --- a/experiments/README.md +++ b/experiments/README.md @@ -6,3 +6,6 @@ - It's MBR All the Way Down: Modern Generation Techniques Through the Lens of Minimum Bayes Risk (Bertsch et al., 2023) - Epsilon Sampling Rocks: Investigating Sampling Strategies for Minimum Bayes Risk Decoding for Machine Translation (Freitag et al., 2023) - Understanding the Properties of Minimum Bayes Risk Decoding in Neural Machine Translation (Müller & Sennrich, ACL-IJCNLP 2021) + +**Other experiments** +- Comparison of [fastChrF](https://github.com/jvamvas/fastChrF) to standard sentence-level ChrF ([Popović, 2015](https://aclanthology.org/W15-3049/)) as a metric for MBR. \ No newline at end of file diff --git a/experiments/chrf-vs-fastchrf/README.md b/experiments/chrf-vs-fastchrf/README.md new file mode 100644 index 0000000..dc15dd1 --- /dev/null +++ b/experiments/chrf-vs-fastchrf/README.md @@ -0,0 +1,32 @@ +Comparison of [fastChrF](https://github.com/jvamvas/fastChrF) to standard sentence-level ChrF ([Popović, 2015](https://aclanthology.org/W15-3049/)) as a metric for MBR. + +## Setup +* Task: Machine translation +* Translation directions: en–de, de–en, en–ru, ru–en +* Model: [facebook/wmt19-*](https://huggingface.co/facebook/wmt19-en-de) ([Ng et al., 2019](https://aclanthology.org/W19-5333/)). +* MBR metrics: `fastchrf.pairwise_chrf` (a fast implementation of standard ChrF) and `fastchrf.aggregate_chrf` (a streamlined ChrF variant for MBR) +* Number of samples: 256 +* Sampling approach: epsilon sampling with ε=0.02 +* Samples and references are the same +* Test set: newstest2019 +* Evaluation metrics: chrF ([sacreBLEU](https://github.com/mjpost/sacrebleu)) and COMET-22 ([Rei et al., 2022](https://aclanthology.org/2022.wmt-1.52/)) +* Baseline: beam search with beam size 4 + +## Results +| Language Pair | Method | ChrF | COMET | duration (s) | +|---------------|--------------------------------------|---------:|----------:|-------------:| +| en-de | MBR with `fastchrf.pairwise_chrf` | 67.7 | 0.867 | 7798 | +| en-de | MBR with `fastchrf.aggregate_chrf` | 67.7 | 0.867 | 7480 | +| en-de | Beam search | 67.7 | 0.868 | 62 | +| de-en | MBR with `fastchrf.pairwise_chrf` | 65.4 | 0.851 | 6894 | +| de-en | MBR with `fastchrf.aggregate_chrf` | 65.6 | 0.850 | 6849 | +| de-en | Beam search | 65.1 | 0.851 | 53 | +| en-ru | MBR with `fastchrf.pairwise_chrf` | 57.5 | 0.862 | 7802 | +| en-ru | MBR with `fastchrf.aggregate_chrf` | 57.5 | 0.862 | 7465 | +| en-ru | Beam search | 56.9 | 0.863 | 64 | +| ru-en | MBR with `fastchrf.pairwise_chrf` | 64.2 | 0.847 | 7541 | +| ru-en | MBR with `fastchrf.aggregate_chrf` | 64.3 | 0.848 | 6689 | +| ru-en | Beam search | 63.5 | 0.847 | 61 | +| **Average** | **MBR with `fastchrf.pairwise_chrf`** | **63.7** | **0.857** | **7509** | +| **Average** | **MBR with `fastchrf.aggregate_chrf`** | **63.7** | **0.857** | **7121** | +| **Average** | **Beam search** | **63.3** | **0.857** | **60** | \ No newline at end of file diff --git a/experiments/chrf-vs-fastchrf/run_experiment.py b/experiments/chrf-vs-fastchrf/run_experiment.py new file mode 100644 index 0000000..04befcd --- /dev/null +++ b/experiments/chrf-vs-fastchrf/run_experiment.py @@ -0,0 +1,142 @@ +import sys +import time +from copy import deepcopy +from pathlib import Path + +import evaluate +import jsonlines +import sacrebleu +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import FSMTForConditionalGeneration, AutoTokenizer, pipeline, set_seed, GenerationConfig + +from mbr import MBR, MBRConfig + +language_pair = sys.argv[1] +assert language_pair in ["de-en", "en-de", "en-ru", "ru-en"] + +batch_size = 32 + +results_file = jsonlines.open(Path(__file__).parent / f"results_{language_pair}.jsonl", "w") + +model_name = f"facebook/wmt19-{language_pair}" +model = MBR(FSMTForConditionalGeneration).from_pretrained(model_name) +tokenizer = AutoTokenizer.from_pretrained(model_name) +mt_pipeline = pipeline( + "translation_" + language_pair.split("-")[0] + "_to_" + language_pair.split("-")[1], + model=model, + tokenizer=tokenizer, + device=(0 if torch.cuda.is_available() else -1), +) +evaluation_metric_chrf = evaluate.load("chrf") +evaluation_metric_comet = evaluate.load("comet", "Unbabel/wmt22-comet-da") + +src_path = sacrebleu.get_source_file("wmt19", language_pair) +ref_path = sacrebleu.get_reference_files("wmt19", language_pair)[0] +dataset = load_dataset("text", data_files={"test": src_path}) +references = Path(ref_path).read_text().splitlines() +assert len(dataset["test"]) == len(references) + +# MBR +generation_config = GenerationConfig.from_pretrained(model_name) +generation_config.do_sample = True +generation_config.num_beams = 1 +generation_config.early_stopping = False +generation_config.epsilon_cutoff = 0.02 + +base_mbr_config = MBRConfig( + num_samples=256, + num_references=256, +) +base_mbr_config.metric_cache_size = batch_size * base_mbr_config.num_samples * base_mbr_config.num_references +mbr_configs = {} + +# MBR with fastchrf.pairwise_chrf +mbr_config = deepcopy(base_mbr_config) +mbr_config.metric = "fastchrf-pairwise" +mbr_configs["MBR with fastchrf.pairwise_chrf"] = mbr_config + +# MBR with fastchrf.aggregate_chrf +mbr_config = deepcopy(base_mbr_config) +mbr_config.metric = "fastchrf-aggregate" +mbr_configs["MBR with fastchrf.aggregate_chrf"] = mbr_config + +for method, mbr_config in mbr_configs.items(): + + set_seed(42) + time_start = time.time() + outputs = mt_pipeline( + dataset["test"]["text"], + mbr_config=mbr_config, + generation_config=generation_config, + tokenizer=tokenizer, + batch_size=batch_size, + progress_bar=True + ) + translations = [] + for batch in tqdm(outputs): + if isinstance(batch, dict): + batch = [batch] + translations += [translation["translation_text"] for translation in batch] + time_end = time.time() + + chrf_score = evaluation_metric_chrf.compute( + predictions=translations, + references=references, + ) + comet_score = evaluation_metric_comet.compute( + predictions=translations, + references=references, + sources=dataset["test"]["text"], + gpus=0, + ) + results_file.write({ + "language_pair": language_pair, + "method": method, + "chrf": chrf_score["score"], + "comet22": comet_score["mean_score"], + "duration": time_end - time_start, + "translations": translations, + }) + +# Beam search +model = FSMTForConditionalGeneration.from_pretrained(model_name).half().to(mt_pipeline.device) +mt_pipeline.model = model +generation_config = GenerationConfig.from_pretrained(model_name) +generation_config.num_beams = 4 + +set_seed(42) +time_start = time.time() +outputs = mt_pipeline( + dataset["test"]["text"], + generation_config=generation_config, + batch_size=batch_size, +) +translations = [] +for batch in tqdm(outputs): + if isinstance(batch, dict): + batch = [batch] + translations += [translation["translation_text"] for translation in batch] +time_end = time.time() + +chrf_score = evaluation_metric_chrf.compute( + predictions=translations, + references=references, +) +comet_score = evaluation_metric_comet.compute( + predictions=translations, + references=references, + sources=dataset["test"]["text"], + gpus=0, +) +results_file.write({ + "language_pair": language_pair, + "method": f"beam search (beam size {generation_config.num_beams})", + "chrf": chrf_score["score"], + "comet22": comet_score["mean_score"], + "duration": time_end - time_start, + "translations": translations, +}) + +results_file.close() diff --git a/experiments/requirements.txt b/experiments/requirements.txt index 99636bd..a5b9f59 100644 --- a/experiments/requirements.txt +++ b/experiments/requirements.txt @@ -1,5 +1,6 @@ jsonlines==4.0.0 datasets==2.14.6 +sacrebleu==2.3.1 sacremoses==0.0.53 # For OpusMT nltk==3.8.1 rouge_score==0.1.2 diff --git a/pyproject.toml b/pyproject.toml index 2900685..74e1f1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,9 +10,9 @@ requires-python = ">=3.9" dependencies = [ "transformers", "evaluate", - "sacrebleu", "cachetools", "tqdm", + "fastchrf", ] classifiers = [ "Programming Language :: Python :: 3", diff --git a/requirements-dev.txt b/requirements-dev.txt index 6ee6607..aca6bbc 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ +sacrebleu==2.4.0 unbabel-comet==2.1.1 git+https://github.com/google-research/bleurt.git sentencepiece==0.1.99 # M2M100 model diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 0000000..16cffc9 --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,2 @@ +sacrebleu==2.4.0 + diff --git a/src/mbr/generation/configuration_utils.py b/src/mbr/generation/configuration_utils.py index 7d3923c..07fba8d 100644 --- a/src/mbr/generation/configuration_utils.py +++ b/src/mbr/generation/configuration_utils.py @@ -12,7 +12,7 @@ class MBRConfig: Example: ```python - >>> config = MBRConfig(num_samples=10, num_references=10, metric="chrf") + >>> config = MBRConfig(num_samples=10, num_references=10, metric="fastchrf") >>> model.generate(..., mbr_config=config) ``` @@ -31,7 +31,7 @@ class MBRConfig: Number of samples generated. 1 means no MBR decoding. num_references (`int`, *optional*, defaults to `num_samples`): Number of pseudo-references used for MBR decoding. - metric (`str` or `~evaluate.Metric`, *optional*, defaults to 'chrf'): + metric (`str` or `~evaluate.Metric`, *optional*, defaults to 'fastchrf'): Metric used for MBR decoding. metric_config_name (`str`, *optional*, defaults to None): Metric configuration to pass to `evaluate.load` (e.g., the model for a trained metric, such as @@ -71,7 +71,7 @@ def __init__(self, **kwargs): # Parameters that control the generation strategy used self.num_samples = kwargs.pop("num_samples", 10) self.num_references = kwargs.pop("num_references", self.num_samples) - self.metric = kwargs.pop("metric", "chrf") + self.metric = kwargs.pop("metric", "fastchrf") self.metric_config_name = kwargs.pop("metric_config_name", None) self.metric_output_field = kwargs.pop("metric_output_field", "score") self.metric_kwargs = kwargs.pop("metric_kwargs", {}) diff --git a/src/mbr/generation/utils.py b/src/mbr/generation/utils.py index a6397f5..2c0bf5d 100644 --- a/src/mbr/generation/utils.py +++ b/src/mbr/generation/utils.py @@ -15,6 +15,7 @@ from transformers.utils import logging, ModelOutput from mbr.generation.configuration_utils import MBRConfig +from mbr.metrics import load_metric_runner from mbr.metrics.base import MetricRunner, MetricOutput if TYPE_CHECKING: @@ -476,7 +477,7 @@ def generate( # 15. apply metric to samples if metric_runner is None: - metric_runner = MetricRunner(mbr_config, tokenizer) + metric_runner = load_metric_runner(mbr_config, tokenizer) if isinstance(samples[0], ModelOutput): sample_ids = tuple(sample.sequences for sample in samples) diff --git a/src/mbr/metrics/__init__.py b/src/mbr/metrics/__init__.py index 9199218..fcbb6a0 100644 --- a/src/mbr/metrics/__init__.py +++ b/src/mbr/metrics/__init__.py @@ -1 +1,13 @@ -from mbr.metrics.base import metric_is_source_based +from mbr import MBRConfig +from mbr.metrics.base import metric_is_source_based, MetricRunner + + +def load_metric_runner(mbr_config: MBRConfig, tokenizer=None) -> MetricRunner: + if mbr_config.metric in {"fastchrf", "aggregate_chrf", "fastchrf.aggregate_chrf"}: + from mbr.metrics.fastchrf import FastChrfMetricRunner + return FastChrfMetricRunner(mbr_config, tokenizer, compute_pairwise_average=False) + elif mbr_config.metric in {"pairwise_chrf", "fastchrf.pairwise_chrf"}: + from mbr.metrics.fastchrf import FastChrfMetricRunner + return FastChrfMetricRunner(mbr_config, tokenizer, compute_pairwise_average=True) + else: + return MetricRunner(mbr_config, tokenizer) diff --git a/src/mbr/metrics/fastchrf.py b/src/mbr/metrics/fastchrf.py new file mode 100644 index 0000000..881376f --- /dev/null +++ b/src/mbr/metrics/fastchrf.py @@ -0,0 +1,111 @@ +from typing import List, Tuple + +import torch +from fastchrf import pairwise_chrf, aggregate_chrf +from transformers import PreTrainedTokenizerBase + +from mbr import MetricRunner, MBRConfig, MetricOutput + + +class FastChrfMetricRunner(MetricRunner): + """ + MetricRunner for fastChrF. See https://github.com/jvamvas/fastChrF for more information. + + Args: + mbr_config + tokenizer + compute_pairwise_average: Default: False. If True, use fastchr.chrf_pairwise() to calculate exact ChrF scores + for each sample-reference pair and then average them; this corresponds to a fast implementation of the + original ChrF metric. If False, use fastchr.chrf_aggregate() to directly calculate aggregate fastChrF scores + across all references; note that the result will be different from the original ChrF metric. + """ + + def __init__(self, + mbr_config: MBRConfig, + tokenizer: PreTrainedTokenizerBase, + compute_pairwise_average: bool = False, + ): + self.mbr_config = mbr_config + self.tokenizer = tokenizer + self.metric_is_source_based = False + self.char_order = mbr_config.metric_kwargs.get("char_order", 6) + self.beta = mbr_config.metric_kwargs.get("beta", 2) + self.remove_whitespace = mbr_config.metric_kwargs.get("remove_whitespace", True) + self.eps_smoothing = mbr_config.metric_kwargs.get("eps_smoothing", False) + self.compute_pairwise_average = compute_pairwise_average + + def __call__(self, + input_ids: torch.LongTensor, + sample_ids: Tuple[torch.LongTensor], + reference_ids: Tuple[torch.LongTensor], + ) -> MetricOutput: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The input sequence ids. + sample_ids (`tuple(torch.LongTensor)`): + Tuple (one element for `num_samples`) of tensors of shape `(batch_size, sequence_length)` containing + the sampled sequences. + reference_ids: + Tuple (one element for `num_references`) of tensors of shape `(batch_size, sequence_length)` containing + the reference sequences. + + Returns: + `MetricOutput` containing the metric scores. + """ + + # Detokenize + str_samples = [] # num_samples x batch_size + for sample in sample_ids: + str_samples.append(self.tokenizer.batch_decode(sample, skip_special_tokens=True)) + str_references = [] # num_references x batch_size + for reference in reference_ids: + str_references.append(self.tokenizer.batch_decode(reference, skip_special_tokens=True)) + + if len(str_samples[0]) != len(str_references[0]): + raise ValueError("Batch size of samples and references must match") + if len(str_samples) != self.mbr_config.num_samples: + raise ValueError("Number of samples must match `mbr_config.num_samples`") + if len(str_references) != self.mbr_config.num_references: + raise ValueError("Number of references must match `mbr_config.num_references`") + + # Transpose to batch_size x num_samples/num_references + str_samples = list(zip(*str_samples)) + str_references = list(zip(*str_references)) + + if self.compute_pairwise_average: + output = self._compute_pairwise_chrf(str_samples, str_references) + else: + output = self._compute_aggregate_chrf(str_samples, str_references) + return output + + def _compute_pairwise_chrf(self, samples: List[List[str]], references: List[List[str]]) -> MetricOutput: + scores_per_reference = pairwise_chrf( + samples, + references, + char_order=self.char_order, + beta=self.beta, + remove_whitespace=self.remove_whitespace, + eps_smoothing=self.eps_smoothing, + ) + scores_per_reference = torch.tensor(scores_per_reference) + scores = scores_per_reference.mean(dim=-1) + return MetricOutput( + scores=scores, + scores_per_reference=scores_per_reference, + ) + + def _compute_aggregate_chrf(self, samples: List[List[str]], references: List[List[str]]) -> MetricOutput: + scores = aggregate_chrf( + samples, + references, + char_order=self.char_order, + beta=self.beta, + remove_whitespace=self.remove_whitespace, + eps_smoothing=self.eps_smoothing, + ) + scores = torch.tensor(scores) + return MetricOutput( + scores=scores, + scores_per_reference=None, + ) diff --git a/tests/test_generate.py b/tests/test_generate.py index ad5a31f..45bfb82 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -6,7 +6,8 @@ from transformers import AutoTokenizer, GPT2LMHeadModel, M2M100ForConditionalGeneration, GenerationConfig from transformers.generation import SampleDecoderOnlyOutput, SampleEncoderDecoderOutput -from mbr import MBR, MBRConfig, MBROutput, MetricRunner, MetricOutput +from mbr import MBR, MBRConfig, MBROutput, MetricOutput +from mbr.metrics import load_metric_runner class DecoderOnlyTestCase(TestCase): @@ -60,6 +61,7 @@ def test_model_output(self): def test_model_output_extended(self): mbr_config = MBRConfig( + metric="pairwise_chrf", num_samples=5, return_dict_in_generate=True, output_scores=True, @@ -117,7 +119,7 @@ def test_metric_runner(self): "Hello, my name is", ] encoding = self.tokenizer(input_sentences, return_tensors="pt") - metric_runner = MetricRunner(mbr_config, self.tokenizer) + metric_runner = load_metric_runner(mbr_config, self.tokenizer) output = self.model.generate( **encoding, mbr_config=mbr_config, @@ -236,6 +238,7 @@ def test_model_output(self): def test_model_output_extended(self): mbr_config = MBRConfig( + metric="pairwise_chrf", num_samples=5, return_dict_in_generate=True, output_scores=True, diff --git a/tests/test_metrics.py b/tests/test_metrics.py index a7e081d..4b469d3 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -141,3 +141,33 @@ def test_comet_metric_runner(self): base_metric_scores = base_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) torch.testing.assert_close(base_metric_scores, metric_scores) + + def test_fastchrf_metric_runner__aggregate(self): + from mbr.metrics.fastchrf import FastChrfMetricRunner + metric_runner = FastChrfMetricRunner(self.mbr_config, self.tokenizer, compute_pairwise_average=False) + metric_output = metric_runner(self.input_ids, self.sample_ids, self.reference_ids) + self.assertTrue(torch.is_floating_point(metric_output.scores)) + self.assertIsNone(metric_output.scores_per_reference) + self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples + # Duplicate samples should have the same scores + torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) + # The metric scores should rank as expected, given the test strings in self.samples and self.references + self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) + + def test_fastchrf_metric_runner__pairwise(self): + from mbr.metrics.fastchrf import FastChrfMetricRunner + metric_runner = FastChrfMetricRunner(self.mbr_config, self.tokenizer, compute_pairwise_average=True) + metric_output = metric_runner(self.input_ids, self.sample_ids, self.reference_ids) + self.assertTrue(torch.is_floating_point(metric_output.scores)) + self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference)) + self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples + self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references + # Duplicate samples should have the same scores + torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) + torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0]) + # The metric scores should rank as expected, given the test strings in self.samples and self.references + self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) + self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2])