-
Notifications
You must be signed in to change notification settings - Fork 1
/
hubconf.py
102 lines (86 loc) · 3.56 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
dependencies = ["torch", "torchaudio", "numpy"]
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import logging
import json
from pathlib import Path
from wavlm.WavLM import WavLM, WavLMConfig
from hifigan.models import Generator as HiFiGAN
from hifigan.utils import AttrDict
from matcher import KNeighborsVC
from anonymizer import Anonymizer
def salt(pretrained=True, progress=True, base=True, device="cuda") -> Anonymizer:
anonymizer = Anonymizer(knn_vc(pretrained, progress, True, base, device))
return anonymizer
def knn_vc(
pretrained=True, progress=True, prematched=True, base=False, device="cuda"
) -> KNeighborsVC:
"""Load kNN-VC (WavLM encoder and HiFiGAN decoder). Optionally use vocoder trained on `prematched` data."""
hifigan, hifigan_cfg = hifigan_wavlm(pretrained, progress, base, prematched, device)
wavlm = wavlm_large(pretrained, progress, base, device)
knnvc = KNeighborsVC(wavlm, hifigan, hifigan_cfg, base, device)
return knnvc
def hifigan_wavlm(
pretrained=True, progress=True, base=False, prematched=True, device="cuda"
) -> HiFiGAN:
"""Load pretrained hifigan trained to vocode wavlm features. Optionally use weights trained on `prematched` data."""
cp = Path(__file__).parent.absolute()
with open(cp / "hifigan" / "config_v1_wavlm.json") as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
if base:
h.hubert_dim = 768
device = torch.device(device)
generator = HiFiGAN(h).to(device)
if pretrained:
if prematched:
if base:
url = "https://github.com/BakerBunker/SALT/releases/download/1.0.0/base_g_02500000.pt"
else:
url = "https://github.com/bshall/knn-vc/releases/download/v0.1/prematch_g_02500000.pt"
else:
url = "https://github.com/bshall/knn-vc/releases/download/v0.1/g_02500000.pt"
state_dict_g = torch.hub.load_state_dict_from_url(
url, map_location=device, progress=progress
)
generator.load_state_dict(state_dict_g["generator"])
generator.eval()
generator.remove_weight_norm()
print(
f"[HiFiGAN] Generator loaded with {sum([p.numel() for p in generator.parameters()]):,d} parameters."
)
return generator, h
def wavlm_large(pretrained=True, progress=True, base=False, device="cuda") -> WavLM:
"""Load the WavLM large checkpoint from the original paper. See https://github.com/microsoft/unilm/tree/master/wavlm for details."""
if torch.cuda.is_available() == False:
if str(device) != "cpu":
logging.warning(
f"Overriding device {device} to cpu since no GPU is available."
)
device = "cpu"
if not base:
checkpoint = torch.hub.load_state_dict_from_url(
"https://github.com/bshall/knn-vc/releases/download/v0.1/WavLM-Large.pt",
map_location=device,
progress=progress,
)
else:
checkpoint = torch.hub.load_state_dict_from_url(
"https://github.com/BakerBunker/SALT/releases/download/1.0.0/WavLM-Base.pt",
map_location=device,
progress=progress,
)
cfg = WavLMConfig(checkpoint["cfg"])
device = torch.device(device)
model = WavLM(cfg)
if pretrained:
model.load_state_dict(checkpoint["model"])
model = model.to(device)
model.eval()
print(
f"WavLM-{'Large' if not base else 'Base'} loaded with {sum([p.numel() for p in model.parameters()]):,d} parameters."
)
return model