Skip to content

Commit

Permalink
Return MetricOutput dict (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
jvamvas authored Dec 27, 2023
1 parent 50db2dc commit fea9ecb
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 46 deletions.
2 changes: 1 addition & 1 deletion src/mbr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from mbr.generation.configuration_utils import MBRGenerationConfig
from mbr.generation.utils import MBROutput, MBRGenerationMixin
from mbr.metrics.base import MetricRunner
from mbr.metrics.base import MetricOutput, MetricRunner
from mbr.modeling import MBR


Expand Down
18 changes: 9 additions & 9 deletions src/mbr/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from transformers.utils import logging, ModelOutput

from mbr.generation.configuration_utils import MBRGenerationConfig
from mbr.metrics.base import MetricRunner
from mbr.metrics.base import MetricRunner, MetricOutput

if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
Expand All @@ -42,16 +42,16 @@ class MBROutput(ModelOutput):
The indices (in `all_samples`) of the selected sequences for each batch item.
references (`tuple(ModelOutput)`), *optional*, returned when `output_all_samples=True` is passed or when
`config.output_all_samples=True`):
metric_scores (`torch.FloatTensor` of shape `(batch_size, num_samples)`), *optional*, returned when
`output_metric_scores=True` is passed or when `config.output_metric_scores=True`):
The metric score for each sample.
metric_scores (`MetricOutput`), *optional*, returned when `output_metric_scores=True` is passed or when
`config.output_metric_scores=True`):
The output of the metric.
"""

sequences: torch.LongTensor = None
all_samples: Optional[Tuple[ModelOutput]] = None
selected_samples_indices: Optional[torch.LongTensor] = None
references: Optional[Tuple[ModelOutput]] = None
metric_scores: Optional[torch.FloatTensor] = None
metric_scores: Optional[MetricOutput] = None


class MBRGenerationMixin(GenerationMixin):
Expand Down Expand Up @@ -483,11 +483,11 @@ def generate(
else:
reference_ids = references

metric_scores = metric_runner(input_ids, sample_ids, reference_ids)
metric_output = metric_runner(input_ids, sample_ids, reference_ids)
if not mbr_config.lower_is_better:
top_metric_scores, top_metric_indices = metric_scores.max(dim=-1)
top_metric_scores, top_metric_indices = metric_output.scores.max(dim=-1)
else:
top_metric_scores, top_metric_indices = metric_scores.min(dim=-1)
top_metric_scores, top_metric_indices = metric_output.scores.min(dim=-1)

# Copy top samples into a tensor of shape (batch_size, max_length)
max_length = max(sample.shape[1] for sample in sample_ids)
Expand All @@ -496,7 +496,7 @@ def generate(
all_samples=(tuple(samples) if mbr_config.output_all_samples else None),
selected_samples_indices=(top_metric_indices if mbr_config.output_all_samples else None),
references=(tuple(references) if mbr_config.output_all_samples else None),
metric_scores=(metric_scores if mbr_config.output_metric_scores else None),
metric_scores=(metric_output if mbr_config.output_metric_scores else None),
)
for batch_idx, sample_idx in enumerate(top_metric_indices):
output.sequences[batch_idx][:sample_ids[sample_idx].shape[1]] = sample_ids[sample_idx][batch_idx]
Expand Down
31 changes: 24 additions & 7 deletions src/mbr/metrics/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
import functools
from typing import Tuple, Union, List
from dataclasses import dataclass
from typing import Tuple, Union, List, Optional

import evaluate
import torch
from datasets import Metric
from evaluate import EvaluationModule
from transformers import PreTrainedTokenizerBase
from transformers.utils import ModelOutput

from mbr import MBRGenerationConfig

MetricType = Union[Metric, EvaluationModule]


@dataclass
class MetricOutput(ModelOutput):
"""
Args:
scores (`torch.FloatTensor` of shape `(batch_size, num_samples)`):
The metric scores for each sample (aggregated over all references).
scores_per_reference (`torch.FloatTensor` of shape `(batch_size, num_samples, num_references)`):
The pairwise metric scores for each sample and reference. `None` if the metric is computed corpus-level.
"""
scores: torch.FloatTensor
scores_per_reference: Optional[torch.FloatTensor] = None


class MetricRunner:
"""
Applies the metric to samples and references (and optionally inputs) and calculates a metric score for each sample.
Expand Down Expand Up @@ -46,7 +61,7 @@ def __call__(self,
input_ids: torch.LongTensor,
sample_ids: Tuple[torch.LongTensor],
reference_ids: Tuple[torch.LongTensor],
) -> torch.FloatTensor:
) -> MetricOutput:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Expand All @@ -59,8 +74,7 @@ def __call__(self,
the reference sequences.
Returns:
`torch.FloatTensor` of shape `(batch_size, num_samples)`:
The metric scores for each sample (aggregated over all references).
`MetricOutput` containing the metric scores.
"""

# Detokenize
Expand All @@ -83,8 +97,12 @@ def __call__(self,
raise ValueError("Number of references must match `mbr_config.num_references`")

# Compute metric
metric_scores = self._compute_str_metric(str_samples, str_references, str_inputs)
return metric_scores
scores_per_reference = self._compute_str_metric(str_samples, str_references, str_inputs)

return MetricOutput(
scores=scores_per_reference.mean(dim=-1),
scores_per_reference=scores_per_reference,
)

def _compute_str_metric(self,
samples: List[List[str]],
Expand Down Expand Up @@ -112,7 +130,6 @@ def _compute_str_metric(self,
**self.mbr_config.metric_kwargs,
)
metric_scores[i, j, k] = score
metric_scores = metric_scores.mean(dim=-1) # average over references
return metric_scores

@functools.lru_cache(maxsize=(1024 ** 2))
Expand Down
1 change: 0 additions & 1 deletion src/mbr/metrics/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,4 @@ def _compute_str_metric(self,
for k in range(self.mbr_config.num_references):
metric_scores[i, j, k] = input_triple_scores[(inputs[i], samples[j][i], references[k][i])]

metric_scores = metric_scores.mean(dim=-1) # average over references
return metric_scores
18 changes: 11 additions & 7 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import AutoTokenizer, GPT2LMHeadModel, M2M100ForConditionalGeneration, GenerationConfig
from transformers.generation import SampleDecoderOnlyOutput, SampleEncoderDecoderOutput

from mbr import MBR, MBRGenerationConfig, MBROutput, MetricRunner
from mbr import MBR, MBRGenerationConfig, MBROutput, MetricRunner, MetricOutput


class DecoderOnlyTestCase(TestCase):
Expand Down Expand Up @@ -90,9 +90,11 @@ def test_model_output_extended(self):
self.assertEqual(5, len(output.references))
self.assertIsInstance(output.references[0], SampleDecoderOnlyOutput)
self.assertIsNotNone(output.metric_scores)
self.assertTrue(torch.is_floating_point(output.metric_scores))
self.assertEqual(1, output.metric_scores.shape[0])
self.assertEqual(5, output.metric_scores.shape[1])
self.assertIsInstance(output.metric_scores, MetricOutput)
self.assertTrue(torch.is_floating_point(output.metric_scores.scores))
self.assertTrue(torch.is_floating_point(output.metric_scores.scores_per_reference))
self.assertEqual([1, 5], list(output.metric_scores.scores.shape))
self.assertEqual([1, 5, 5], list(output.metric_scores.scores_per_reference.shape))

# Test the model output for a selected sample
sample = output.all_samples[output.selected_samples_indices[0]]
Expand Down Expand Up @@ -267,9 +269,11 @@ def test_model_output_extended(self):
self.assertEqual(5, len(output.references))
self.assertIsInstance(output.references[0], SampleEncoderDecoderOutput)
self.assertIsNotNone(output.metric_scores)
self.assertTrue(torch.is_floating_point(output.metric_scores))
self.assertEqual(2, output.metric_scores.shape[0])
self.assertEqual(5, output.metric_scores.shape[1])
self.assertIsInstance(output.metric_scores, MetricOutput)
self.assertTrue(torch.is_floating_point(output.metric_scores.scores))
self.assertTrue(torch.is_floating_point(output.metric_scores.scores_per_reference))
self.assertEqual([2, 5], list(output.metric_scores.scores.shape))
self.assertEqual([2, 5, 5], list(output.metric_scores.scores_per_reference.shape))

# Test the model output for a selected sample (batch index 0)
sample = output.all_samples[output.selected_samples_indices[0]]
Expand Down
55 changes: 34 additions & 21 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,47 +75,59 @@ def test_metric_config_name(self):
self.assertEqual(metric.scorer.encoder.__class__.__name__, "MiniLMEncoder")

def test_compute_metric__chrf(self):
metric_scores = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
self.assertTrue(torch.is_floating_point(metric_scores))
metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
self.assertTrue(torch.is_floating_point(metric_output.scores))
self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference))
torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores)
self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples
self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references
# Duplicate samples should have the same scores
self.assertEqual(metric_scores[0, 0], metric_scores[0, 1])
torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1])
torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0])
# The metric scores should rank as expected, given the test strings in self.samples and self.references
self.assertEqual(metric_scores.shape, (2, 3)) # batch_size x num_samples
self.assertGreater(metric_scores[0, 0], metric_scores[0, 2])
self.assertLess(metric_scores[1, 0], metric_scores[1, 1])
self.assertLess(metric_scores[1, 0], metric_scores[1, 2])
self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2])
self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1])
self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2])

@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
def test_compute_metric__comet(self):
self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da")
self.mbr_config.metric_output_field = "mean_score"
self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer)
self.assertEqual(self.metric_runner.metric.name, "comet")
metric_scores = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
self.assertTrue(torch.is_floating_point(metric_scores))
metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
self.assertTrue(torch.is_floating_point(metric_output.scores))
self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference))
torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores)
self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples
self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references
# Duplicate samples should have the same scores
self.assertEqual(metric_scores[0, 0], metric_scores[0, 1])
torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1])
torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0])
# The metric scores should rank as expected, given the test strings in self.samples and self.references
self.assertEqual(metric_scores.shape, (2, 3)) # batch_size x num_samples
self.assertGreater(metric_scores[0, 0], metric_scores[0, 2])
self.assertLess(metric_scores[1, 0], metric_scores[1, 1])
self.assertLess(metric_scores[1, 0], metric_scores[1, 2])
self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2])
self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1])
self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2])

@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
def test_compute_metric__bleurt(self):
self.mbr_config.metric = evaluate.load("bleurt")
self.mbr_config.metric_output_field = "scores"
self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer)
self.assertEqual(self.metric_runner.metric.name, "bleurt")
metric_scores = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
self.assertTrue(torch.is_floating_point(metric_scores))
metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
self.assertTrue(torch.is_floating_point(metric_output.scores))
self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference))
torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores)
self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples
self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references
# Duplicate samples should have the same scores
self.assertEqual(metric_scores[0, 0], metric_scores[0, 1])
torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1])
torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0])
# The metric scores should rank as expected, given the test strings in self.samples and self.references
self.assertEqual(metric_scores.shape, (2, 3)) # batch_size x num_samples
self.assertGreater(metric_scores[0, 0], metric_scores[0, 2])
self.assertLess(metric_scores[1, 0], metric_scores[1, 1])
self.assertLess(metric_scores[1, 0], metric_scores[1, 2])
self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2])
self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1])
self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2])

@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
def test_comet_metric_runner(self):
Expand All @@ -125,6 +137,7 @@ def test_comet_metric_runner(self):
base_metric_runner = MetricRunner(self.mbr_config, self.tokenizer)
self.assertEqual(base_metric_runner.metric.name, "comet")
comet_metric_runner = CometMetricRunner(self.mbr_config, self.tokenizer)
# Output should be the same as the base MetricRunner
base_metric_scores = base_metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
torch.testing.assert_close(base_metric_scores, metric_scores)

0 comments on commit fea9ecb

Please sign in to comment.