Skip to content

Commit

Permalink
Merge pull request #3 from rogermiranda1000/feature/different-voices
Browse files Browse the repository at this point in the history
Feature/different voices
  • Loading branch information
miranda1000 authored Oct 2, 2023
2 parents 08ffeac + 462c9e0 commit c5ea4d1
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 46 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,23 @@ You'll have to edit `config.json`:
- Set the bot name in `nick` and `owner`
- Set the desired secret key in a new entry: `"secret": "admin"`. This will be needed for the audio player, as you'll have to enter to `localhost:7890?token=<secret>`
- Set the desired redeem name in a new entry: `"redeem": "Custom TTS"`
- Set the `RVC-based TTS` model name in a new entry: `"model": "<model name>"`
- Set the app id in `client_id`
- Set the OAuth token in `"oauth": "oauth:<token>"`
- Set the PubSub token in a new entry: `"pubsub": "<token>"`
- Set the `RVC-based TTS` model name in a new entry:
```
"voices": {
"<model name>": {
"model-name": "<model name>",
"model-voice": "en-US-AriaNeural-Female"
}
}
```

Optional additional properties:

- You can add a character limit by setting `"input_limit": 450`
- You can add a segments limit by setting `"segments_limit": 12`. Note: one segment is one change of model, or a sound being played.
- To add sounds to be replaced you'll have to add an `audios` folder and place there the .wav; then create a new `audios` entry with each audio and :
```
"audios": {
Expand Down
98 changes: 80 additions & 18 deletions TwitchTTSBot/bot_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
import os, json, uuid, random, re
from typing import List
from bot import TwitchTTSBot
import shutil
from synthesizers.synthesizer import TTSSynthesizer
from synthesizers.rvc_synthesizer import RVCTTSSynthesizer
from synthesizers.rvc_synthesizer import RVCTTSSynthesizer,RVCModel
from functools import cache

import pypeln as pl
import asyncio
from tts_queue import TTSQueueEntry, TTSSegment, GeneratedTTSSegment
from tts_queue import TTSQueueEntry, TTSSegment, GeneratedTTSSegment, PregeneratedTTSSegment

@cache
def _get_config_json() -> json:
Expand All @@ -30,6 +29,9 @@ def _get_model_name() -> str:
def _get_character_limit() -> int:
return None if 'input_limit' not in _get_config_json() else _get_config_json()['input_limit']

def _get_segments_limit() -> int:
return None if 'segments_limit' not in _get_config_json() else _get_config_json()['segments_limit']

def _get_token() -> str:
return _get_config_json()['secret']

Expand Down Expand Up @@ -59,33 +61,73 @@ def _generate_splits(text: str, find: List[str]) -> List[str]:

return r

def _splits_to_segments(found: List[str], audios: dict, into: List[TTSSegment]):
into.clear()
def _splits_to_segments(found: List[str], audios: dict, segment: GeneratedTTSSegment) -> List[TTSSegment]:
into = []
synthesizer = segment.synthesizer

for f in found:
e = None
if not f in audios:
# text
e = GeneratedTTSSegment(text=f)
e = GeneratedTTSSegment(text=f, synthesizer=synthesizer)
else:
# audio
# copy route
target_file = str(uuid.uuid4().hex) + '.wav'
target_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../audio-server/audios/' + target_file)

# from route
audio_entry = audios[f]
while 'alias' in audio_entry:
audio_entry = audios[ audio_entry['alias'] ]
audio_file = random.choice( audio_entry['files'] )
audio_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'audios/' + audio_file)

shutil.copyfile(audio_path, target_path)
e = TTSSegment(path=target_path)
e = PregeneratedTTSSegment(copy_from=audio_path)

into.append(e)
return into

def _get_tts_models() -> List[RVCTTSSynthesizer]:
r = []

models_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../rvc-tts-webui/weights')
models = [ f.name for f in os.scandir(models_folder) if f.is_dir() ]
print(f"[v] Found models in folder: {models}")

voices = _get_config_json()['voices']
for voice,data in voices.items():
model_name = data['model-name']
if model_name not in models:
raise ValueError(f"Model {model_name} not found in models folder")

r.append(RVCTTSSynthesizer(RVCModel(voice, model_name, data['model-voice'])))

return r

def _generate_voice_splits(segment: TTSSegment, models: List[RVCTTSSynthesizer]) -> List[TTSSegment]:
if not isinstance(segment, GeneratedTTSSegment):
raise ValueError("This function was mean to be used with a single text segment")

r = [ segment ]
prev = None

for model in models:
prev = r
r = []

regex_pattern = re.compile(r'(?: |^)' + re.escape(model.model.alias) + r': ') # find pattern (sanitized), followed by spaces or begin/end
for segment in prev:
split = re.split(regex_pattern, segment.text)

def instantiate(synthesizer: TTSSynthesizer = None) -> TwitchTTSBot:
more_split = split[0].strip()
if len(more_split) > 0:
r.append(GeneratedTTSSegment(text=more_split, synthesizer=segment.synthesizer)) # as it's in the left-side, use the previous synthesizer
for more_split in split[1:]: # skip the first (as we've already added it)
# as we're on the 2nd index (or higher), it did found a match in between
more_split = more_split.strip()
if len(more_split) > 0: # if it's at the end of the string it will produce an empty string
r.append(GeneratedTTSSegment(text=more_split, synthesizer=model))

return r

def instantiate(default_synthesizer: TTSSynthesizer = None) -> TwitchTTSBot:
queue_pre_inference = []
queue_post_inference = []

Expand All @@ -97,19 +139,39 @@ async def truncate_input(e: TTSQueueEntry):
return e

queue_pre_inference.append(pl.task.map(truncate_input))

models = _get_tts_models()

# change voices
async def voice_delimiter(e: TTSQueueEntry):
e.segments = _generate_voice_splits(e.segments[0], models)
return e
queue_pre_inference.append(pl.task.map(voice_delimiter))

# replace audios
audios = _get_audios()
if len(audios) > 0:
async def segment_input(e: TTSQueueEntry):
if len(e.segments) != 1 or not isinstance(e.segments[0], GeneratedTTSSegment):
raise ValueError("This code wasn't mean to be used with different than 1 text segment")
segments = e.segments[:]
e.segments.clear() # remove all previous segments

splits = _generate_splits(e.segments[0].text, list(audios.keys()))
_splits_to_segments(splits, audios, e.segments) # replace the segment for the new ones
for segment in segments:
if not isinstance(segment, GeneratedTTSSegment):
raise ValueError("This code wasn't mean to be used with different than text segments")

splits = _generate_splits(segment.text, list(audios.keys()))
e.segments += _splits_to_segments(splits, audios, segment) # replace the segment for the new ones
return e

queue_pre_inference.append(pl.task.map(segment_input))

segments_limit = _get_segments_limit()
if segments_limit is not None:
async def truncate_input(e: TTSQueueEntry):
e.segments = e.segments[:segments_limit]
return e

queue_pre_inference.append(pl.task.map(truncate_input))

# return the instance
return TwitchTTSBot.instance(WebServer(secret_token=_get_token()), synthesizer if synthesizer is not None else RVCTTSSynthesizer(model=_get_model_name()), queue_pre_inference=queue_pre_inference, queue_post_inference=queue_post_inference)
return TwitchTTSBot.instance(WebServer(secret_token=_get_token()), default_synthesizer if default_synthesizer is not None else models[0], queue_pre_inference=queue_pre_inference, queue_post_inference=queue_post_inference)
30 changes: 25 additions & 5 deletions TwitchTTSBot/synthesizers/rvc_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,37 @@
import asyncio
from .synthesizer import TTSSynthesizer

class RVCModel:
def __init__(self, alias: str, model_name: str, model_voice: str):
self._alias = alias
self._model_name = model_name
self._model_voice = model_voice

@property
def alias(self) -> str:
return self._alias

@property
def model_name(self) -> str:
return self._model_name

@property
def model_voice(self) -> str:
return self._model_voice

# It will copy an audio file each time `synthesize` is called
class RVCTTSSynthesizer(TTSSynthesizer):
def __init__(self, model: str):
def __init__(self, model: RVCModel):
self._model = model

@property
def model(self) -> RVCModel:
return self._model

async def synthesize(self, text: str, out: str):
# calling `infere` directly won't work for permissions issue (needs to be sudo)
python_full_path = sys.executable
infere_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../rvc-tts-webui/infere.py')
proc = await asyncio.create_subprocess_exec('sudo',python_full_path,infere_path, '--model', self._model, '--text', text, '--out', out,
proc = await asyncio.create_subprocess_exec('sudo',python_full_path,infere_path, '--model', self._model.model_name, '--text', text, '--out', out, '--voice', self._model.model_voice,
stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
stdout, stderr = await proc.communicate()


stdout, stderr = await proc.communicate()
16 changes: 16 additions & 0 deletions TwitchTTSBot/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,21 @@ def test_skip(self):
# don't stop until done
self.sleep(15) # TODO get when bot is done

def test_audios(self):
print("[v] Launching custom event (audio)")
data = PubSubData(BotTests._GetRedeem("[pop]"))
forward_event(Event.on_pubsub_custom_channel_point_reward, data, PubSubPointRedemption(data))

# don't stop until done
self.sleep(15) # TODO get when bot is done

def test_multiple_voices(self):
print("[v] Launching custom event (multiple voices)")
data = PubSubData(BotTests._GetRedeem("Does this work? glados: Yes, it seems to work just fine. nedia: Cool!"))
forward_event(Event.on_pubsub_custom_channel_point_reward, data, PubSubPointRedemption(data))

# don't stop until done
self.sleep(20) # TODO get when bot is done

if __name__ == '__main__':
unittest.main()
59 changes: 38 additions & 21 deletions TwitchTTSBot/tts_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from synthesizers.synthesizer import TTSSynthesizer
from pydub import AudioSegment
import shutil

import os
from pathlib import Path
Expand All @@ -20,9 +21,9 @@
import multiprocessing as mp

class TTSQueue:
def __init__(self, serve_to: AudioServer, synthesizer: TTSSynthesizer, pre_inference: List[Partial[pl.task.Stage[T]]] = None, post_inference: List[Partial[pl.task.Stage[T]]] = None):
def __init__(self, serve_to: AudioServer, default_synthesizer: TTSSynthesizer, pre_inference: List[Partial[pl.task.Stage[T]]] = None, post_inference: List[Partial[pl.task.Stage[T]]] = None):
self._serve_to = serve_to
self._synthesizer = synthesizer
self._default_synthesizer = default_synthesizer

self._audios_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../audio-server/audios/')

Expand Down Expand Up @@ -75,7 +76,7 @@ async def start_queue():
# TODO stop thread when closing

async def enqueue(self, requested_by: str, text: str):
e = TTSQueueEntry(requested_by, GeneratedTTSSegment(text))
e = TTSQueueEntry(requested_by, GeneratedTTSSegment(text, self._default_synthesizer))
self._processing_input.put(e)

# notify the tts_queue loop
Expand All @@ -87,20 +88,11 @@ async def __infere(self, e: TTSQueueEntry) -> TTSQueueEntry:
Infere TTS
"""
for index, segment in enumerate(e.segments):
if not segment.processing:
continue # already processed

if not isinstance(segment, GeneratedTTSSegment):
raise ValueError("Unrecognised class " + segment.__class__.__name__ + " and not ready to process.")

target_file = str(uuid.uuid4().hex) + '.wav'
target_path = os.path.join(self._audios_path, target_file)
print(f"[v] Synthesizing segment {index+1}/{len(e.segments)} into {target_file}...")

await self._synthesizer.synthesize(segment.text, target_path)
segment.path = target_path

print(f"[v] '{segment.text}' synthetized.")
await segment.generate(target_path)

return e

Expand Down Expand Up @@ -170,7 +162,7 @@ async def __ended_web_streaming(self):
class TTSQueueEntry:
def __init__(self, requested_by: str, *segments: Tuple[TTSSegment, ...], path: str = None):
self._requested_by = requested_by
self._segments = list(segments)
self.segments = list(segments)
self.path = path
self.invalidated = False

Expand All @@ -185,15 +177,19 @@ def path(self) -> str:
@path.setter
def path(self, path: str):
self._path = path

@property
def invalidated(self) -> bool:
return self._invalidated

@property
def segments(self) -> List[TTSSegment]:
return self._segments

@segments.setter
def segments(self, segments: List[TTSSegment]):
self._segments = segments

@property
def invalidated(self) -> bool:
return self._invalidated

@invalidated.setter
def invalidated(self, invalidated: bool):
self._invalidated = invalidated
Expand All @@ -207,7 +203,10 @@ def processing(self) -> bool:
return self._path is None

class TTSSegment:
def __init__(self, path: str = None):
def __init__(self):
self.path = None

async def generate(self, path: str):
self.path = path

@property
Expand All @@ -226,10 +225,24 @@ def file_name(self) -> str:
def processing(self) -> bool:
return self._path is None

class PregeneratedTTSSegment(TTSSegment):
def __init__(self, copy_from: str):
super().__init__()
self._copy_from = copy_from

async def generate(self, path: str):
await super().generate(path)
shutil.copyfile(self._copy_from, path)

class GeneratedTTSSegment(TTSSegment):
def __init__(self, text: str, path: str = None):
super().__init__(path)
def __init__(self, text: str, synthesizer: TTSSynthesizer):
super().__init__()
self._text = text
self._synthesizer = synthesizer

async def generate(self, path: str):
await super().generate(path)
await self._synthesizer.synthesize(self.text, self.path)

@property
def text(self) -> str:
Expand All @@ -238,3 +251,7 @@ def text(self) -> str:
@text.setter
def text(self, text: str):
self._text = text

@property
def synthesizer(self) -> TTSSynthesizer:
return self._synthesizer
3 changes: 2 additions & 1 deletion rvc-tts-webui/infere.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ def main():
argParser.add_argument("-m", "--model", required=True)
argParser.add_argument("-t", "--text", required=True)
argParser.add_argument("-o", "--out", default='out.wav')
argParser.add_argument("-v", "--voice", default='en-US-AriaNeural-Female')

args = argParser.parse_args()

infere(args.model, args.text, args.out, speaker='en-US-MichelleNeural-Female')
infere(args.model, args.text, args.out, speaker=args.voice)

if __name__ == '__main__':
main()

0 comments on commit c5ea4d1

Please sign in to comment.