Skip to content

Commit

Permalink
Fix/calculate average recall in item level & fix laplace smoothing & …
Browse files Browse the repository at this point in the history
…use f64 in pretrain (#161)
  • Loading branch information
L-M-Sherlock authored Mar 4, 2024
1 parent ce712c9 commit 9a20ec2
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 53 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "0.4.4"
version = "0.4.5"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
110 changes: 61 additions & 49 deletions src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ fn create_pretrain_data(fsrs_items: Vec<FSRSItem>) -> HashMap<FirstRating, Vec<A
for (second_delta_t, ratings) in inner_map {
let avg = ratings.iter().map(|&x| x as f64).sum::<f64>() / ratings.len() as f64;
data.push(AverageRecall {
delta_t: *second_delta_t as f32,
recall: avg as f32,
count: ratings.len() as f32,
delta_t: *second_delta_t as f64,
recall: avg,
count: ratings.len() as f64,
})
}

Expand All @@ -69,9 +69,9 @@ fn create_pretrain_data(fsrs_items: Vec<FSRSItem>) -> HashMap<FirstRating, Vec<A

/// The average pass rate & count for a single delta_t for a given first rating.
struct AverageRecall {
delta_t: f32,
recall: f32,
count: f32,
delta_t: f64,
recall: f64,
count: f64,
}

fn total_rating_count(
Expand All @@ -80,23 +80,23 @@ fn total_rating_count(
pretrainset
.iter()
.map(|(first_rating, data)| {
let count = data.iter().map(|d| d.count).sum::<f32>() as u32;
let count = data.iter().map(|d| d.count).sum::<f64>() as u32;
(*first_rating, count)
})
.collect()
}

fn power_forgetting_curve(t: &Array1<f32>, s: f32) -> Array1<f32> {
(t / s * FACTOR as f32 + 1.0).mapv(|v| v.powf(DECAY as f32))
fn power_forgetting_curve(t: &Array1<f64>, s: f64) -> Array1<f64> {
(t / s * FACTOR + 1.0).mapv(|v| v.powf(DECAY))
}

fn loss(
delta_t: &Array1<f32>,
recall: &Array1<f32>,
count: &Array1<f32>,
init_s0: f32,
default_s0: f32,
) -> f32 {
delta_t: &Array1<f64>,
recall: &Array1<f64>,
count: &Array1<f64>,
init_s0: f64,
default_s0: f64,
) -> f64 {
let y_pred = power_forgetting_curve(delta_t, init_s0);
let logloss = (-(recall * y_pred.clone().mapv_into(|v| v.ln())
+ (1.0 - recall) * (1.0 - &y_pred).mapv_into(|v| v.ln()))
Expand All @@ -113,25 +113,23 @@ fn search_parameters(
average_recall: f32,
) -> HashMap<u32, f32> {
let mut optimal_stabilities = HashMap::new();
let epsilon = f32::EPSILON;
let epsilon = f64::EPSILON;

for (first_rating, data) in &mut pretrainset {
let r_s0_default: HashMap<u32, f32> = R_S0_DEFAULT_ARRAY.iter().cloned().collect();
let default_s0 = r_s0_default[first_rating];
let default_s0 = r_s0_default[first_rating] as f64;
let delta_t = Array1::from_iter(data.iter().map(|d| d.delta_t));
let count = Array1::from_iter(data.iter().map(|d| d.count));
let recall = {
// Laplace smoothing
// (real_recall * n + average_recall * 1) / (n + 1)
// https://github.com/open-spaced-repetition/fsrs4anki/pull/358/files#diff-35b13c8e3466e8bd1231a51c71524fc31a945a8f332290726214d3a6fa7f442aR491
let real_recall = Array1::from_iter(data.iter().map(|d| d.recall));
let n = data.iter().map(|d| d.count).sum::<f32>();
(real_recall * n + average_recall) / (n + 1.0)
(real_recall * count.clone() + average_recall as f64) / (count.clone() + 1.0)
};
let count = Array1::from_iter(data.iter().map(|d| d.count));

let mut low = S_MIN;
let mut high = INIT_S_MAX;
let mut optimal_s = 1.0;
let mut low = S_MIN as f64;
let mut high = INIT_S_MAX as f64;
let mut optimal_s = default_s0;

let mut iter = 0;
while high - low > epsilon && iter < 1000 {
Expand All @@ -151,7 +149,7 @@ fn search_parameters(
optimal_s = (high + low) / 2.0;
}

optimal_stabilities.insert(*first_rating, optimal_s);
optimal_stabilities.insert(*first_rating, optimal_s as f32);
}

optimal_stabilities
Expand Down Expand Up @@ -274,70 +272,84 @@ mod tests {
use burn::tensor::Data;

use super::*;
use crate::dataset::split_data;
use crate::dataset::filter_outlier;
use crate::training::calculate_average_recall;

#[test]
fn test_power_forgetting_curve() {
let t = Array1::from(vec![0.0, 1.0, 2.0, 3.0]);
let s = 1.0;
let y = power_forgetting_curve(&t, s);
let expected = Array1::from(vec![1.0, 0.90000004, 0.82502866, 0.76613086]);
let expected = Array1::from(vec![1.0, 0.9, 0.8250286473253902, 0.7661308776828737]);
assert_eq!(y, expected);
}

#[test]
fn test_loss() {
let delta_t = Array1::from(vec![1.0, 2.0, 3.0]);
let recall = Array1::from(vec![0.9, 0.8181818, 0.75]);
let count = Array1::from(vec![100.0, 100.0, 100.0]);
let init_s0 = 1.0;
let actual = loss(&delta_t, &recall, &count, init_s0, init_s0);
assert_eq!(actual, 13.624332);
Data::from([loss(&delta_t, &recall, &count, 2.0, init_s0)])
.assert_approx_eq(&Data::from([14.5771]), 5);
let delta_t = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let recall = Array1::from(vec![
0.86684181, 0.90758192, 0.73348482, 0.76776996, 0.68769064,
]);
let count = Array1::from(vec![435.0, 97.0, 63.0, 38.0, 28.0]);
let default_s0 = DEFAULT_PARAMETERS[0] as f64;
let actual = loss(&delta_t, &recall, &count, 1.017056, default_s0);
dbg!(actual);
assert_eq!(actual, 22.922578338789826);
let actual = loss(&delta_t, &recall, &count, 1.017011, default_s0);
dbg!(actual);
assert_eq!(actual, 22.922578344493953);
}

#[test]
fn test_search_parameters() {
let first_rating = 1;
let pretrainset = HashMap::from([(
4,
first_rating,
vec![
AverageRecall {
delta_t: 1.0,
recall: 0.9,
count: 30.0,
recall: 0.86666667,
count: 435.0,
},
AverageRecall {
delta_t: 2.0,
recall: 0.8181818,
count: 30.0,
recall: 0.90721649,
count: 97.0,
},
AverageRecall {
delta_t: 3.0,
recall: 0.75,
count: 30.0,
recall: 0.73015873,
count: 63.0,
},
AverageRecall {
delta_t: 4.0,
recall: 0.6923077,
count: 30.0,
recall: 0.76315789,
count: 38.0,
},
AverageRecall {
delta_t: 5.0,
recall: 0.67857143,
count: 28.0,
},
],
)]);
let actual = search_parameters(pretrainset, 0.9);
Data::from([*actual.get(&4).unwrap()]).assert_approx_eq(&Data::from([0.943_921]), 3);
let actual = search_parameters(pretrainset, 0.9430285915990116);
Data::from([*actual.get(&first_rating).unwrap()])
.assert_approx_eq(&Data::from([1.017_056]), 6);
}

#[test]
fn test_pretrain() {
use crate::convertor_tests::anki21_sample_file_converted_to_fsrs;
let items = anki21_sample_file_converted_to_fsrs();
let (mut pretrainset, mut trainset) =
items.into_iter().partition(|item| item.reviews.len() == 2);
(pretrainset, trainset) = filter_outlier(pretrainset, trainset);
let items = [pretrainset.clone(), trainset].concat();
let average_recall = calculate_average_recall(&items);
let pretrainset = split_data(items, 1).0;
Data::from(pretrain(pretrainset, average_recall).unwrap()).assert_approx_eq(
&Data::from([1.001_131, 1.810_561, 4.403_226, 10.935_509]),
4,
&Data::from([1.017_056, 1.829_625, 4.414_563, 10.935_500]),
6,
)
}

Expand Down
10 changes: 8 additions & 2 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,14 @@ pub(crate) struct TrainingConfig {
pub fn calculate_average_recall(items: &[FSRSItem]) -> f32 {
let (total_recall, total_reviews) = items
.iter()
.flat_map(|item| item.reviews.iter())
.map(|item| item.current())
.fold((0u32, 0u32), |(sum, count), review| {
(sum + (review.rating > 1) as u32, count + 1)
});

if total_reviews == 0 {
return 0.0;
}

total_recall as f32 / total_reviews as f32
}

Expand Down Expand Up @@ -441,6 +440,13 @@ mod tests {
use burn::backend::ndarray::NdArrayDevice;
use rayon::prelude::IntoParallelIterator;

#[test]
fn test_calculate_average_recall() {
let items = anki21_sample_file_converted_to_fsrs();
let average_recall = calculate_average_recall(&items);
assert_eq!(average_recall, 0.9435269);
}

#[test]
fn training() {
if std::env::var("SKIP_TRAINING").is_ok() {
Expand Down

0 comments on commit 9a20ec2

Please sign in to comment.