Skip to content

Commit

Permalink
fix genai image classification
Browse files Browse the repository at this point in the history
  • Loading branch information
Ekedani committed May 27, 2024
1 parent 4debf2f commit afc381e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 19 deletions.
8 changes: 4 additions & 4 deletions ai-image-classifier/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
models = {}


def load_models():
def load_genai_image_classifiers():
model_dir = 'models'
for model_file in os.listdir(model_dir):
if model_file.endswith('.pkl'):
Expand All @@ -26,8 +26,8 @@ def load_models():
models[model_name] = load_learner(os.path.join(model_dir, model_file))


load_models()
BEST_MODEL = 'resnet34_image_classifier_v4'
load_genai_image_classifiers()
BEST_CLASSIFIER = 'resnet34_image_classifier_v4'


def get_prediction(model, request):
Expand Down Expand Up @@ -57,7 +57,7 @@ def get_prediction(model, request):

@app.route('/prediction', methods=['POST'])
def predict_using_best_model():
return get_prediction(BEST_MODEL, request)
return get_prediction(BEST_CLASSIFIER, request)


@app.route('/<model>/prediction', methods=['POST'])
Expand Down
38 changes: 23 additions & 15 deletions ai-image-classifier/model_training.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from fastai.vision.all import *
import os

import matplotlib.pyplot as plt
from fastai.callback.tracker import EarlyStoppingCallback, SaveModelCallback
from fastai.interpret import ClassificationInterpretation
from fastai.metrics import accuracy
from fastai.vision.augment import Resize
from fastai.vision.data import ImageDataLoaders
from fastai.vision.learner import vision_learner
from fastai.vision.models import resnet34


def setup_and_train():
def setup_and_train_genai_classifier():
dataset_path = 'dataset\\train'
model_path = 'models'

Expand All @@ -14,37 +22,37 @@ def setup_and_train():
item_tfms=Resize(224)
)

learn = vision_learner(dls, resnet34, metrics=accuracy)
genai_classifier_learn = vision_learner(dls, resnet34, metrics=accuracy)

early_stop_callback = EarlyStoppingCallback(monitor='valid_loss', min_delta=0.01, patience=3)
save_model_callback = SaveModelCallback(fname='best_model', every_epoch=False, with_opt=True)

learn.lr_find()
learn.fit_one_cycle(20, cbs=[early_stop_callback, save_model_callback])
genai_classifier_learn.lr_find()
genai_classifier_learn.fit_one_cycle(20, cbs=[early_stop_callback, save_model_callback])

learn.save(f'resnet34_image_classifier_v4.pth')
learn.export(f'resnet34_image_classifier_v4.pkl')
print("Saved GenAI image classifier model to disk")
genai_classifier_learn.save('resnet34_image_classifier_v4.pth')
genai_classifier_learn.export('resnet34_image_classifier_v4.pkl')
print('Saved GenAI image classifier model to disk')

return learn, model_path
return genai_classifier_learn, model_path


def plot_results(learn):
def test_and_plot_results(genai_classifier_learn):
plt.figure(figsize=(10, 4))
learn.recorder.plot_lr_find()
genai_classifier_learn.recorder.plot_lr_find()
plt.title("Learning Rate Finder")
plt.xlabel("Learning Rate")
plt.ylabel("Loss")
plt.savefig(f'learning_rate_finder.png')

plt.figure(figsize=(10, 4))
learn.recorder.plot_loss()
genai_classifier_learn.recorder.plot_loss()
plt.title("Training and Validation Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.savefig(f'training_validation_loss.png')

interp = ClassificationInterpretation.from_learner(learn)
interp = ClassificationInterpretation.from_learner(genai_classifier_learn)

interp.plot_top_losses(6, figsize=(15, 11))
plt.savefig(f'top_losses.png')
Expand All @@ -54,5 +62,5 @@ def plot_results(learn):


if __name__ == '__main__':
learn, model_path = setup_and_train()
plot_results(learn)
genai_classifier_learn, model_path = setup_and_train_genai_classifier()
test_and_plot_results(genai_classifier_learn)
Binary file modified ai-image-classifier/requirements.txt
Binary file not shown.

0 comments on commit afc381e

Please sign in to comment.