from nemo.collections.asr.models import EncDecCTCModelBPE
from omegaconf import open_dict
#import yt_dlp as youtube_dl
import os
import tempfile
import torch
import gradio as gr
from pydub import AudioSegment
import time

device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME="ayymen/stt_zgh_fastconformer_ctc_small"
YT_LENGTH_LIMIT_S=3600

model = EncDecCTCModelBPE.from_pretrained(model_name=MODEL_NAME).to(device)

with open_dict(model.cfg):
    model.cfg.decoding.strategy = "beam"
    model.cfg.decoding.beam.beam_size = 256 # Desired Beam Size
    model.cfg.decoding.beam.beam_alpha = 1.5 # Desired Beam Alpha
    model.cfg.decoding.beam.beam_beta = 1.5 # Desired Beam Beta
    model.cfg.decoding.beam.kenlm_path = "kenlm.bin" # Path to KenLM binary file

model.change_decoding_strategy(model.cfg.decoding)

model.eval()

def get_transcripts(audio_path):
    audio = AudioSegment.from_file(audio_path)
    # check if audio is mono 16kHz
    if audio.channels != 1 or audio.frame_rate != 16000:
        audio = audio.set_channels(1).set_frame_rate(16000) # convert to mono 16kHz
        with tempfile.TemporaryDirectory() as tmpdirname:
            audio_path = os.path.join(tmpdirname, "audio.wav")
            audio.export(audio_path, format="wav")
            text = model.transcribe([audio_path])[0]
    else:
        text = model.transcribe([audio_path])[0]
    return text

'''
article = (
    "<p style='text-align: center'>"
    "<a href='https://huggingface.co/nvidia/parakeet-rnnt-1.1b' target='_blank'>🎙️ Learn more about Parakeet model</a> | "
    "<a href='https://arxiv.org/abs/2305.05084' target='_blank'>📚 FastConformer paper</a> | "
    "<a href='https://github.com/NVIDIA/NeMo' target='_blank'>🧑‍💻 Repository</a>"
    "</p>"
)
'''

EXAMPLES = [
    ["135.wav"],
    ["common_voice_zgh_37837257.mp3"]
]

"""
YT_EXAMPLES = [
    ["https://www.youtube.com/shorts/CSgTSE50MHY"],
    ["https://www.youtube.com/shorts/OxQtqOyAFLE"]
]
"""

def _return_yt_html_embed(yt_url):
    video_id = yt_url.split("?v=")[-1]
    if "youtube.com/shorts/" in video_id:
        video_id = video_id.split("/")[-1]
    HTML_str = (
        f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
        " </center>"
    )
    return HTML_str

def download_yt_audio(yt_url, filename):
    info_loader = youtube_dl.YoutubeDL()
    
    try:
        info = info_loader.extract_info(yt_url, download=False)
    except youtube_dl.utils.DownloadError as err:
        raise gr.Error(str(err))
    
    file_length = info["duration_string"]
    file_h_m_s = file_length.split(":")
    file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
    
    if len(file_h_m_s) == 1:
        file_h_m_s.insert(0, 0)
    if len(file_h_m_s) == 2:
        file_h_m_s.insert(0, 0)
    file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
    
    if file_length_s > YT_LENGTH_LIMIT_S:
        yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
        file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
        raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
    
    ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
    
    with youtube_dl.YoutubeDL(ydl_opts) as ydl:
        try:
            ydl.download([yt_url])
        except youtube_dl.utils.ExtractorError as err:
            raise gr.Error(str(err))


def yt_transcribe(yt_url, max_filesize=75.0):
    html_embed_str = _return_yt_html_embed(yt_url)

    with tempfile.TemporaryDirectory() as tmpdirname:
        filepath = os.path.join(tmpdirname, "video.mp4")
        download_yt_audio(yt_url, filepath)
        audio = AudioSegment.from_file(filepath)
        audio = audio.set_channels(1).set_frame_rate(16000) # convert to mono 16kHz
        wav_filepath = os.path.join(tmpdirname, "audio.wav")
        audio.export(wav_filepath, format="wav")
        text = get_transcripts(wav_filepath)

    return html_embed_str, text


demo = gr.Blocks()

mf_transcribe = gr.Interface(
    fn=get_transcripts,
    inputs=[
        gr.Audio(sources="microphone", type="filepath")
    ],
    outputs="text",
    title="Transcribe Audio",
    description=(
        "Transcribe microphone or audio inputs with the click of a button! Demo uses the"
        f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and [NVIDIA NeMo](https://github.com/NVIDIA/NeMo) to transcribe audio files"
        " of arbitrary length."
    ),
    allow_flagging="never",
)

file_transcribe = gr.Interface(
    fn=get_transcripts,
    inputs=[
        gr.Audio(sources="upload", type="filepath", label="Audio file"),
    ],
    outputs="text",
    examples=EXAMPLES,
    title="Transcribe Audio",
    description=(
        "Transcribe microphone or audio inputs with the click of a button! Demo uses the"
        f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and [NVIDIA NeMo](https://github.com/NVIDIA/NeMo) to transcribe audio files"
        " of arbitrary length."
    ),
    allow_flagging="never",
)

"""
youtube_transcribe = gr.Interface(
    fn=yt_transcribe,
    inputs=[
        gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
    ],
    outputs=["html", "text"],
    examples=YT_EXAMPLES,
    title="Transcribe Audio",
    description=(
        "Transcribe microphone or audio inputs with the click of a button! Demo uses the"
        f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and [NVIDIA NeMo](https://github.com/NVIDIA/NeMo) to transcribe audio files"
        " of arbitrary length."
    ),
    allow_flagging="never",
)
"""

with demo:
    gr.TabbedInterface(
        [
            mf_transcribe,
            file_transcribe,
            #youtube_transcribe
        ],
        [
            "Microphone",
            "Audio file",
            #"Youtube Video"
        ]
    )

demo.launch(server_name="0.0.0.0")