-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
7 changed files
with
201 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import random | ||
import time | ||
|
||
from typing import Tuple, List | ||
from hypothesis import given | ||
from hypothesis import strategies as st | ||
|
||
import bleuscore | ||
import py_bleu | ||
|
||
|
||
def shuffle_string(s: str) -> str: | ||
# Convert the string to a list of characters | ||
char_list = list(s) | ||
|
||
# Shuffle the list of characters | ||
random.shuffle(char_list) | ||
|
||
# Join the list back into a string | ||
shuffled_string = ''.join(char_list) | ||
|
||
return shuffled_string | ||
|
||
|
||
def shrink_string(s: str) -> str: | ||
return s[0:random.randint(1, len(s) + 1)] | ||
|
||
|
||
def build_translation_pair(text: str, n: int = 10) -> Tuple[List[str], List[List[str]]]: | ||
references = [[text] for _ in range(n)] | ||
predictions = [shrink_string(shuffle_string(text)) for _ in range(n)] | ||
return predictions, references | ||
|
||
|
||
@given(st.text(alphabet=st.characters(min_codepoint=32, max_codepoint=126), | ||
min_size=10, max_size=20)) | ||
def test_bleu(input_text): | ||
predictions, references = build_translation_pair(text=input_text, n=10) | ||
max_order = 4 | ||
smooth = True | ||
|
||
py_result = rust_result = {} | ||
t0 = time.time() | ||
for i in range(10): | ||
py_result = py_bleu.compute_bleu(reference_corpus=references, | ||
translation_corpus=predictions, | ||
max_order=max_order, | ||
smooth=smooth)[0] | ||
t1 = time.time() | ||
for i in range(10): | ||
rust_result = bleuscore.compute_bleu(reference_corpus=references, | ||
translation_corpus=predictions, | ||
max_order=max_order, | ||
smooth=smooth) | ||
print(rust_result) | ||
rust_result = rust_result.get("bleu") | ||
t2 = time.time() | ||
print(t1 - t0, t2 - t1, (t1 - t0) > (t2 - t1)) | ||
assert (py_result - rust_result) < 1e-10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,5 @@ | |
__all__ = [ | ||
"tokenizer_regex", | ||
"tokenizer_13a", | ||
"compute_bleu", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
use counter::Counter; | ||
use crate::ngram::get_ngram_counter; | ||
|
||
#[derive(Debug, Default)] | ||
pub struct BleuScore { | ||
pub bleu: f64, | ||
pub precisions: Vec<f64>, | ||
pub bp: f64, | ||
pub ratio: f64, | ||
pub translation_length: usize, | ||
pub reference_length: usize, | ||
} | ||
|
||
|
||
pub fn compute_bleu( | ||
reference_corpus: Vec<Vec<String>>, | ||
translation_corpus: Vec<String>, | ||
max_order: usize, | ||
smooth: bool, | ||
) -> BleuScore { | ||
let mut matches_by_order: Vec<usize> = vec![0; max_order]; | ||
let mut possible_matches_by_order: Vec<usize> = vec![0; max_order]; | ||
let mut references_length: usize = 0; | ||
let mut translation_length: usize = 0; | ||
|
||
for (references, translation) in | ||
reference_corpus.iter().zip(translation_corpus.iter()) { | ||
references_length += references.iter().map(|x| x.len()).min().unwrap(); | ||
translation_length += translation.len(); | ||
let translation_ngram_counts = get_ngram_counter(translation, max_order); | ||
let mut merged_ref_ngram_counts = Counter::new(); | ||
for reference in references { | ||
merged_ref_ngram_counts |= get_ngram_counter(&reference, max_order); | ||
} | ||
let overlap = translation_ngram_counts & merged_ref_ngram_counts; | ||
|
||
for ngram in overlap.keys() { | ||
matches_by_order[ngram.len() - 1] += overlap[ngram] | ||
} | ||
for order in 1..=max_order { | ||
let possible_matches = translation.len() - order + 1; | ||
if possible_matches > 0 { | ||
possible_matches_by_order[order - 1] += possible_matches | ||
} | ||
} | ||
} | ||
|
||
let mut precisions:Vec<f64> = vec![0.0; max_order]; | ||
for i in 0..max_order { | ||
match smooth { | ||
true => { | ||
precisions[i] = (matches_by_order[i] as f64 + 1.0) / (possible_matches_by_order[i] as f64 + 1.0); | ||
}, | ||
false => { | ||
if possible_matches_by_order[i] > 0 { | ||
precisions[i] = (matches_by_order[i] as f64) / (possible_matches_by_order[i] as f64) | ||
} | ||
else { | ||
precisions[i] = 0.0; | ||
} | ||
} | ||
} | ||
} | ||
|
||
let mut geo_mean = 0.0; | ||
|
||
if precisions.iter().fold(f64::INFINITY, |a, &b| a.min(b)) > 0.0 { | ||
let p_log_sum: f64 = (1.0 / max_order as f64) * precisions.iter().map(|&x| x.ln()).sum::<f64>(); | ||
geo_mean = p_log_sum.exp(); | ||
} | ||
|
||
let ratio: f64 = translation_length as f64 / references_length as f64; | ||
let mut bp = 1.0; | ||
if ratio <= 1.0 { | ||
bp = (1.0 - 1.0 / ratio).exp(); | ||
} | ||
let bleu = geo_mean * bp; | ||
BleuScore{bleu, precisions, bp, ratio, translation_length, reference_length: references_length} | ||
} | ||
|
||
|
||
#[cfg(test)] | ||
mod test { | ||
use crate::bleu::{compute_bleu}; | ||
#[test] | ||
fn test_bleu() { | ||
let reference_corpus: Vec<Vec<String>> = vec![vec!["Hello".to_string()]]; | ||
let translation_corpus: Vec<String> = vec!["Yellow".to_string()]; | ||
let max_order: usize = 4; | ||
let smooth: bool = true; | ||
let res = compute_bleu(reference_corpus, translation_corpus, max_order, smooth); | ||
// (0.6147881529512643, [0.7142857142857143, 0.6666666666666666, 0.6, 0.5], 1.0, 1.2, 6, 5) | ||
assert_eq!((res.bleu - 0.6147881529512643) < 1e-10, true); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
fn main() { | ||
let line = "Hello , World !"; | ||
println!("{:?}", line.split_whitespace().map(|x| x.to_string()).collect::<Vec<String>>()); | ||
for i in 1..=4 as usize { | ||
for j in 0..=(2 - i) { | ||
println!("{i}, {j}"); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters