Skip to content

Commit

Permalink
Feat/compare current weights with another (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Feb 6, 2024
1 parent 85643d4 commit 6b0207a
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "0.2.0"
version = "0.3.0"
authors = ["Open Spaced Repetition"]
categories = ["Algorithms", "Science"]
edition = "2021"
Expand Down
74 changes: 66 additions & 8 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub(crate) const FACTOR: f64 = 19f64 / 81f64;
pub(crate) const S_MIN: f32 = 0.01;
/// This is a slice for efficiency, but should always be 17 in length.
pub type Weights = [f32];
use itertools::izip;

pub static DEFAULT_WEIGHTS: [f32; 17] = [
0.5614, 1.2546, 3.5878, 7.9731, 5.1043, 1.1303, 0.823, 0.0465, 1.629, 0.135, 1.0045, 2.132,
Expand Down Expand Up @@ -236,6 +237,56 @@ impl<B: Backend> FSRS<B> {
pub fn current_retrievability(&self, state: MemoryState, days_elapsed: u32) -> f32 {
(days_elapsed as f32 / (state.stability * 9.0) + 1.0).powi(-1)
}

/// Returns the universal metrics for the existing and provided parameters. If the first value
/// is smaller than the second value, the existing parameters are better than the provided ones.
pub fn universal_metrics<F>(
&self,
items: Vec<FSRSItem>,
parameters: &Weights,
mut progress: F,
) -> Result<(f32, f32)>
where
F: FnMut(ItemProgress) -> bool,
{
if items.is_empty() {
return Err(FSRSError::NotEnoughData);
}
let batcher = FSRSBatcher::new(self.device());
let mut all_predictions_self = vec![];
let mut all_predictions_other = vec![];
let mut all_true_val = vec![];
let mut progress_info = ItemProgress {
current: 0,
total: items.len(),
};
let model_self = self.model();
let fsrs_other = FSRS::<B>::new_with_backend(Some(parameters), self.device())?;
let model_other = fsrs_other.model();
for chunk in items.chunks(512) {
let batch = batcher.batch(chunk.to_vec());

let (_state, retention) = infer::<B>(model_self, batch.clone());
let pred = retention.clone().to_data().convert::<f32>().value;
all_predictions_self.extend(pred);

let (_state, retention) = infer::<B>(model_other, batch.clone());
let pred = retention.clone().to_data().convert::<f32>().value;
all_predictions_other.extend(pred);

let true_val = batch.labels.clone().to_data().convert::<f32>().value;
all_true_val.extend(true_val);
progress_info.current += chunk.len();
if !progress(progress_info) {
return Err(FSRSError::Interrupted);
}
}
let self_by_other =
measure_a_by_b(&all_predictions_self, &all_predictions_other, &all_true_val);
let other_by_self =
measure_a_by_b(&all_predictions_other, &all_predictions_self, &all_true_val);
Ok((self_by_other, other_by_self))
}
}

#[derive(Debug, Copy, Clone)]
Expand Down Expand Up @@ -274,17 +325,17 @@ fn calibration_rmse(pred: &[f32], true_val: &[f32]) -> f32 {
if pred.len() != true_val.len() {
panic!("Vectors pred and true_val must have the same length");
}
measure_a_by_b(pred, pred, true_val)
}

fn measure_a_by_b(pred_a: &[f32], pred_b: &[f32], true_val: &[f32]) -> f32 {
let mut groups = HashMap::new();

for (p, t) in pred.iter().zip(true_val) {
let bin = get_bin(*p, 20);
groups.entry(bin).or_insert_with(Vec::new).push((p, t));
}

izip!(pred_a, pred_b, true_val).for_each(|(a, b, t)| {
let bin = get_bin(*b, 20);
groups.entry(bin).or_insert_with(Vec::new).push((a, t));
});
let mut total_sum = 0.0;
let mut total_count = 0.0;

for group in groups.values() {
let count = group.len() as f32;
let pred_mean = group.iter().map(|(p, _)| *p).sum::<f32>() / count;
Expand Down Expand Up @@ -414,10 +465,17 @@ mod tests {
.assert_approx_eq(&Data::from([0.204_001, 0.025_387]), 5);

let fsrs = FSRS::new(Some(WEIGHTS))?;
let metrics = fsrs.evaluate(items, |_| true).unwrap();
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.201_908, 0.013_894]), 5);

let (self_by_other, other_by_self) = fsrs
.universal_metrics(items, &DEFAULT_WEIGHTS, |_| true)
.unwrap();

Data::from([self_by_other, other_by_self])
.assert_approx_eq(&Data::from([0.015_987_674, 0.019_702_684]), 5);
Ok(())
}

Expand Down

0 comments on commit 6b0207a

Please sign in to comment.