Skip to content

Commit

Permalink
update default parameters (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Mar 1, 2024
1 parent a567486 commit 333a63c
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "0.4.2"
version = "0.4.3"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
10 changes: 5 additions & 5 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ pub type Parameters = [f32];
use itertools::izip;

pub static DEFAULT_PARAMETERS: [f32; 17] = [
0.5614, 1.2546, 3.5878, 7.9731, 5.1043, 1.1303, 0.823, 0.0465, 1.629, 0.135, 1.0045, 2.132,
0.0839, 0.3204, 1.3547, 0.219, 2.7849,
0.5701, 1.4436, 4.1386, 10.9355, 5.1443, 1.2006, 0.8627, 0.0362, 1.629, 0.1342, 1.0166, 2.1174,
0.0839, 0.3204, 1.4676, 0.219, 2.8237,
];

fn infer<B: Backend>(
Expand Down Expand Up @@ -478,7 +478,7 @@ mod tests {
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.204_001, 0.025_387]), 5);
.assert_approx_eq(&Data::from([0.203_023, 0.024_624]), 5);

let fsrs = FSRS::new(Some(PARAMETERS))?;
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
Expand All @@ -491,7 +491,7 @@ mod tests {
.unwrap();

Data::from([self_by_other, other_by_self])
.assert_approx_eq(&Data::from([0.015_987, 0.019_767]), 5);
.assert_approx_eq(&Data::from([0.016_727, 0.019_213]), 5);
Ok(())
}

Expand Down Expand Up @@ -578,7 +578,7 @@ mod tests {
fsrs.memory_state_from_sm2(2.5, 10.0, 0.9).unwrap(),
MemoryState {
stability: 9.999995,
difficulty: 7.200902
difficulty: 7.255334
}
);
assert_eq!(
Expand Down
14 changes: 7 additions & 7 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ mod tests {
let stability = model.init_stability(rating);
assert_eq!(
stability.to_data(),
Data::from([0.5614, 1.2546, 3.5878, 7.9731, 0.5614, 1.2546])
Data::from([0.5701, 1.4436, 4.1386, 10.9355, 0.5701, 1.4436])
)
}

Expand All @@ -295,7 +295,7 @@ mod tests {
let difficulty = model.init_difficulty(rating);
assert_eq!(
difficulty.to_data(),
Data::from([7.3649, 6.2346, 5.1043, 3.974, 7.3649, 6.2346])
Data::from([7.5455, 6.3449, 5.1443, 3.9436998, 7.5455, 6.3449])
)
}

Expand Down Expand Up @@ -331,13 +331,13 @@ mod tests {
next_difficulty.clone().backward();
assert_eq!(
next_difficulty.to_data(),
Data::from([6.646, 5.823, 5.0, 4.177])
Data::from([6.7254, 5.8627, 5.0, 4.1373])
);
let next_difficulty = model.mean_reversion(next_difficulty);
next_difficulty.clone().backward();
assert_eq!(
next_difficulty.to_data(),
Data::from([6.574311, 5.7895803, 5.00485, 4.2201195])
Data::from([6.6681643, 5.836694, 5.0052238, 4.1737533])
)
}

Expand All @@ -358,19 +358,19 @@ mod tests {
s_recall.clone().backward();
assert_eq!(
s_recall.to_data(),
Data::from([26.678038, 13.996968, 62.718544, 202.76956])
Data::from([26.980936, 14.128489, 63.600677, 208.72739])
);
let s_forget = model.stability_after_failure(stability, difficulty, retention);
s_forget.clone().backward();
assert_eq!(
s_forget.to_data(),
Data::from([1.8932177, 2.0453987, 2.2637987, 2.5304008])
Data::from([1.9016013, 2.0777826, 2.3257504, 2.6291647])
);
let next_stability = s_recall.mask_where(rating.clone().equal_elem(1), s_forget);
next_stability.clone().backward();
assert_eq!(
next_stability.to_data(),
Data::from([1.8932177, 13.996968, 62.718544, 202.76956])
Data::from([1.9016013, 14.128489, 63.600677, 208.72739])
)
}

Expand Down
8 changes: 4 additions & 4 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ mod tests {
None,
)
.0;
assert_eq!(memorization[memorization.len() - 1], 3022.055014122344)
assert_eq!(memorization[memorization.len() - 1], 3130.8465582271774)
}

#[test]
Expand Down Expand Up @@ -732,8 +732,8 @@ mod tests {
assert_eq!(
results.1.to_vec(),
vec![
0, 16, 27, 29, 86, 73, 96, 95, 96, 105, 112, 113, 124, 131, 139, 124, 130, 141,
162, 175, 168, 179, 186, 185, 198, 189, 200, 200, 200, 200
0, 16, 27, 34, 84, 80, 91, 92, 103, 107, 111, 113, 138, 132, 133, 116, 134, 148,
152, 162, 172, 177, 188, 189, 200, 185, 185, 200, 198, 200
]
);
assert_eq!(
Expand All @@ -747,7 +747,7 @@ mod tests {
let config = SimulatorConfig::default();
let fsrs = FSRS::new(None)?;
let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap();
assert_eq!(optimal_retention, 0.864870726919112);
assert_eq!(optimal_retention, 0.8468471175527587);
assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err());
Ok(())
}
Expand Down
8 changes: 5 additions & 3 deletions src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,10 @@ mod tests {
let items = anki21_sample_file_converted_to_fsrs();
let average_recall = calculate_average_recall(&items);
let pretrainset = split_data(items, 1).0;
Data::from(pretrain(pretrainset, average_recall).unwrap())
.assert_approx_eq(&Data::from([1.001_131, 1.810_561, 4.403_481, 8.530_161]), 4)
Data::from(pretrain(pretrainset, average_recall).unwrap()).assert_approx_eq(
&Data::from([1.001_131, 1.810_561, 4.403_226, 10.935_509]),
4,
)
}

#[test]
Expand All @@ -349,6 +351,6 @@ mod tests {
let mut rating_stability = HashMap::from([(2, 0.35)]);
let rating_count = HashMap::from([(2, 1)]);
let actual = smooth_and_fill(&mut rating_stability, &rating_count).unwrap();
assert_eq!(actual, [0.15661564, 0.35, 1.0009006, 2.2242827,]);
assert_eq!(actual, [0.13822041, 0.35, 1.0034012, 2.6513057,]);
}
}

0 comments on commit 333a63c

Please sign in to comment.