Skip to content

Commit

Permalink
Revise comet metric runner
Browse files Browse the repository at this point in the history
- bugfix: disable dropout by calling eval()
- torch.no_grad
- use cache for embeddings, too
  • Loading branch information
jvamvas committed Jan 31, 2024
1 parent 3949960 commit 67b7074
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 27 deletions.
3 changes: 2 additions & 1 deletion src/mbr/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,5 @@ def validate(self, is_init=False):
Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the
model, such as parameters related to the generation length.
"""
pass
if self.metric_cache_size <= 0:
raise ValueError(f"`metric_cache_size` ({self.metric_cache_size}) must be greater than 0.")
10 changes: 6 additions & 4 deletions src/mbr/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, mbr_config: MBRConfig, tokenizer: PreTrainedTokenizerBase):
# Ensure that mbr_config.metric_kwargs is hashable (because _compute_metric() uses lru_cache)
if mbr_config.metric_kwargs:
try:
hash(self.mbr_config.metric_kwargs)
hash(tuple(self.mbr_config.metric_kwargs))
except TypeError as e:
raise TypeError(f"mbr_config.metric_kwargs must be hashable.") from e
self.tokenizer = tokenizer
Expand All @@ -56,6 +56,8 @@ def _load_metric(self) -> MetricType:
metric = evaluate.load(metric, self.mbr_config.metric_config_name)
else:
raise ValueError(f"Invalid metric type: {type(metric)}")
if metric.name == "comet":
metric.scorer.eval()
return metric

def __call__(self,
Expand Down Expand Up @@ -111,11 +113,11 @@ def _compute_str_metric(self,
inputs: List[str] = None,
) -> torch.FloatTensor:
batch_size = len(samples[0])
metric_scores = torch.zeros((batch_size, self.mbr_config.num_samples, self.mbr_config.num_references))
metric_scores = torch.zeros((batch_size, len(samples), len(references)))
for i in range(batch_size):
for j in range(self.mbr_config.num_samples):
for j in range(len(samples)):
sample = samples[j][i]
for k in range(self.mbr_config.num_references):
for k in range(len(references)):
reference = references[k][i]
if inputs is not None:
score = self.compute_metric(
Expand Down
59 changes: 37 additions & 22 deletions src/mbr/metrics/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,16 @@ def __init__(self,
"comet.models.RegressionMetric")
if device is not None:
self.comet.scorer = self.comet.scorer.to(device)
self.comet.scorer.eval()
self.batch_size_embed = batch_size_embed
self.batch_size_estimate = batch_size_estimate
self.progress_bar = progress_bar
# We use a key-value cache, which is needed if the metric is called multiple times with similar inputs
# (e.g. for MBR with iterative pruning).
self.cache = FIFOCache(maxsize=self.mbr_config.metric_cache_size)
self.embedding_cache = FIFOCache(maxsize=self.mbr_config.metric_cache_size)
self.score_cache = FIFOCache(maxsize=self.mbr_config.metric_cache_size)

@torch.no_grad()
def _compute_str_metric(self,
samples: List[List[str]],
references: List[List[str]],
Expand All @@ -55,35 +58,47 @@ def _compute_str_metric(self,
if inputs is None:
raise NotImplementedError("CometMetricRunner requires source sequences (`inputs`) to be provided")
batch_size = len(samples[0])
metric_scores = torch.zeros((batch_size, self.mbr_config.num_samples, self.mbr_config.num_references))
metric_scores = torch.zeros((batch_size, len(samples), len(references)))
for i in tqdm(list(range(batch_size)), desc="comet", disable=not self.progress_bar):
# Embed all sequences
all_samples = [sample[i] for sample in samples]
all_references = [reference[i] for reference in references]
all_sequences = list(set(all_samples + all_references + inputs))
all_encodings = self.comet.scorer.encoder.prepare_sample(all_sequences).to(self.comet.scorer.device)
all_sequences = set(all_samples + all_references + inputs)

all_embeddings: Dict[str, torch.FloatTensor] = {}
batches = itertools.zip_longest(range(0, len(all_sequences), self.batch_size_embed),
range(self.batch_size_embed, len(all_sequences), self.batch_size_embed))
for start_idx, end_idx in batches:
embeddings = self.comet.scorer.get_sentence_embedding(
input_ids=all_encodings["input_ids"][start_idx:end_idx],
attention_mask=all_encodings["attention_mask"][start_idx:end_idx],
)
for j in range(start_idx, end_idx if end_idx is not None else len(all_sequences)):
all_embeddings[all_sequences[j]] = embeddings[j - start_idx]
# Populate embeddings from cache
for sequence in list(all_sequences):
if sequence in self.embedding_cache:
all_embeddings[sequence] = self.embedding_cache[sequence]
all_sequences.remove(sequence)

# Compute embeddings for remaining sequences
if all_sequences:
all_sequences = list(all_sequences)
encodings = self.comet.scorer.encoder.prepare_sample(all_sequences).to(self.comet.scorer.device)
batches = itertools.zip_longest(range(0, len(all_sequences), self.batch_size_embed),
range(self.batch_size_embed, len(all_sequences), self.batch_size_embed))
for start_idx, end_idx in batches:
embeddings = self.comet.scorer.get_sentence_embedding(
input_ids=encodings["input_ids"][start_idx:end_idx],
attention_mask=encodings["attention_mask"][start_idx:end_idx],
)
for j in range(start_idx, end_idx if end_idx is not None else len(all_sequences)):
embedding = embeddings[j - start_idx]
all_embeddings[all_sequences[j]] = embedding
self.embedding_cache[all_sequences[j]] = embedding

# Collect all input triples in a list
input_triples: Set[Tuple[str, str, str]] = set()
for j in range(self.mbr_config.num_samples):
for k in range(self.mbr_config.num_references):
for j in range(len(samples)):
for k in range(len(references)):
input_triples.add((inputs[i], samples[j][i], references[k][i]))
input_triple_scores = {}

# Check if any of the triples are in the cache
input_triple_scores: Dict[Tuple[str, str, str], torch.FloatTensor] = {}
# Populate scores from cache
for triple in list(input_triples):
if triple in self.cache:
input_triple_scores[triple] = self.cache[triple]
if triple in self.score_cache:
input_triple_scores[triple] = self.score_cache[triple]
input_triples.remove(triple)

# Compute scores for remaining input triples
Expand All @@ -102,10 +117,10 @@ def _compute_str_metric(self,
triple = batch[j - start_idx]
score = batch_scores.score[j - start_idx]
input_triple_scores[triple] = score
self.cache[triple] = score
self.score_cache[triple] = score

for j in range(self.mbr_config.num_samples):
for k in range(self.mbr_config.num_references):
for j in range(len(samples)):
for k in range(len(references)):
metric_scores[i, j, k] = input_triple_scores[(inputs[i], samples[j][i], references[k][i])]

return metric_scores
19 changes: 19 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def test_compute_metric__chrf(self):
@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
def test_compute_metric__comet(self):
self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da")
self.mbr_config.metric.scorer.eval()
self.mbr_config.metric_output_field = "mean_score"
self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer)
self.assertEqual(self.metric_runner.metric.name, "comet")
Expand Down Expand Up @@ -133,15 +134,33 @@ def test_compute_metric__bleurt(self):
def test_comet_metric_runner(self):
from mbr.metrics.comet import CometMetricRunner
self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da")
self.mbr_config.metric.scorer.eval()
self.mbr_config.metric_output_field = "mean_score"
base_metric_runner = MetricRunner(self.mbr_config, self.tokenizer)
self.assertEqual(base_metric_runner.metric.name, "comet")
self.assertFalse(base_metric_runner.metric.scorer.training)
comet_metric_runner = CometMetricRunner(self.mbr_config, self.tokenizer)
self.assertFalse(comet_metric_runner.metric.scorer.training)
# Output should be the same as the base MetricRunner
base_metric_scores = base_metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
torch.testing.assert_close(base_metric_scores, metric_scores)

@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
def test_comet_metric_runner__cache(self):
"""Output should be identical irrespective of cache size"""
from mbr.metrics.comet import CometMetricRunner
self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da")
self.mbr_config.metric_output_field = "mean_score"
base_metric_runner = MetricRunner(self.mbr_config, self.tokenizer)
base_metric_scores = base_metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
self.assertEqual(base_metric_runner.metric.name, "comet")
for cache_size in [1, 4, 8]:
self.mbr_config.metric_cache_size = cache_size
comet_metric_runner = CometMetricRunner(self.mbr_config, self.tokenizer)
metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
torch.testing.assert_close(base_metric_scores, metric_scores)

def test_fastchrf_metric_runner__aggregate(self):
from mbr.metrics.fastchrf import FastChrfMetricRunner
metric_runner = FastChrfMetricRunner(self.mbr_config, self.tokenizer, compute_pairwise_average=False)
Expand Down

0 comments on commit 67b7074

Please sign in to comment.