diff --git a/results/models/disease_pipeline.pickle b/results/models/disease_pipeline.pickle index 4a9ca8b..ebec341 100644 Binary files a/results/models/disease_pipeline.pickle and b/results/models/disease_pipeline.pickle differ diff --git a/results/tables/cross_val_score.csv b/results/tables/cross_val_score.csv index 98c9b5c..c734b28 100644 --- a/results/tables/cross_val_score.csv +++ b/results/tables/cross_val_score.csv @@ -1,11 +1,12 @@ -dummy,logreg,svc,logreg_bal,svc_bal -0.001,0.009,0.007,0.009,0.008 -0.003,0.006,0.006,0.005,0.006 -0.746,0.797,0.818,0.731,0.747 -0.746,0.85,0.896,0.822,0.916 -0.0,0.638,0.824,0.477,0.524 -0.0,0.789,0.969,0.609,0.779 -0.0,0.44,0.38,0.68,0.62 -0.0,0.56,0.61,0.84,0.94 -0.0,0.517,0.512,0.555,0.563 -0.0,0.655,0.747,0.706,0.851 +index,dummy,dummy,logreg,logreg,svc,svc,logreg_bal,logreg_bal,svc_bal,svc_bal +,mean,std,mean,std,mean,std,mean,std,mean,std +fit_time,0.001,0.0,0.009,0.0,0.007,0.0,0.009,0.0,0.008,0.001 +score_time,0.003,0.0,0.006,0.0,0.006,0.001,0.006,0.0,0.006,0.0 +test_accuracy,0.746,0.004,0.797,0.063,0.818,0.036,0.731,0.035,0.747,0.099 +train_accuracy,0.746,0.001,0.85,0.012,0.896,0.015,0.822,0.02,0.916,0.01 +test_precision,0.0,0.0,0.638,0.159,0.824,0.182,0.477,0.056,0.524,0.154 +train_precision,0.0,0.0,0.789,0.032,0.969,0.018,0.609,0.031,0.779,0.034 +test_recall,0.0,0.0,0.44,0.167,0.38,0.084,0.68,0.179,0.62,0.11 +train_recall,0.0,0.0,0.56,0.038,0.61,0.065,0.84,0.038,0.94,0.022 +test_f1,0.0,0.0,0.517,0.167,0.512,0.094,0.555,0.09,0.563,0.13 +train_f1,0.0,0.0,0.655,0.033,0.747,0.047,0.706,0.031,0.851,0.014 diff --git a/results/tables/cross_val_std.csv b/results/tables/cross_val_std.csv index 5b7da1b..c734b28 100644 --- a/results/tables/cross_val_std.csv +++ b/results/tables/cross_val_std.csv @@ -1,11 +1,12 @@ -dummy,logreg,svc,logreg_bal,svc_bal -0.0,0.001,0.0,0.0,0.001 -0.001,0.001,0.0,0.0,0.0 -0.004,0.063,0.036,0.035,0.099 -0.001,0.012,0.015,0.02,0.01 -0.0,0.159,0.182,0.056,0.154 -0.0,0.032,0.018,0.031,0.034 -0.0,0.167,0.084,0.179,0.11 -0.0,0.038,0.065,0.038,0.022 -0.0,0.167,0.094,0.09,0.13 -0.0,0.033,0.047,0.031,0.014 +index,dummy,dummy,logreg,logreg,svc,svc,logreg_bal,logreg_bal,svc_bal,svc_bal +,mean,std,mean,std,mean,std,mean,std,mean,std +fit_time,0.001,0.0,0.009,0.0,0.007,0.0,0.009,0.0,0.008,0.001 +score_time,0.003,0.0,0.006,0.0,0.006,0.001,0.006,0.0,0.006,0.0 +test_accuracy,0.746,0.004,0.797,0.063,0.818,0.036,0.731,0.035,0.747,0.099 +train_accuracy,0.746,0.001,0.85,0.012,0.896,0.015,0.822,0.02,0.916,0.01 +test_precision,0.0,0.0,0.638,0.159,0.824,0.182,0.477,0.056,0.524,0.154 +train_precision,0.0,0.0,0.789,0.032,0.969,0.018,0.609,0.031,0.779,0.034 +test_recall,0.0,0.0,0.44,0.167,0.38,0.084,0.68,0.179,0.62,0.11 +train_recall,0.0,0.0,0.56,0.038,0.61,0.065,0.84,0.038,0.94,0.022 +test_f1,0.0,0.0,0.517,0.167,0.512,0.094,0.555,0.09,0.563,0.13 +train_f1,0.0,0.0,0.655,0.033,0.747,0.047,0.706,0.031,0.851,0.014 diff --git a/scripts/4_training_models.py b/scripts/4_training_models.py index 6645110..eb093b1 100644 --- a/scripts/4_training_models.py +++ b/scripts/4_training_models.py @@ -90,11 +90,15 @@ def main(train, write_to): return_train_score=True) ).agg(['mean', 'std']).round(3).T - # Save cross-validation results - pd.concat(cross_val_results, axis='columns').xs('std', axis='columns', level=1).to_csv( + # Save cross-validation results (standard deviation) + std_df = pd.concat(cross_val_results, axis='columns').reset_index() + std_df.to_csv( os.path.join(write_to, "tables", "cross_val_std.csv"), index=False ) - pd.concat(cross_val_results, axis='columns').xs('mean', axis='columns', level=1).to_csv( + + # Save cross-validation results (mean) + mean_df = pd.concat(cross_val_results, axis='columns').reset_index() + mean_df.to_csv( os.path.join(write_to, "tables", "cross_val_score.csv"), index=False )