import gradio as gr import torch import numpy as np from transformers import pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, MarianTokenizer, MarianMTModel from indic_transliteration import sanscript from indic_transliteration.sanscript import transliterate from transformers.models.whisper.english_normalizer import BasicTextNormalizer from datasets import load_dataset title = "Cascaded STST" description = """ Demo for cascaded speech-to-speech translation (STST), mapping from source speech in any language to target speech in Hindi. Demo uses OpenAI's [Whisper Base](https://huggingface.co/openai/whisper-base) model for speech translation to English, then MarianMT's [opus-mt-en-hi](https://huggingface.co/Helsinki-NLP/opus-mt-en-hi) model for translation to Hindi, and finally microsoft/speechT5 fine-tuned for Hindi on IndicTTS dataset for text-to-speech. [SpeechT5 TTS](https://huggingface.co/navodit17/speecht5_finetuned_indic_tts_hi) model for text-to-speech: ![Cascaded STST](https://huggingface.co/datasets/huggingface-course/audio-course-images/resolve/main/s2st_cascaded.png "Diagram of cascaded speech to speech translation") ### NOTE: The goal is not to generate perfect Hindi speech or translation, but to demonstrate the cascaded STST approach using multiple models. The model might give poor result for very short sentences (1-2 words or so). Try to send longer audio in that case. --- """ device = "cuda:0" if torch.cuda.is_available() else "cpu" print(f"device: {device}") # load speech translation checkpoint asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device) # load text-to-speech checkpoint and speaker embeddings processor = SpeechT5Processor.from_pretrained("navodit17/speecht5_finetuned_indic_tts_hi") model = SpeechT5ForTextToSpeech.from_pretrained("navodit17/speecht5_finetuned_indic_tts_hi").to(device) vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device) embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) # load english to hindi translation checkpoint tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi") model_en_hi = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-hi") normalizer = BasicTextNormalizer() def translate_en_hi(text): inputs = tokenizer(text, return_tensors="pt") outputs = model_en_hi.generate(**inputs, max_new_tokens=256) return tokenizer.decode(outputs[0], skip_special_tokens=True) def translate(audio): outputs = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"task": "translate"}) print(f"Translated text - English: {outputs['text']}") translated_text = translate_en_hi(outputs["text"]) print(f"Translated text - Hindi: {translated_text}") return translated_text def synthesise(text): text = normalizer(transliterate(text, sanscript.DEVANAGARI, sanscript.ITRANS)) print(f"Normalized Text: {text}") inputs = processor(text=text, return_tensors="pt") print(f"Inputs: {inputs['input_ids'].shape}") speech = model.generate_speech(input_ids=inputs["input_ids"].to(device), speaker_embeddings=speaker_embeddings.to(device), vocoder=vocoder) return speech.cpu() def speech_to_speech_translation(audio): translated_text = translate(audio) synthesised_speech = synthesise(translated_text) print(f"Generated speech shape: {synthesised_speech.shape}") synthesised_speech = (synthesised_speech.numpy() * 32767).astype(np.int16) return 16000, synthesised_speech demo = gr.Blocks() file_translate = gr.Interface( fn=speech_to_speech_translation, inputs=gr.Audio(sources="upload", type="filepath"), outputs=gr.Audio(label="Generated Speech", type="numpy"), title=title, description=description, ) mic_translate = gr.Interface( fn=speech_to_speech_translation, inputs=gr.Audio(sources="microphone", type="filepath"), outputs=gr.Audio(label="Generated Speech", type="numpy", ), title=title, description=description, ) with demo: gr.TabbedInterface([file_translate, mic_translate], ["Audio File", "Microphone"]) demo.launch()