-
Notifications
You must be signed in to change notification settings - Fork 7
/
server.py
70 lines (50 loc) · 2.39 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
import logging
from waitress import serve
from flask import Flask, Response, request
import yaml
from transformers import AutoTokenizer
import torch
import traceback
import json
from pyeurovoc import EuroVocBERT
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def predict():
try:
content = request.json
if content is None:
return Response(response='{"message": "The POST request does not contain a JSON payload."}',
status=400, mimetype='application/json')
if "data" not in content:
return Response(response='{"message": "The field \"data\" was not found in the sent JSON payload."}',
status=400, mimetype='application/json')
if not type(content["data"]) == str:
return Response(response='{"message": "The field \"data\" must be a byte array encoded in Base64."}',
status=400, mimetype='application/json')
id_labels = model(content["data"], num_labels=config["num_id_labels"])
mt_labels = [dict_mt_labels[str(label)]for label in id_labels if str(label) in dict_mt_labels]
do_labels = [dict_mt_labels[str(label)][:2] for label in id_labels if str(label) in dict_mt_labels]
return {
"id_labels": list(id_labels.keys()),
"mt_labels": mt_labels[:config["num_mt_labels"]],
"do_labels": do_labels[:config["num_do_labels"]]
}
except:
logging.error(traceback.format_exc())
return Response(response='{"message": "An unexpected error occurred during transcription."}',
status=400, mimetype='application/json')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--server_config", type=str, default="configs/server.yml")
args = parser.parse_args()
with open(args.server_config, "r") as file:
config = yaml.load(file, Loader=yaml.FullLoader)
with open(config["mt_labels_path"], "r") as file:
dict_mt_labels = json.load(file)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.getLogger().setLevel(logging.INFO)
logging.info('Setting up server...')
model = EuroVocBERT(lang=config["model"]["language"])
logging.info('Server initialised')
serve(app, host=config["server"]["host"], port=config["server"]["port"])