-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
31 lines (23 loc) · 994 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# main.py
import gradio as gr
from utils import load_model, predict_category
from news_dataset import NewsDataset # Importez NewsDataset depuis news_dataset.py
def launch_app():
dataset = NewsDataset(csv_file="./inshort_news_data.csv", max_length=100)
num_classes = len(dataset.labels_dict)
model_path = './models/trained_model1.pth' # Chemin vers le modèle entraîné
model = load_model(model_path, num_classes) # Charger le modèle entraîné avec le bon nombre de classes
labels_dict = dataset.labels_dict
def predict_function(headline, article):
return predict_category(headline, article, model, labels_dict)
iface = gr.Interface(
fn=predict_function,
inputs=["text", "text"],
outputs="text",
title="News Category Classification",
description="Enter a headline and an article to classify its category."
)
#iface.launch()
iface.launch(share=True)
if __name__ == "__main__":
launch_app()