-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredictor.py
executable file
·30 lines (23 loc) · 1.29 KB
/
predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from model import predict
from process import post_process, extract_text
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import torch
from starlette.responses import JSONResponse
class PythonPredictor:
def __init__(self, config):
device = 0 if torch.cuda.is_available() else -1
print(f"using device: {'cuda' if device == 0 else 'cpu'}")
self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-distilled-squad")
self.model = AutoModelForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad')
def predict(self, payload):
question = payload['question'].replace("\\/", "/").encode().decode('unicode_escape')
html_article = payload['article'].replace("\\/", "/").encode().decode('unicode_escape')
context = extract_text(html_article)
answer = predict(question, context, tokenizer=self.tokenizer, model=self.model)
payload['reader'] = 0
if len(answer[0]) > 0:
H, T, img = post_process(html_article, answer, payload['html_url'], tokenizer=self.tokenizer)
if H != '':
payload['html_snippet'], payload['text_snippet'], payload['images'] = H, T, img
payload['reader'] = 1
return JSONResponse(content=payload)