Skip to content

Commit

Permalink
Add/bleu (#6)
Browse files Browse the repository at this point in the history
* add: bleu calculation partly
  • Loading branch information
shenxiangzhuang authored Apr 19, 2024
1 parent 60b1ff0 commit 215f0b1
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 3 deletions.
5 changes: 5 additions & 0 deletions benchmark/py_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,8 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4,
bleu = geo_mean * bp

return (bleu, precisions, bp, ratio, translation_length, reference_length)


if __name__ == "__main__":
res = compute_bleu(reference_corpus=[["Hello"]], translation_corpus=["Yellow"], max_order=4, smooth=True)
print(res)
59 changes: 59 additions & 0 deletions benchmark/test_benchmark_bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import random
import time

from typing import Tuple, List
from hypothesis import given
from hypothesis import strategies as st

import bleuscore
import py_bleu


def shuffle_string(s: str) -> str:
# Convert the string to a list of characters
char_list = list(s)

# Shuffle the list of characters
random.shuffle(char_list)

# Join the list back into a string
shuffled_string = ''.join(char_list)

return shuffled_string


def shrink_string(s: str) -> str:
return s[0:random.randint(1, len(s) + 1)]


def build_translation_pair(text: str, n: int = 10) -> Tuple[List[str], List[List[str]]]:
references = [[text] for _ in range(n)]
predictions = [shrink_string(shuffle_string(text)) for _ in range(n)]
return predictions, references


@given(st.text(alphabet=st.characters(min_codepoint=32, max_codepoint=126),
min_size=10, max_size=20))
def test_bleu(input_text):
predictions, references = build_translation_pair(text=input_text, n=10)
max_order = 4
smooth = True

py_result = rust_result = {}
t0 = time.time()
for i in range(10):
py_result = py_bleu.compute_bleu(reference_corpus=references,
translation_corpus=predictions,
max_order=max_order,
smooth=smooth)[0]
t1 = time.time()
for i in range(10):
rust_result = bleuscore.compute_bleu(reference_corpus=references,
translation_corpus=predictions,
max_order=max_order,
smooth=smooth)
print(rust_result)
rust_result = rust_result.get("bleu")
t2 = time.time()
print(t1 - t0, t2 - t1, (t1 - t0) > (t2 - t1))
assert (py_result - rust_result) < 1e-10
1 change: 1 addition & 0 deletions python/bleuscore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
__all__ = [
"tokenizer_regex",
"tokenizer_13a",
"compute_bleu",
]
95 changes: 95 additions & 0 deletions src/bleu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use counter::Counter;
use crate::ngram::get_ngram_counter;

#[derive(Debug, Default)]
pub struct BleuScore {
pub bleu: f64,
pub precisions: Vec<f64>,
pub bp: f64,
pub ratio: f64,
pub translation_length: usize,
pub reference_length: usize,
}


pub fn compute_bleu(
reference_corpus: Vec<Vec<String>>,
translation_corpus: Vec<String>,
max_order: usize,
smooth: bool,
) -> BleuScore {
let mut matches_by_order: Vec<usize> = vec![0; max_order];
let mut possible_matches_by_order: Vec<usize> = vec![0; max_order];
let mut references_length: usize = 0;
let mut translation_length: usize = 0;

for (references, translation) in
reference_corpus.iter().zip(translation_corpus.iter()) {
references_length += references.iter().map(|x| x.len()).min().unwrap();
translation_length += translation.len();
let translation_ngram_counts = get_ngram_counter(translation, max_order);
let mut merged_ref_ngram_counts = Counter::new();
for reference in references {
merged_ref_ngram_counts |= get_ngram_counter(&reference, max_order);
}
let overlap = translation_ngram_counts & merged_ref_ngram_counts;

for ngram in overlap.keys() {
matches_by_order[ngram.len() - 1] += overlap[ngram]
}
for order in 1..=max_order {
let possible_matches = translation.len() - order + 1;
if possible_matches > 0 {
possible_matches_by_order[order - 1] += possible_matches
}
}
}

let mut precisions:Vec<f64> = vec![0.0; max_order];
for i in 0..max_order {
match smooth {
true => {
precisions[i] = (matches_by_order[i] as f64 + 1.0) / (possible_matches_by_order[i] as f64 + 1.0);
},
false => {
if possible_matches_by_order[i] > 0 {
precisions[i] = (matches_by_order[i] as f64) / (possible_matches_by_order[i] as f64)
}
else {
precisions[i] = 0.0;
}
}
}
}

let mut geo_mean = 0.0;

if precisions.iter().fold(f64::INFINITY, |a, &b| a.min(b)) > 0.0 {
let p_log_sum: f64 = (1.0 / max_order as f64) * precisions.iter().map(|&x| x.ln()).sum::<f64>();
geo_mean = p_log_sum.exp();
}

let ratio: f64 = translation_length as f64 / references_length as f64;
let mut bp = 1.0;
if ratio <= 1.0 {
bp = (1.0 - 1.0 / ratio).exp();
}
let bleu = geo_mean * bp;
BleuScore{bleu, precisions, bp, ratio, translation_length, reference_length: references_length}
}


#[cfg(test)]
mod test {
use crate::bleu::{compute_bleu};
#[test]
fn test_bleu() {
let reference_corpus: Vec<Vec<String>> = vec![vec!["Hello".to_string()]];
let translation_corpus: Vec<String> = vec!["Yellow".to_string()];
let max_order: usize = 4;
let smooth: bool = true;
let res = compute_bleu(reference_corpus, translation_corpus, max_order, smooth);
// (0.6147881529512643, [0.7142857142857143, 0.6666666666666666, 0.6, 0.5], 1.0, 1.2, 6, 5)
assert_eq!((res.bleu - 0.6147881529512643) < 1e-10, true);
}
}
25 changes: 25 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod tokenizer;
mod ngram;
mod bleu;

use pyo3::prelude::*;
use crate::tokenizer::Tokenizer;
use pyo3::types::IntoPyDict;


#[pyfunction]
Expand All @@ -19,10 +21,33 @@ fn tokenizer_13a(line: &str) -> PyResult<Vec<String>> {
Ok(res)
}

#[pyfunction]
pub fn compute_bleu(
reference_corpus: Vec<Vec<String>>,
translation_corpus: Vec<String>,
max_order: usize,
smooth: bool, ) -> PyResult<PyObject> {
let bleu = bleu::compute_bleu(reference_corpus, translation_corpus, max_order, smooth);
Python::with_gil(|py| {
let bleu_dict = [
("bleu", bleu.bleu.to_object(py)),
("precisions", bleu.precisions.to_object(py)),
("bp", bleu.bp.to_object(py)),
("ratio", bleu.ratio.to_object(py)),
("translation_length", bleu.translation_length.to_object(py)),
("reference_length", bleu.reference_length.to_object(py)),
].into_py_dict_bound(py);
Ok(bleu_dict.into())
})

}


/// A Python module implemented in Rust.
#[pymodule]
fn bleuscore(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(tokenizer_regex, m)?)?;
m.add_function(wrap_pyfunction!(tokenizer_13a, m)?)?;
m.add_function(wrap_pyfunction!(compute_bleu, m)?)?;
Ok(())
}
7 changes: 5 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
fn main() {
let line = "Hello , World !";
println!("{:?}", line.split_whitespace().map(|x| x.to_string()).collect::<Vec<String>>());
for i in 1..=4 as usize {
for j in 0..=(2 - i) {
println!("{i}, {j}");
}
}
}
12 changes: 11 additions & 1 deletion src/ngram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ use counter::Counter;
pub fn get_ngram_counter(line: &str, max_order: usize) -> Counter<&str> {
let mut counts: Counter<&str> = Counter::new();
for order in 1..=max_order {
for start_index in 0..=(line.len() - order) {
for start_index in 0..(line.len().saturating_sub(order - 1)) {
// println!("line: {}, start_index: {}, order: {}", line, start_index, order);
let ngram = &line[start_index..(start_index + order)];
// println!("ngram: {}", ngram);
counts[&ngram] += 1;
}
}
Expand Down Expand Up @@ -33,4 +35,12 @@ mod test {
assert_eq!(counter[&"aab"], 1);
assert_eq!(counter[&"aabc"], 1);
}

#[test]
fn test_get_ngram_short() {
let counter = get_ngram_counter("ab", 4);
assert_eq!(counter[&"a"], 1);
assert_eq!(counter[&"b"], 1);
assert_eq!(counter[&"ab"], 1);
}
}

0 comments on commit 215f0b1

Please sign in to comment.