Skip to content

Commit

Permalink
use median in auto bootstrap kde bandwidth estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
wgurecky committed Jul 1, 2024
1 parent 4747a4a commit 9a5b717
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/lib_math_utils/univariate_rv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,13 @@ pub fn build_kde(init_bandwidth: f64, samples: ArrayView1<f64>, n_iter: usize, m
.est_bandwidth(s_test.view(), Some(method));
bandwidth_ests.push(bwe.unwrap());
}
// median bandwidth
bandwidth_ests.sort_by(|a, b| a.total_cmp(b));
let med_idx: usize = bandwidth_ests.len() / 2;
let bw: f64 = bandwidth_ests[med_idx];
// avg bandwidth
let bw: f64 = bandwidth_ests.into_iter().sum::<f64>()
/ n_iter as f64;
// let bw: f64 = bandwidth_ests.into_iter().sum::<f64>()
// / n_iter as f64;
KdeRv::new(bw, samples)
}

Expand Down Expand Up @@ -569,7 +573,7 @@ mod univariate_rv_unit_tests {
fn test_kde_rv() {
// generate random samples from known norml dist
let rv_known = Normal::new(5.25, 10.).unwrap();
let ns = 100;
let ns = 400;
let mut tst_s = Array1::zeros(ns);
let mut support_s = Array1::zeros(ns);
for i in 0..ns {tst_s[i] = rv_known.sample(&mut rand::thread_rng()); }
Expand All @@ -587,12 +591,12 @@ mod univariate_rv_unit_tests {
println!("Real pop mean: {:?}, KDE Mean: {:?}", support_s.mean(), kde_samples.mean());
println!("Real pop std: {:?}, KDE std: {:?}", support_s.std(0.), kde_samples.std(0.));
assert_approx_eq!(support_s.mean().unwrap(), kde_samples.mean().unwrap(), 9e-1);
assert_approx_eq!(support_s.std(0.), kde_samples.std(0.), 3.);
assert_approx_eq!(support_s.std(0.), kde_samples.std(0.), 5.);

// test kde automated builder
let auto_kde_dist = build_kde(1.0, support_s.view(), 10, 2).unwrap();
let kde_samples = auto_kde_dist.sample(10000, None);
println!("Fitted KDE bandwidth: {:?}", auto_kde_dist.bandwidth);
let auto_kde_dist = build_kde(1.0, support_s.view(), 20, 2).unwrap();
let kde_samples = auto_kde_dist.sample(100000, None);
println!("Fitted auto KDE bandwidth: {:?}", auto_kde_dist.bandwidth);
println!("Real pop mean: {:?}, KDE Mean: {:?}", support_s.mean(), kde_samples.mean());
println!("Real pop std: {:?}, KDE std: {:?}", support_s.std(0.), kde_samples.std(0.));
assert_approx_eq!(support_s.mean().unwrap(), kde_samples.mean().unwrap(), 9e-1);
Expand Down

0 comments on commit 9a5b717

Please sign in to comment.