Skip to content

Commit

Permalink
Fix/early stop (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Mar 5, 2024
1 parent 9a20ec2 commit ba9d65a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "0.4.5"
version = "0.4.6"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
15 changes: 12 additions & 3 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,9 @@ fn train<B: AutodiffBackend>(
Aggregate::Mean,
Direction::Lowest,
Split::Valid,
StoppingCondition::NoImprovementSince { n_epochs: 1 },
StoppingCondition::NoImprovementSince {
n_epochs: config.num_epochs,
},
))
.devices(vec![device])
.num_epochs(config.num_epochs)
Expand Down Expand Up @@ -457,8 +459,11 @@ mod tests {
let device = NdArrayDevice::Cpu;
let items = anki21_sample_file_converted_to_fsrs();
let (pre_trainset, trainsets, testset) = split_data(items.clone(), n_splits);
let average_recall = calculate_average_recall(&pre_trainset);
let items = [pre_trainset.clone(), testset.clone()].concat();
let average_recall = calculate_average_recall(&items);
dbg!(average_recall);
let initial_stability = pretrain(pre_trainset, average_recall).unwrap();
dbg!(initial_stability);
let config = TrainingConfig::new(
ModelConfig {
freeze_stability: true,
Expand Down Expand Up @@ -492,6 +497,10 @@ mod tests {
.par_iter()
.map(|&sum| sum / n_splits as f32)
.collect();
dbg!(average_parameters);
dbg!(&average_parameters);

let fsrs = FSRS::new(Some(&average_parameters)).unwrap();
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
dbg!(&metrics);
}
}

0 comments on commit ba9d65a

Please sign in to comment.