-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.py
95 lines (71 loc) · 2.9 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import warnings
warnings.simplefilter("ignore")
from argparse import ArgumentParser
from enum import StrEnum
import uvicorn
from fastapi import Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import DirectoryPath
from rich.progress import Progress
from model import Model
from schema import Input, Settings, Speaker
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_headers=["*"],
allow_methods=["*"],
allow_origins=["*"],
)
@app.post("/set_tts_settings")
def set(settings: Settings):
model.settings = settings
@app.get("/speakers")
def get():
return [
{"name": speaker.capitalize(), "voice_id": speaker, "preview_url": ""}
for speaker in model.speakers
]
@app.get("/tts_stream")
async def stream(request: Request, input: Input = Depends()):
async def generator():
async for output in model.stream(input):
if await request.is_disconnected():
break
yield output
return StreamingResponse(generator(), media_type="audio/ogg")
@app.post("/tts_to_audio")
async def generate(request: Request, input: Input):
async def generator():
async for output in model.generate(input):
if await request.is_disconnected():
break
yield output
return StreamingResponse(generator(), media_type="audio/ogg")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=8020)
parser.add_argument("-m", "--model", type=DirectoryPath, required=True)
parser.add_argument("-s", "--speakers", type=DirectoryPath, required=True)
parser.add_argument("-D", "--device", type=str, default="cuda")
parser.add_argument("-d", "--deepspeed", action="store_true")
parser.add_argument("-o", "--offload", action="store_true")
parser.add_argument("-r", "--recache", action="store_true")
args = parser.parse_args()
with Progress(transient=True) as progress:
loading = progress.add_task("Loading model", total=None)
model = Model(args.model, args.device, args.offload, args.deepspeed)
with Progress(transient=True) as progress:
suffixes = (".flac", ".mp3", ".ogg", ".wav")
speakers = [s for s in args.speakers.glob("*.*") if s.suffix in suffixes]
caching = progress.add_task("Caching speakers", total=len(speakers))
for speaker in speakers:
model.add(speaker, args.recache)
progress.advance(caching)
Speakers = StrEnum("Speakers", ((s, s) for s in model.speakers))
Speaker._member_map_ = Speakers._member_map_
Speaker._member_names_ = Speakers._member_names_
Speaker._value2member_map_ = Speakers._value2member_map_
uvicorn.run(app, host=args.host, port=args.port)