Skip to content

Commit

Permalink
fix: ngram_bench (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
shenxiangzhuang authored May 27, 2024
1 parent 8785907 commit b80988a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 47 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed
- `ngram` bench use counter lib's function rather than the truly used function.

### Changed
- Use AHash in ngram module
Expand Down
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ name = "bleuscore"
crate-type = ["cdylib", "rlib"]

[dependencies]
cached = "0.50.0"
cached = "0.51.3"
regex = "1.10.4"
lazy_static = "1.4.0"
counter = "0.5.7"
rayon = "1.10.0"
ahash = "0.8.11"

Expand Down
54 changes: 9 additions & 45 deletions src/ngram.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use ahash::AHashMap;
use counter::Counter;

/// Here the tokens' type is `&[String]` rather than `&Vec<String>`
/// to fix `clippy::not_unsafe_ptr_arg_deref` error.
Expand All @@ -17,24 +16,9 @@ pub fn get_token_ngram_counter(tokens: &[String], max_order: usize) -> AHashMap<
count_map
}

/// TODO: change to use Counter to count ngram
#[allow(dead_code)]
fn get_ngram_counter(line: &str, max_order: usize) -> Counter<&str> {
let mut counts: Counter<&str> = Counter::new();
for order in 1..=max_order {
for start_index in 0..(line.len().saturating_sub(order - 1)) {
// println!("line: {}, start_index: {}, order: {}", line, start_index, order);
let ngram = &line[start_index..(start_index + order)];
// println!("ngram: {}", ngram);
counts[&ngram] += 1;
}
}
counts
}

#[cfg(test)]
mod test {
use crate::ngram::{get_ngram_counter, get_token_ngram_counter};
use crate::ngram::get_token_ngram_counter;

#[test]
fn test_get_token_ngram_short() {
Expand Down Expand Up @@ -69,47 +53,27 @@ mod test {

assert_eq!(counter.len(), 9);
}

#[test]
fn test_get_ngram_short() {
let counter = get_ngram_counter("ab", 4);
assert_eq!(counter[&"a"], 1);
assert_eq!(counter[&"b"], 1);
assert_eq!(counter[&"ab"], 1);
}

#[test]
fn test_get_ngram_long() {
let counter = get_ngram_counter("aabc", 4);
assert_eq!(counter[&"a"], 2);
assert_eq!(counter[&"b"], 1);
assert_eq!(counter[&"c"], 1);
assert_eq!(counter[&"d"], 0);

assert_eq!(counter[&"aa"], 1);
assert_eq!(counter[&"ab"], 1);
assert_eq!(counter[&"bc"], 1);
assert_eq!(counter[&"ac"], 0);

assert_eq!(counter[&"aab"], 1);
assert_eq!(counter[&"aabc"], 1);
}
}

#[cfg(test)]
mod benchmark {
use crate::ngram::get_ngram_counter;
use crate::ngram::get_token_ngram_counter;
use test::Bencher;

#[bench]
fn bench_ngram(b: &mut Bencher) {
let line = "aabc";
let tokens: Vec<String> = vec![
"a".to_string(),
"a".to_string(),
"b".to_string(),
"c".to_string(),
];
let max_order = 4;

let iter_num: usize = 100;
b.iter(|| {
std::hint::black_box(for _ in 1..=iter_num {
get_ngram_counter(line, max_order);
get_token_ngram_counter(&tokens, max_order);
});
});
}
Expand Down

0 comments on commit b80988a

Please sign in to comment.