Skip to content

Commit

Permalink
Feat/option for pretrain only (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Jan 29, 2024
1 parent 51d38b4 commit 85643d4
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "0.1.0"
version = "0.2.0"
authors = ["Open Spaced Repetition"]
categories = ["Algorithms", "Science"]
edition = "2021"
Expand Down
11 changes: 10 additions & 1 deletion src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::error::Result;
use crate::model::{Model, ModelConfig};
use crate::pre_training::pretrain;
use crate::weight_clipper::weight_clipper;
use crate::{FSRSError, FSRS};
use crate::{FSRSError, DEFAULT_WEIGHTS, FSRS};
use burn::backend::Autodiff;
use burn::data::dataloader::DataLoaderBuilder;
use burn::module::Module;
Expand Down Expand Up @@ -239,6 +239,7 @@ impl<B: Backend> FSRS<B> {
pub fn compute_weights(
&self,
items: Vec<FSRSItem>,
pretrain_only: bool,
progress: Option<Arc<Mutex<CombinedProgressState>>>,
) -> Result<Vec<f32>> {
let finish_progress = || {
Expand All @@ -261,6 +262,14 @@ impl<B: Backend> FSRS<B> {
finish_progress();
e
})?;
if pretrain_only {
finish_progress();
let weights = initial_stability
.into_iter()
.chain(DEFAULT_WEIGHTS[4..].iter().copied())
.collect();
return Ok(weights);
}
let config = TrainingConfig::new(
ModelConfig {
freeze_stability: true,
Expand Down

0 comments on commit 85643d4

Please sign in to comment.