-
Notifications
You must be signed in to change notification settings - Fork 0
/
whisper_api.py
204 lines (160 loc) · 5.51 KB
/
whisper_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import logging
from typing import Optional
from uuid import uuid4
from fastapi import BackgroundTasks, FastAPI, HTTPException, status, Response
from asr import run
from whisper import load_model
from enum import Enum
from pydantic import BaseModel
from config import (
model_base_dir,
w_device,
w_model,
)
logger = logging.getLogger(__name__)
api = FastAPI()
logger.info(f"Loading model on device {w_device}")
# load the model in memory on API startup
model = load_model(model_base_dir, w_model, w_device)
class Status(Enum):
CREATED = "CREATED"
PROCESSING = "PROCESSING"
DONE = "DONE"
ERROR = "ERROR"
StatusToHTTP = {
Status.CREATED: status.HTTP_201_CREATED,
Status.PROCESSING: status.HTTP_202_ACCEPTED,
Status.DONE: status.HTTP_200_OK,
Status.ERROR: status.HTTP_500_INTERNAL_SERVER_ERROR,
}
class Task(BaseModel):
input_uri: str
output_uri: str
status: Status = Status.CREATED
id: str | None = None
error_msg: str | None = None
all_tasks = [
{
"input_uri": "http://modelhosting.beng.nl/whisper-asr.mp3",
"output_uri": "http://modelhosting.beng.nl/assets/whisper-asr",
"id": "test1",
}
]
current_task: Optional[Task] = None
def get_task_by_id(task_id: str) -> Optional[dict]:
tasks_with_id = list(filter(lambda t: t.get("id", "") == task_id, all_tasks))
return tasks_with_id[0] if tasks_with_id else None
def get_task_index(task_id: str) -> int:
for index, task in enumerate(all_tasks):
if task.get("id", "") == task_id:
return index
return -1
def delete_task(task_id) -> bool:
task_index = get_task_index(task_id)
if task_index == -1:
return False
del all_tasks[task_index]
return True
def update_task(task: Task) -> bool:
if not task or not task.id:
logger.warning("Tried to update task without ID")
return False
task_index = get_task_index(task.id)
if task_index == -1:
return False
all_tasks[task_index] = task.dict()
return True
def try_whisper(task: Task):
logger.info(f"Trying to call Whisper for task {task.id}")
try:
task.status = Status.PROCESSING
update_task(task)
error_msg = run(task.input_uri, task.output_uri, model)
task.status = Status.ERROR if error_msg else Status.DONE
task.error_msg = error_msg
except Exception:
logger.exception("Failed to run whisper")
task.status = Status.ERROR
update_task(task)
logger.info(f"Done running Whisper for task {task.id}")
@api.get("/tasks")
def get_all_tasks():
return {"data": all_tasks}
@api.get("/status")
def get_status(response: Response):
global current_task
if current_task and current_task.status == Status.PROCESSING:
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
return {"msg": "The worker is currently processing a task. Try again later!"}
response.status_code = status.HTTP_200_OK
return {"msg": "The worker is available!"}
@api.post("/tasks", status_code=status.HTTP_201_CREATED)
async def create_task(
task: Task, background_tasks: BackgroundTasks, response: Response
):
global current_task
if current_task and current_task.status == Status.PROCESSING:
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
return {"msg": "The worker is currently processing a task. Try again later!"}
background_tasks.add_task(try_whisper, task)
task.id = str(uuid4())
task.status = Status.CREATED
current_task = task
task_dict = task.dict()
all_tasks.append(task_dict)
return {"data": task_dict, "msg": "Successfully added task", "task_id": task.id}
@api.get("/tasks/{task_id}")
async def get_task(task_id: str, response: Response):
task = get_task_by_id(task_id)
if not task:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"Task {task_id} not found"
)
response.status_code = StatusToHTTP[task["status"]]
return {"data": task}
@api.delete("/tasks/{task_id}")
async def remove_task(task_id: str):
success = delete_task(task_id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"Task {task_id} not found"
)
return {
"msg": (
f"Successfully deleted task {task_id}"
if success
else f"Failed to delete task {task_id}"
),
"task_id": task_id,
}
@api.get("/ping")
async def ping():
return "pong"
if __name__ == "__main__":
import sys
import uvicorn
from argparse import ArgumentParser
from base_util import LOG_FORMAT
# first read the CLI arguments
parser = ArgumentParser(description="whisper-api")
parser.add_argument("--port", action="store", dest="port", default="5333")
parser.add_argument("--log", action="store", dest="loglevel", default="INFO")
args = parser.parse_args()
# initialises the root logger
logging.basicConfig(
stream=sys.stdout, # configure a stream handler only for now (single handler)
format=LOG_FORMAT,
)
# setting the loglevel
log_level = args.loglevel.upper()
logger.setLevel(log_level)
logger.info(f"Logger initialized (log level: {log_level})")
logger.info(f"Got the following CMD line arguments: {args}")
port = 5333
try:
port = int(args.port)
except ValueError:
logger.error(
f"--port must be a valid integer, starting with default port {port}"
)
uvicorn.run(api, port=port, host="0.0.0.0")