From 67b7074915bb45641aef53104ff6cce002d162c8 Mon Sep 17 00:00:00 2001 From: Jannis Vamvas Date: Wed, 31 Jan 2024 20:17:05 +0100 Subject: [PATCH] Revise comet metric runner - bugfix: disable dropout by calling eval() - torch.no_grad - use cache for embeddings, too --- src/mbr/generation/configuration_utils.py | 3 +- src/mbr/metrics/base.py | 10 ++-- src/mbr/metrics/comet.py | 59 ++++++++++++++--------- tests/test_metrics.py | 19 ++++++++ 4 files changed, 64 insertions(+), 27 deletions(-) diff --git a/src/mbr/generation/configuration_utils.py b/src/mbr/generation/configuration_utils.py index 07fba8d..6ecac1f 100644 --- a/src/mbr/generation/configuration_utils.py +++ b/src/mbr/generation/configuration_utils.py @@ -117,4 +117,5 @@ def validate(self, is_init=False): Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the model, such as parameters related to the generation length. """ - pass + if self.metric_cache_size <= 0: + raise ValueError(f"`metric_cache_size` ({self.metric_cache_size}) must be greater than 0.") diff --git a/src/mbr/metrics/base.py b/src/mbr/metrics/base.py index 11d64fa..117c58b 100644 --- a/src/mbr/metrics/base.py +++ b/src/mbr/metrics/base.py @@ -40,7 +40,7 @@ def __init__(self, mbr_config: MBRConfig, tokenizer: PreTrainedTokenizerBase): # Ensure that mbr_config.metric_kwargs is hashable (because _compute_metric() uses lru_cache) if mbr_config.metric_kwargs: try: - hash(self.mbr_config.metric_kwargs) + hash(tuple(self.mbr_config.metric_kwargs)) except TypeError as e: raise TypeError(f"mbr_config.metric_kwargs must be hashable.") from e self.tokenizer = tokenizer @@ -56,6 +56,8 @@ def _load_metric(self) -> MetricType: metric = evaluate.load(metric, self.mbr_config.metric_config_name) else: raise ValueError(f"Invalid metric type: {type(metric)}") + if metric.name == "comet": + metric.scorer.eval() return metric def __call__(self, @@ -111,11 +113,11 @@ def _compute_str_metric(self, inputs: List[str] = None, ) -> torch.FloatTensor: batch_size = len(samples[0]) - metric_scores = torch.zeros((batch_size, self.mbr_config.num_samples, self.mbr_config.num_references)) + metric_scores = torch.zeros((batch_size, len(samples), len(references))) for i in range(batch_size): - for j in range(self.mbr_config.num_samples): + for j in range(len(samples)): sample = samples[j][i] - for k in range(self.mbr_config.num_references): + for k in range(len(references)): reference = references[k][i] if inputs is not None: score = self.compute_metric( diff --git a/src/mbr/metrics/comet.py b/src/mbr/metrics/comet.py index d5ff38c..beee9e8 100644 --- a/src/mbr/metrics/comet.py +++ b/src/mbr/metrics/comet.py @@ -40,13 +40,16 @@ def __init__(self, "comet.models.RegressionMetric") if device is not None: self.comet.scorer = self.comet.scorer.to(device) + self.comet.scorer.eval() self.batch_size_embed = batch_size_embed self.batch_size_estimate = batch_size_estimate self.progress_bar = progress_bar # We use a key-value cache, which is needed if the metric is called multiple times with similar inputs # (e.g. for MBR with iterative pruning). - self.cache = FIFOCache(maxsize=self.mbr_config.metric_cache_size) + self.embedding_cache = FIFOCache(maxsize=self.mbr_config.metric_cache_size) + self.score_cache = FIFOCache(maxsize=self.mbr_config.metric_cache_size) + @torch.no_grad() def _compute_str_metric(self, samples: List[List[str]], references: List[List[str]], @@ -55,35 +58,47 @@ def _compute_str_metric(self, if inputs is None: raise NotImplementedError("CometMetricRunner requires source sequences (`inputs`) to be provided") batch_size = len(samples[0]) - metric_scores = torch.zeros((batch_size, self.mbr_config.num_samples, self.mbr_config.num_references)) + metric_scores = torch.zeros((batch_size, len(samples), len(references))) for i in tqdm(list(range(batch_size)), desc="comet", disable=not self.progress_bar): # Embed all sequences all_samples = [sample[i] for sample in samples] all_references = [reference[i] for reference in references] - all_sequences = list(set(all_samples + all_references + inputs)) - all_encodings = self.comet.scorer.encoder.prepare_sample(all_sequences).to(self.comet.scorer.device) + all_sequences = set(all_samples + all_references + inputs) + all_embeddings: Dict[str, torch.FloatTensor] = {} - batches = itertools.zip_longest(range(0, len(all_sequences), self.batch_size_embed), - range(self.batch_size_embed, len(all_sequences), self.batch_size_embed)) - for start_idx, end_idx in batches: - embeddings = self.comet.scorer.get_sentence_embedding( - input_ids=all_encodings["input_ids"][start_idx:end_idx], - attention_mask=all_encodings["attention_mask"][start_idx:end_idx], - ) - for j in range(start_idx, end_idx if end_idx is not None else len(all_sequences)): - all_embeddings[all_sequences[j]] = embeddings[j - start_idx] + # Populate embeddings from cache + for sequence in list(all_sequences): + if sequence in self.embedding_cache: + all_embeddings[sequence] = self.embedding_cache[sequence] + all_sequences.remove(sequence) + + # Compute embeddings for remaining sequences + if all_sequences: + all_sequences = list(all_sequences) + encodings = self.comet.scorer.encoder.prepare_sample(all_sequences).to(self.comet.scorer.device) + batches = itertools.zip_longest(range(0, len(all_sequences), self.batch_size_embed), + range(self.batch_size_embed, len(all_sequences), self.batch_size_embed)) + for start_idx, end_idx in batches: + embeddings = self.comet.scorer.get_sentence_embedding( + input_ids=encodings["input_ids"][start_idx:end_idx], + attention_mask=encodings["attention_mask"][start_idx:end_idx], + ) + for j in range(start_idx, end_idx if end_idx is not None else len(all_sequences)): + embedding = embeddings[j - start_idx] + all_embeddings[all_sequences[j]] = embedding + self.embedding_cache[all_sequences[j]] = embedding # Collect all input triples in a list input_triples: Set[Tuple[str, str, str]] = set() - for j in range(self.mbr_config.num_samples): - for k in range(self.mbr_config.num_references): + for j in range(len(samples)): + for k in range(len(references)): input_triples.add((inputs[i], samples[j][i], references[k][i])) - input_triple_scores = {} - # Check if any of the triples are in the cache + input_triple_scores: Dict[Tuple[str, str, str], torch.FloatTensor] = {} + # Populate scores from cache for triple in list(input_triples): - if triple in self.cache: - input_triple_scores[triple] = self.cache[triple] + if triple in self.score_cache: + input_triple_scores[triple] = self.score_cache[triple] input_triples.remove(triple) # Compute scores for remaining input triples @@ -102,10 +117,10 @@ def _compute_str_metric(self, triple = batch[j - start_idx] score = batch_scores.score[j - start_idx] input_triple_scores[triple] = score - self.cache[triple] = score + self.score_cache[triple] = score - for j in range(self.mbr_config.num_samples): - for k in range(self.mbr_config.num_references): + for j in range(len(samples)): + for k in range(len(references)): metric_scores[i, j, k] = input_triple_scores[(inputs[i], samples[j][i], references[k][i])] return metric_scores diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 4b469d3..89c9ecf 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -92,6 +92,7 @@ def test_compute_metric__chrf(self): @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") def test_compute_metric__comet(self): self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da") + self.mbr_config.metric.scorer.eval() self.mbr_config.metric_output_field = "mean_score" self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer) self.assertEqual(self.metric_runner.metric.name, "comet") @@ -133,15 +134,33 @@ def test_compute_metric__bleurt(self): def test_comet_metric_runner(self): from mbr.metrics.comet import CometMetricRunner self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da") + self.mbr_config.metric.scorer.eval() self.mbr_config.metric_output_field = "mean_score" base_metric_runner = MetricRunner(self.mbr_config, self.tokenizer) self.assertEqual(base_metric_runner.metric.name, "comet") + self.assertFalse(base_metric_runner.metric.scorer.training) comet_metric_runner = CometMetricRunner(self.mbr_config, self.tokenizer) + self.assertFalse(comet_metric_runner.metric.scorer.training) # Output should be the same as the base MetricRunner base_metric_scores = base_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) torch.testing.assert_close(base_metric_scores, metric_scores) + @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") + def test_comet_metric_runner__cache(self): + """Output should be identical irrespective of cache size""" + from mbr.metrics.comet import CometMetricRunner + self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da") + self.mbr_config.metric_output_field = "mean_score" + base_metric_runner = MetricRunner(self.mbr_config, self.tokenizer) + base_metric_scores = base_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) + self.assertEqual(base_metric_runner.metric.name, "comet") + for cache_size in [1, 4, 8]: + self.mbr_config.metric_cache_size = cache_size + comet_metric_runner = CometMetricRunner(self.mbr_config, self.tokenizer) + metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) + torch.testing.assert_close(base_metric_scores, metric_scores) + def test_fastchrf_metric_runner__aggregate(self): from mbr.metrics.fastchrf import FastChrfMetricRunner metric_runner = FastChrfMetricRunner(self.mbr_config, self.tokenizer, compute_pairwise_average=False)