Skip to content

Commit

Permalink
Add additional argument to script 4 and 5
Browse files Browse the repository at this point in the history
  • Loading branch information
Piloxita committed Dec 14, 2024
1 parent fe923fb commit 726ca3c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
6 changes: 4 additions & 2 deletions scripts/4_training_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@

@click.command()
@click.option('--train', type=str, help="Location of train data file")
@click.option('--seed', type =int, help="Set seed for reproducibility")
@click.option('--write-to', type=str, help="Path to master directory where outputs will be written")
def main(train, write_to):

def main(train, seed, write_to):

# Ensure necessary directories exist
os.makedirs(os.path.join(write_to, "tables"), exist_ok=True)
Expand Down Expand Up @@ -82,7 +84,7 @@ def main(train, write_to):

# 3. Training models
models = class_model_trainer(preprocessor, X_train, y_train, pos_lable = '> 50% diameter narrowing',
seed=123, write_to=write_to,
seed=seed, write_to=write_to,
cv = 5, metrics = classification_metrics)

# 4. HYPERPARAMETER OPTIMIZATION
Expand Down
5 changes: 3 additions & 2 deletions scripts/5_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
@click.command()
@click.option('--train', type=str, help="Location of train data file")
@click.option('--test', type=str, help="Path to the test data file", required=True)
@click.option('--pipeline', type=str, help="Path to the model pickle", required=True)
@click.option('--write-to', type=str, help="Path to the master directory where outputs will be written", required=True)
def main(train, test, write_to):
def main(train, test, pipeline, write_to):
"""
Evaluate a trained model on test data and save evaluation metrics and confusion matrix.
"""
# Define the path to the saved model file
model_path = os.path.join(write_to, "models", "disease_pipeline.pickle")
model_path = pipeline

# Ensure necessary directories exist
os.makedirs(os.path.join(write_to, "tables"), exist_ok=True)
Expand Down

0 comments on commit 726ca3c

Please sign in to comment.