diff --git a/Cargo.lock b/Cargo.lock index 950f1941..6ccdf6f4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1065,7 +1065,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "0.1.0" +version = "0.2.0" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 76de9ffe..430ad95b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "0.1.0" +version = "0.2.0" authors = ["Open Spaced Repetition"] categories = ["Algorithms", "Science"] edition = "2021" diff --git a/src/training.rs b/src/training.rs index 1fe6f71b..3d7c12f5 100644 --- a/src/training.rs +++ b/src/training.rs @@ -5,7 +5,7 @@ use crate::error::Result; use crate::model::{Model, ModelConfig}; use crate::pre_training::pretrain; use crate::weight_clipper::weight_clipper; -use crate::{FSRSError, FSRS}; +use crate::{FSRSError, DEFAULT_WEIGHTS, FSRS}; use burn::backend::Autodiff; use burn::data::dataloader::DataLoaderBuilder; use burn::module::Module; @@ -239,6 +239,7 @@ impl FSRS { pub fn compute_weights( &self, items: Vec, + pretrain_only: bool, progress: Option>>, ) -> Result> { let finish_progress = || { @@ -261,6 +262,14 @@ impl FSRS { finish_progress(); e })?; + if pretrain_only { + finish_progress(); + let weights = initial_stability + .into_iter() + .chain(DEFAULT_WEIGHTS[4..].iter().copied()) + .collect(); + return Ok(weights); + } let config = TrainingConfig::new( ModelConfig { freeze_stability: true,