diff --git a/ai-image-classifier/main.py b/ai-image-classifier/main.py index e7e474c..da5c206 100644 --- a/ai-image-classifier/main.py +++ b/ai-image-classifier/main.py @@ -17,7 +17,7 @@ models = {} -def load_models(): +def load_genai_image_classifiers(): model_dir = 'models' for model_file in os.listdir(model_dir): if model_file.endswith('.pkl'): @@ -26,8 +26,8 @@ def load_models(): models[model_name] = load_learner(os.path.join(model_dir, model_file)) -load_models() -BEST_MODEL = 'resnet34_image_classifier_v4' +load_genai_image_classifiers() +BEST_CLASSIFIER = 'resnet34_image_classifier_v4' def get_prediction(model, request): @@ -57,7 +57,7 @@ def get_prediction(model, request): @app.route('/prediction', methods=['POST']) def predict_using_best_model(): - return get_prediction(BEST_MODEL, request) + return get_prediction(BEST_CLASSIFIER, request) @app.route('//prediction', methods=['POST']) diff --git a/ai-image-classifier/model_training.py b/ai-image-classifier/model_training.py index b2fc9a5..d715f5d 100644 --- a/ai-image-classifier/model_training.py +++ b/ai-image-classifier/model_training.py @@ -1,8 +1,16 @@ -from fastai.vision.all import * +import os + +import matplotlib.pyplot as plt +from fastai.callback.tracker import EarlyStoppingCallback, SaveModelCallback +from fastai.interpret import ClassificationInterpretation +from fastai.metrics import accuracy from fastai.vision.augment import Resize +from fastai.vision.data import ImageDataLoaders +from fastai.vision.learner import vision_learner +from fastai.vision.models import resnet34 -def setup_and_train(): +def setup_and_train_genai_classifier(): dataset_path = 'dataset\\train' model_path = 'models' @@ -14,37 +22,37 @@ def setup_and_train(): item_tfms=Resize(224) ) - learn = vision_learner(dls, resnet34, metrics=accuracy) + genai_classifier_learn = vision_learner(dls, resnet34, metrics=accuracy) early_stop_callback = EarlyStoppingCallback(monitor='valid_loss', min_delta=0.01, patience=3) save_model_callback = SaveModelCallback(fname='best_model', every_epoch=False, with_opt=True) - learn.lr_find() - learn.fit_one_cycle(20, cbs=[early_stop_callback, save_model_callback]) + genai_classifier_learn.lr_find() + genai_classifier_learn.fit_one_cycle(20, cbs=[early_stop_callback, save_model_callback]) - learn.save(f'resnet34_image_classifier_v4.pth') - learn.export(f'resnet34_image_classifier_v4.pkl') - print("Saved GenAI image classifier model to disk") + genai_classifier_learn.save('resnet34_image_classifier_v4.pth') + genai_classifier_learn.export('resnet34_image_classifier_v4.pkl') + print('Saved GenAI image classifier model to disk') - return learn, model_path + return genai_classifier_learn, model_path -def plot_results(learn): +def test_and_plot_results(genai_classifier_learn): plt.figure(figsize=(10, 4)) - learn.recorder.plot_lr_find() + genai_classifier_learn.recorder.plot_lr_find() plt.title("Learning Rate Finder") plt.xlabel("Learning Rate") plt.ylabel("Loss") plt.savefig(f'learning_rate_finder.png') plt.figure(figsize=(10, 4)) - learn.recorder.plot_loss() + genai_classifier_learn.recorder.plot_loss() plt.title("Training and Validation Loss") plt.xlabel("Iterations") plt.ylabel("Loss") plt.savefig(f'training_validation_loss.png') - interp = ClassificationInterpretation.from_learner(learn) + interp = ClassificationInterpretation.from_learner(genai_classifier_learn) interp.plot_top_losses(6, figsize=(15, 11)) plt.savefig(f'top_losses.png') @@ -54,5 +62,5 @@ def plot_results(learn): if __name__ == '__main__': - learn, model_path = setup_and_train() - plot_results(learn) + genai_classifier_learn, model_path = setup_and_train_genai_classifier() + test_and_plot_results(genai_classifier_learn) diff --git a/ai-image-classifier/requirements.txt b/ai-image-classifier/requirements.txt index 9b572b3..1c1b0a4 100644 Binary files a/ai-image-classifier/requirements.txt and b/ai-image-classifier/requirements.txt differ