Last active
June 23, 2023 19:47
-
-
Save Iamgoofball/18b7b01501d1a5b0d1276b2a7eb8675d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from TTS.api import TTS | |
| import os | |
| import signal | |
| import sys | |
| import io | |
| import time | |
| import json | |
| import gc | |
| import random | |
| import numpy as np | |
| import ffmpeg | |
| from typing import * | |
| from modules import models | |
| from modules.utils import load_audio | |
| from flask import Flask, request, send_file, abort, make_response | |
| from numpy import interp | |
| from pydub import AudioSegment | |
| from pydub.silence import split_on_silence, detect_leading_silence | |
| from fairseq import checkpoint_utils | |
| from fairseq.models.hubert.hubert import HubertModel | |
| from modules.shared import ROOT_DIR, device, is_half | |
| import requests | |
| import librosa | |
| import threading | |
| #os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'garbage_collection_threshold:0.4,max_split_size_mb:128' | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = '' | |
| #os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'backend:cudaMallocAsync' | |
| #os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1' | |
| tts = None | |
| vc_models = { | |
| "TGStation_Crepe_1.pth": "./models/checkpoints/speakers_tgstation_1.json", | |
| "TGStation_Crepe_2.pth": "./models/checkpoints/speakers_tgstation_2.json", | |
| "TGStation_Crepe_3.pth": "./models/checkpoints/speakers_tgstation_3.json", | |
| } | |
| trim_leading_silence = lambda x: x[detect_leading_silence(x) :] | |
| trim_trailing_silence = lambda x: trim_leading_silence(x.reverse()).reverse() | |
| strip_silence = lambda x: trim_trailing_silence(trim_leading_silence(x)) | |
| ttslock = threading.Lock() | |
| letters_to_use = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890" | |
| random_factor = 0.35 | |
| os.makedirs('samples', exist_ok=True) | |
| app = Flask(__name__) | |
| request_count = 0 | |
| tts_errors = 0 | |
| rvc_errors = 0 | |
| last_request_time = time.time() | |
| avg_request_time = 0 | |
| avg_tts_time = 0 | |
| avg_rvc_time = 0 | |
| timeofdeath = 0 | |
| voice_name_mapping = {} | |
| use_voice_name_mapping = True | |
| with open("./tts_voices_mapping.json", "r") as file: | |
| voice_name_mapping = json.load(file) | |
| if len(voice_name_mapping) == 0: | |
| use_voice_name_mapping = False | |
| voice_name_mapping_reversed = {v: k for k, v in voice_name_mapping.items()} | |
| def load_embedder(): | |
| global embedder_model, loaded_embedder_model | |
| emb_file = "./models/embeddings/hubert_base.pt" | |
| models, _, _ = checkpoint_utils.load_model_ensemble_and_task( | |
| [emb_file], | |
| suffix="", | |
| ) | |
| embedder_model = models[0] | |
| embedder_model = embedder_model.to(device) | |
| if is_half: | |
| embedder_model = embedder_model.half() | |
| else: | |
| embedder_model = embedder_model.float() | |
| embedder_model.eval() | |
| loaded_embedder_model = "hubert_base" | |
| return embedder_model | |
| loaded_models = [] | |
| embedder_model: Optional[HubertModel] = load_embedder() | |
| voice_lookup = {} | |
| for model in vc_models.keys(): | |
| print(model) | |
| voice_lookup[model] = json.load(open(vc_models[model], "r")) | |
| vc_model = models.get_vc_model(model) | |
| loaded_models.append(vc_model) | |
| print("Loaded model " + str(model)) | |
| #vc_model = models.get_vc_model(model_path) | |
| embedding_output_layer = 12 | |
| def tts_log(text): | |
| return | |
| with open('ttsd-gpu-runlog.txt', 'a') as f: | |
| f.write(f"{text}\n") | |
| def maketts(): | |
| tts_log("maketts()") | |
| return TTS(model_path = "E:/model_output_3/model_no_disc.pth", config_path = "E:/model_output_2/config.json", progress_bar=False, gpu=True) | |
| #return TTS("tts_models/en/vctk/vits", progress_bar=False, gpu=True) | |
| def mc_avg(old, new): | |
| if (old > 0): | |
| return (old*0.99) + (new*0.01) | |
| return new | |
| def two_way_round(number, ndigits = 0): | |
| number = round(number, ndigits) | |
| return (f"%0.{ndigits}f") % number | |
| def test_tts(): | |
| global tts | |
| tts_log("test_tts()") | |
| with ttslock: | |
| tts_log("test_tts():lock") | |
| if not tts: | |
| tts = maketts() | |
| with io.BytesIO() as data_bytes: | |
| with torch.inference_mode(): | |
| tts_log("do test_tts()") | |
| tts.tts_to_file(text="The quick brown fox jumps over the lazy dog.", speaker="maleadult01default", file_path=data_bytes) | |
| def readyup(): | |
| print("readyup()\n") | |
| tts_log("readyup()") | |
| tts_log("readyup(): watchdog") | |
| test_tts() | |
| gc_loop() | |
| tts_log("readyup(): getlock") | |
| with ttslock: | |
| tts_log("readyup(): gotlock, sending ready") | |
| tts_log("readyup(): ready sent") | |
| gc_loop() | |
| def gc_loop(): | |
| threading.Timer(60, gc_loop) | |
| #gc.collect() | |
| #torch.cuda.empty_cache() | |
| @app.route("/generate-tts") | |
| def text_to_speech(): | |
| global request_count, last_request_time, avg_request_time, avg_tts_time, avg_rvc_time, tts_errors, rvc_errors, tts | |
| if not tts: | |
| with ttslock: | |
| if not tts: | |
| tts = maketts() | |
| request_count += 1 | |
| tts_errors += 1 | |
| starttime = time.time() | |
| text = request.json.get("text", "") | |
| voice = request.json.get("voice", "") | |
| pitch_adjustment = request.json.get("pitch", "0") | |
| if use_voice_name_mapping: | |
| voice = voice_name_mapping_reversed[voice] | |
| speaker_id = "NO SPEAKER" | |
| model_to_use = None | |
| found_model = False | |
| for model in voice_lookup.keys(): | |
| speaker_list = voice_lookup[model] | |
| for speaker in speaker_list.keys(): | |
| if voice == speaker: | |
| speaker_id = speaker_list[speaker] | |
| found_model = True | |
| break | |
| if found_model: | |
| model_to_use = loaded_models[list(voice_lookup.keys()).index(model)] | |
| break | |
| if speaker_id == "NO SPEAKER" or model_to_use == None: | |
| abort(500) | |
| result = None | |
| with io.BytesIO() as data_bytes: | |
| with torch.inference_mode(): | |
| with ttslock: | |
| tts_starttime = time.time() | |
| with io.BytesIO() as tts_data_bytes: | |
| tts.tts_to_file(text=text, speaker=voice, file_path=tts_data_bytes) | |
| avg_tts_time = mc_avg(avg_tts_time, time.time()-tts_starttime) | |
| rvc_starttime = time.time() | |
| rvc_errors += 1 | |
| audio, _ = librosa.load(io.BytesIO(tts_data_bytes.getvalue()), sr=16000) | |
| audio_opt = model_to_use.vc( | |
| embedder_model, | |
| embedding_output_layer, | |
| model_to_use.net_g, | |
| speaker_id, | |
| audio, | |
| int(pitch_adjustment), | |
| "crepe", | |
| "", | |
| 0, | |
| model_to_use.state_dict.get("f0", 1), | |
| f0_file=None, | |
| ) | |
| AudioSegment(audio_opt, frame_rate=model_to_use.tgt_sr, sample_width=2, channels=1).export(data_bytes, format="wav") | |
| rvc_errors -= 1 | |
| avg_rvc_time = mc_avg(avg_rvc_time, time.time()-rvc_starttime) | |
| result = send_file(io.BytesIO(data_bytes.getvalue()), mimetype="audio/wav") | |
| #gc.collect() | |
| last_request_time = time.time() | |
| tts_errors -= 1 | |
| avg_request_time = mc_avg(avg_request_time, last_request_time-starttime) | |
| return result | |
| @app.route("/generate-tts-blips") | |
| def text_to_speech_blips(): | |
| global request_count, tts | |
| if not tts: | |
| with ttslock: | |
| if not tts: | |
| tts = maketts() | |
| text = request.json.get("text", "").upper() | |
| voice = request.json.get("voice", "") | |
| pitch_adjustment = request.json.get("pitch", "0") | |
| if use_voice_name_mapping: | |
| voice = voice_name_mapping_reversed[voice] | |
| result = None | |
| result_sound = AudioSegment.empty() | |
| if not os.path.exists('samples/v2/' + voice): | |
| os.makedirs('samples/v2/' + voice, exist_ok=True) | |
| with torch.inference_mode(): | |
| with ttslock: | |
| for i, value in enumerate(letters_to_use): | |
| tts.tts_to_file(text=value + ".", speaker=voice, file_path="samples/v2/" + voice + "/" + value + ".wav") | |
| sound = AudioSegment.from_file("samples/v2/" + voice + "/" + value + ".wav", format="wav") | |
| silenced_word = strip_silence(sound) | |
| silenced_word.export("samples/v2/" + voice + "/" + value + ".wav", format='wav') | |
| if not os.path.isdir("samples/v2/" + voice + "/pitch_" + pitch_adjustment): | |
| os.makedirs("samples/v2/" + voice + "/pitch_" + pitch_adjustment, exist_ok=True) | |
| for i, value in enumerate(letters_to_use): | |
| audio, _ = librosa.load("samples/v2/" + voice + "/" + letter + ".wav", 16000) | |
| audio_opt = model_to_use.vc( | |
| embedder_model, | |
| embedding_output_layer, | |
| model_to_use.net_g, | |
| speaker_id, | |
| audio, | |
| int(pitch_adjustment), | |
| "crepe", | |
| "", | |
| 0, | |
| model_to_use.state_dict.get("f0", 1), | |
| f0_file=None, | |
| ) | |
| output_sound = AudioSegment( | |
| audio_opt, | |
| frame_rate=model_to_use.tgt_sr, | |
| sample_width=2, | |
| channels=1, | |
| ) | |
| output_sound.export("samples/v2/" + voice + "/pitch_" + pitch_adjustment + "/" + letter + ".wav", format="wav") | |
| word_letter_count = 0 | |
| for i, letter in enumerate(text): | |
| if not letter.isalpha() or letter.isnumeric() or letter == " ": | |
| continue | |
| if letter == ' ': | |
| word_letter_count = 0 | |
| new_sound = letter_sound._spawn(b'\x00' * (40000 // 3), overrides={'frame_rate': 40000}) | |
| new_sound = new_sound.set_frame_rate(40000) | |
| else: | |
| if not word_letter_count % 2 == 0: | |
| word_letter_count += 1 | |
| continue # Skip every other letter | |
| if not os.path.isfile("samples/v2/" + voice + "/pitch_" + pitch_adjustment + "/" + letter + ".wav"): | |
| continue | |
| word_letter_count += 1 | |
| letter_sound = AudioSegment.from_file("samples/v2/" + voice + "/pitch_" + pitch_adjustment + "/" + letter + ".wav") | |
| raw = letter_sound.raw_data[5000:-5000] | |
| octaves = 1 + random.random() * random_factor | |
| frame_rate = int(letter_sound.frame_rate * (2.0 ** octaves)) | |
| new_sound = letter_sound._spawn(raw, overrides={'frame_rate': frame_rate}) | |
| new_sound = new_sound.set_frame_rate(40000) | |
| result_sound = new_sound if result_sound is None else result_sound + new_sound | |
| with io.BytesIO() as data_bytes: | |
| result_sound.export(data_bytes, format='wav') | |
| result = send_file(io.BytesIO(data_bytes.getvalue()), mimetype="audio/wav") | |
| request_count += 1 | |
| return result | |
| @app.route("/tts-voices") | |
| def voices_list(): | |
| global tts | |
| if not tts: | |
| with ttslock: | |
| if not tts: | |
| tts = maketts() | |
| #gc.collect() | |
| if use_voice_name_mapping: | |
| data = list(voice_name_mapping.values()) | |
| data.sort() | |
| return json.dumps(data) | |
| else: | |
| with ttslock: | |
| return json.dumps(tts.voices) | |
| @app.route("/health-check") | |
| def tts_health_check(): | |
| global request_count, last_request_time, timeofdeath, tts | |
| #gc.collect() | |
| if not tts: | |
| with ttslock: | |
| if not tts: | |
| tts = maketts() | |
| if timeofdeath > 0: | |
| return f"EXPIRED: count:{request_count}({tts_errors}) t:{avg_request_time}s last:{last_request_time}", 500 | |
| if ((request_count > 100) and (time.time() > last_request_time+60)) or (avg_request_time > 2) or tts_errors >= 5: | |
| timeofdeath = time.time() + 3 | |
| return f"EXPIRED: count:{request_count}({tts_errors}) t:{avg_request_time}s last:{last_request_time}", 500 | |
| if request_count > 4096: | |
| timeofdeath = time.time() + 3 | |
| return f"EXPIRED: count:{request_count}({tts_errors}) t:{avg_request_time}s last:{last_request_time}", 500 | |
| if (time.time() > last_request_time+(1*60*60)): | |
| timeofdeath = time.time() + 3 | |
| return f"EXPIRED: count:{request_count}({tts_errors}) t:{avg_request_time}s last:{last_request_time}", 500 | |
| if last_request_time < 1: | |
| request_count += 1 | |
| test_tts() | |
| last_request_time = time.time() | |
| return f"OK count:{request_count}({tts_errors}) t:{avg_request_time}s last:{last_request_time}", 200 | |
| @app.route("/pitch-available") | |
| def pitch_available(): | |
| return "Pitch available", 200 | |
| tts_log("START") | |
| readyup() | |
| if __name__ == "__main__": | |
| if os.getenv('TTS_LD_LIBRARY_PATH', "") != "": | |
| os.putenv('LD_LIBRARY_PATH', os.getenv('TTS_LD_LIBRARY_PATH')) | |
| from waitress import serve | |
| serve(app, host="localhost", port=5003, threads=2, backlog=2, connection_limit=128, cleanup_interval=1, channel_timeout=10) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment