Skip to content

Commit

Permalink
Enable COMET unittests when running test workflow on GitHub (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
jvamvas authored Feb 6, 2024
1 parent b04530a commit 6cedbdc
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 8 deletions.
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
sacrebleu==2.4.0
unbabel-comet==2.1.1
unbabel-comet==2.2.1
git+https://github.com/google-research/bleurt.git
sentencepiece==0.1.99 # M2M100 model
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
sacrebleu==2.4.0

unbabel-comet==2.2.1
6 changes: 0 additions & 6 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def test_is_source_based__chrf(self):
chrf = evaluate.load("chrf")
self.assertFalse(metric_is_source_based(chrf))

@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
def test_is_source_based__comet(self):
comet = evaluate.load("comet", "eamt22-cometinho-da")
self.assertTrue(metric_is_source_based(comet))
Expand All @@ -63,7 +62,6 @@ def test_load_metric(self):
self.assertIsInstance(metric, evaluate.Metric)
self.assertEqual(metric.name, "chr_f")

@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
def test_metric_config_name(self):
self.mbr_config.metric = "comet"
self.mbr_config.metric_config_name = "eamt22-cometinho-da"
Expand All @@ -89,7 +87,6 @@ def test_compute_metric__chrf(self):
self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1])
self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2])

@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
def test_compute_metric__comet(self):
self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da")
self.mbr_config.metric.scorer.eval()
Expand Down Expand Up @@ -130,7 +127,6 @@ def test_compute_metric__bleurt(self):
self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1])
self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2])

@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
def test_comet_metric_runner(self):
from mbr.metrics.comet import CometMetricRunner
self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da")
Expand All @@ -146,7 +142,6 @@ def test_comet_metric_runner(self):
metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
torch.testing.assert_close(base_metric_scores, metric_scores)

@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
def test_comet_metric_runner__cache(self):
"""Output should be identical irrespective of cache size"""
from mbr.metrics.comet import CometMetricRunner
Expand All @@ -161,7 +156,6 @@ def test_comet_metric_runner__cache(self):
metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
torch.testing.assert_close(base_metric_scores, metric_scores)

@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
def test_comet_metric_runner__aggregate(self):
from mbr.metrics.comet import AggregateCometMetricRunner
self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da")
Expand Down

0 comments on commit 6cedbdc

Please sign in to comment.