import numpy as np

import streamlit as st
import librosa
import soundfile as sf
import librosa.display
from config import CONFIG
import torch
from dataset import MaskGenerator
import onnxruntime, onnx
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from pystoi import stoi
from pesq import pesq
import pandas as pd
import torchaudio


from torchmetrics.audio import ShortTimeObjectiveIntelligibility as STOI
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality as PESQ


from PLCMOS.plc_mos import PLCMOSEstimator
from speechmos import dnsmos
from speechmos import plcmos

import speech_recognition as speech_r
from jiwer import wer
import time

@st.cache
def load_model(model):
    path = 'lightning_logs/version_0/checkpoints/' + str(model)
    onnx_model = onnx.load(path)
    options = onnxruntime.SessionOptions()
    options.intra_op_num_threads = 2
    options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
    session = onnxruntime.InferenceSession(path, options)
    input_names = [x.name for x in session.get_inputs()]
    output_names = [x.name for x in session.get_outputs()]
    return session, onnx_model, input_names, output_names

def inference(re_im, session, onnx_model, input_names, output_names):
    inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim],
                                       dtype=np.float32)
              for i, _input in enumerate(onnx_model.graph.input)
              }

    output_audio = []
    for t in range(re_im.shape[0]):
        inputs[input_names[0]] = re_im[t]
        out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs)
        inputs[input_names[1]] = prev_mag
        inputs[input_names[2]] = predictor_state
        inputs[input_names[3]] = mlp_state
        output_audio.append(out)

    output_audio = torch.tensor(np.concatenate(output_audio, 0))
    output_audio = output_audio.permute(1, 0, 2).contiguous()
    output_audio = torch.view_as_complex(output_audio)
    output_audio = torch.istft(output_audio, window, stride, window=hann)
    return output_audio.numpy()

def visualize(hr, lr, recon, sr):
    sr = sr
    window_size = 1024
    window = np.hanning(window_size)

    stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window)
    stft_hr = 2 * np.abs(stft_hr) / np.sum(window)

    stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window)
    stft_lr = 2 * np.abs(stft_lr) / np.sum(window)

    stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window)
    stft_recon = 2 * np.abs(stft_recon) / np.sum(window)

    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 12))
    ax1.title.set_text('Оригинальный сигнал')
    ax2.title.set_text('Сигнал с потерями')
    ax3.title.set_text('Улучшенный сигнал')

    canvas = FigureCanvas(fig)
    p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='log', x_axis='time', sr=sr)
    p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='log', x_axis='time', sr=sr)
    p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='log', x_axis='time', sr=sr)

    ax1.set_xlabel('Время, с')
    ax1.set_ylabel('Частота, Гц')
    ax2.set_xlabel('Время, с')
    ax2.set_ylabel('Частота, Гц')
    ax3.set_xlabel('Время, с')
    ax3.set_ylabel('Частота, Гц')
    return fig



def waveplot(hr, lr, recon, sr):
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 12))
    ax1.title.set_text('Оригинальный сигнал')
    ax2.title.set_text('Сигнал с потерями')
    ax3.title.set_text('Улучшенный сигнал')

    canvas = FigureCanvas(fig)
    p = librosa.display.waveshow(hr, ax=ax1, sr=sr)
    p = librosa.display.waveshow(lr, ax=ax2, sr=sr)
    p = librosa.display.waveshow(recon, ax=ax3, sr=sr)

    ax1.set_xlabel('Время, с')
    #ax1.set_ylabel('Частота, Гц')
    ax2.set_xlabel('Время, с')
    #ax2.set_ylabel('Частота, Гц')
    ax3.set_xlabel('Время, с')
    #ax3.set_ylabel('Частота, Гц')
    return fig

def sign_x_y(x,y):
    if x>y:
        return '-'
    else:
        return ''

packet_size = CONFIG.DATA.EVAL.packet_size
window = CONFIG.DATA.window_size
stride = CONFIG.DATA.stride

title = 'Маскировка потерь пакетов'
st.set_page_config(page_title=title, page_icon=":sound:")
st.title(title)

st.subheader('1. Загрузка аудио')
uploaded_file = st.file_uploader("Загрузите аудио формата (.wav) 48 КГц")

is_file_uploaded = uploaded_file is not None
if not is_file_uploaded:
    uploaded_file = 'sample.wav'

target, sr = librosa.load(uploaded_file, sr=48000)
target = target[:packet_size * (len(target) // packet_size)]

st.text('Ваше аудио')
st.audio(uploaded_file)

model_ver = st.selectbox(
     'Веса оригинальной модели выбраны по умолчанию. Выберите модель',
     ('frn.onnx', 'frn_out_QInt16.onnx', 'frn_out_QInt8.onnx', 'frn_out_QUInt8.onnx', 'frn_out_QUInt16.onnx', 'frn_fp16.onnx'))

st.write('Вы выбрали:', model_ver)

lang = st.selectbox(
     'Выберите язык вашего аудио для корректной работы распознавания речи',
     ('ru-RU', 'en-EN'))

st.write('Вы выбрали:', lang)


st.subheader('2. Выберите желаемый процент потерь')
slider = [st.slider("Ожидаемый процент потерь для генератора потерь цепи Маркова", 0, 100, step=1)]
loss_percent = float(slider[0])/100
mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)])
lossy_input = target.copy().reshape(-1, packet_size)
mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis]
lossy_input *= mask
lossy_input = lossy_input.reshape(-1)
hann = torch.sqrt(torch.hann_window(window))
lossy_input_tensor = torch.tensor(lossy_input)
re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze(1).numpy().astype(np.float32)

session, onnx_model, input_names, output_names = load_model(model_ver)

with st.sidebar:
    st.title('Full-band Reccurent Network', help = 'https://arxiv.org/abs/2211.04071')
    authors_c = st.container()
    authors_c.write('Авторы модели: Viet-Anh Nguyen and Anh H. T. Nguyen and Andy W. H. Khong')
    st.link_button("Github авторов", "https://github.com/Crystalsound/FRN", help = 'Кликни на меня')
    description_c = st.container()
    description_c.write("Это дополненный space оригинальной FRN модели. К исходной сети были применены методы квантования onnxruntime для уменьшения размера .onnx файла и повышения скорости обработки аудио при некотором ухудшении результата. В этом space вы можете сгенерировать потери пакетов и оценить работу модели визуально, на слух, и по нескольким метрикам.")
    st.header("Packet Loss Concealment", help = 'https://arxiv.org/abs/2204.05222')
    PLC_c = st.container()
    PLC_c.write("PLC (Packet Loss Concealment) - это технологии, созданные для борьбы с потерей пакетов при передаче речи в IP сети. Для ознакомления с данной темой рекомендуется статья INTERSPEECH 2022 Audio Deep Packet Loss Concealment Challenge с результатами одноимённого конкурса.")
    st.header("Метрики")
    Metrcs_c = st.container()
    Metrcs_c.write("Для оценивания речи были выбраны следующие метрики: PESQ, STOI, PLCMOS разных версий и WER. С каждой из них вы можете ознакомиться, перейдя по ссылке рядом с заголовком.")
    st.subheader("PESQ", help = 'https://ieeexplore.ieee.org/document/941023')
    st.write('Перцептивная оценка качества речи')
    st.subheader("STOI", help = 'https://ieeexplore.ieee.org/document/5495701')
    st.write('Индекс объективной кратковременной разборчивости')
    st.subheader("PLCMOS", help = 'https://arxiv.org/abs/2305.15127')
    PLCMOS_c=st.container()
    PLCMOS_c.write("Использованы две версии данной метрики (v1, v2). v1 - это первая версия, разработанная для INTERSPEECH 2022 Audio Deep Packet Loss Concealment Challenge. v2 - улучшенная версия метрики, вышедшая в 2023 году. Особенность - неэталонная метрика, которая выдаёт оценку, опираясь только на аудио с потерями без использования информации о исходном (оригинальном). Поставляется как часть пакета speechmos.")
    st.subheader("WAcc", help = 'https://docs.speechmatics.com/tutorials/accuracy-benchmarking')
    WAcc_c=st.container()
    WAcc_c.write('Первоначально использовалась метрика WER (Word Error Rate). Она выражает долю ошибочно распознанных слов. Я считаю, что для восприятия будет проще обратная ей - WAcc (Word Accuracy), то есть доля слов, которые распознаны верно. Для распознавания используется пакет jiwer')

if st.button('Сгенерировать потери'):
    with st.spinner('Ожидайте...'):
        start_time = time.time()
        output = inference(re_im, session, onnx_model, input_names, output_names)
        st.text(str(time.time() - start_time))
        st.subheader('3. Визуализация аудио')
        fig_1 = visualize(target, lossy_input, output, sr)
        fig_2 = waveplot(target, lossy_input, output, sr)
        tab1, tab2 = st.tabs(["Частотная область", "Временная область"])

        with tab1:
           st.header("Частотная область")
           st.pyplot(fig_1)

        with tab2:
           st.header("Временная область")
           st.pyplot(fig_2)

    #st.success('Сделано!')
    sf.write('target.wav', target, sr)
    sf.write('lossy.wav', lossy_input, sr)
    sf.write('enhanced.wav', output, sr)
    st.text('Оригинальное аудио')
    st.audio('target.wav')
    st.text('Аудио с потерями')
    st.audio('lossy.wav')
    st.text('Улучшенное аудио')
    st.audio('enhanced.wav')






    #data_clean, samplerate = torchaudio.load('target.wav')
    #data_lossy, samplerate = torchaudio.load('lossy.wav')
    #data_enhanced, samplerate = torchaudio.load('enhanced.wav')

    #min_len = min(data_clean.shape[1], data_lossy.shape[1], data_enhanced.shape[1])
    #data_clean = data_clean[:, :min_len]
    #data_lossy = data_lossy[:, :min_len]
    #data_enhanced = data_enhanced[:, :min_len]


    #stoi = STOI(samplerate)

    #stoi_orig = round(float(stoi(data_clean, data_clean)),3)
    #stoi_lossy = round(float(stoi(data_clean, data_lossy)),5)
    #stoi_enhanced = round(float(stoi(data_clean, data_enhanced)),5)

    #stoi_mass=[stoi_orig, stoi_lossy, stoi_enhanced]


    #pesq = PESQ(8000, 'nb')

    #data_clean = data_clean.cpu().numpy()
    #data_lossy = data_lossy.cpu().numpy()
    #data_enhanced = data_enhanced.cpu().numpy()

    #if samplerate != 8000:
        #data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=8000)
        #data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=8000)
        #data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=8000)

    #pesq_orig = float(pesq(torch.tensor(data_clean), torch.tensor(data_clean)))
    #pesq_lossy = float(pesq(torch.tensor(data_lossy), torch.tensor(data_clean)))
    #pesq_enhanced = float(pesq(torch.tensor(data_enhanced), torch.tensor(data_clean)))

    #psq_mas=[pesq_orig, pesq_lossy, pesq_enhanced]





    #_____________________________________________
    data_clean, samplerate = sf.read('target.wav')
    data_lossy, samplerate = sf.read('lossy.wav')
    data_enhanced, samplerate = sf.read('enhanced.wav')
    min_len = min(data_clean.shape[0], data_lossy.shape[0], data_enhanced.shape[0])
    data_clean = data_clean[:min_len]
    data_lossy = data_lossy[:min_len]
    data_enhanced = data_enhanced[:min_len]


    stoi_orig = round(stoi(data_clean, data_clean, samplerate, extended=False),5)
    stoi_lossy  = round(stoi(data_clean, data_lossy , samplerate, extended=False),5)
    stoi_enhanced = round(stoi(data_clean, data_enhanced, samplerate, extended=False),5)
    
    stoi_mass=[stoi_orig, stoi_lossy, stoi_enhanced]


    #def get_power(x, nfft):
    #    S = librosa.stft(x, n_fft=nfft)
    #    S = np.log(np.abs(S) ** 2 + 1e-8)
    #    return S
    #def LSD(x_hr, x_pr):
    #    S1 = get_power(x_hr, nfft=2048)
    #    S2 = get_power(x_pr, nfft=2048)
    #    lsd = np.mean(np.sqrt(np.mean((S1 - S2) ** 2, axis=-1)), axis=0)
    #    return lsd

    #lsd_orig = LSD(data_clean,data_clean)
    #lsd_lossy = LSD(data_lossy,data_clean)
    #lsd_enhanced = LSD(data_enhanced,data_clean)

    #lsd_mass=[lsd_orig, lsd_lossy, lsd_enhanced]
        
    if samplerate != 16000:
        data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=16000)
        data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=16000)
        data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=16000)
    


    pesq_orig = pesq(fs = 16000, ref = data_clean, deg = data_clean, mode='wb')
    pesq_lossy = pesq(fs = 16000, ref = data_clean, deg = data_lossy, mode='wb')
    pesq_enhanced = pesq(fs = 16000, ref = data_clean, deg = data_enhanced, mode='wb')

    psq_mas=[pesq_orig, pesq_lossy, pesq_enhanced]

    

    data_clean, fs = sf.read('target.wav')
    data_lossy, fs = sf.read('lossy.wav')
    data_enhanced, fs = sf.read('enhanced.wav')

    if fs!= 16000:
        data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=16000)
        data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=16000)
        data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=16000)

    PLC_example=PLCMOSEstimator()
    PLC_org = PLC_example.run(audio_degraded=data_clean, audio_clean=data_clean)[0]
    PLC_lossy = PLC_example.run(audio_degraded=data_lossy, audio_clean=data_clean)[0]
    PLC_enhanced = PLC_example.run(audio_degraded=data_enhanced, audio_clean=data_clean)[0]

    PLC_massv1 = [PLC_org, PLC_lossy, PLC_enhanced]


    
    df_1 = pd.DataFrame(columns=['Audio', 'PESQ', 'STOI', 'PLCMOSv1'])

    df_1['Аудио'] = ['Оригинал', 'С потерями', 'Улучшенное']

    df_1['PESQ'] = psq_mas

    df_1['STOI'] = stoi_mass

    #df['LSD'] = lsd_mass
    df_1['PLCMOSv1'] = PLC_massv1
    #new_columns = pd.MultiIndex.from_tuples([('', 'Audio'), ('Эталонные метрики', 'PESQ'), ('Эталонные метрики', 'STOI'), ('Эталонные метрики', 'PLCMOSv1')])
                          
    # Присваиваем новый мультииндекс столбцам
    #df_1.columns = new_columns


    PLC_massv2 = [plcmos.run("target.wav", sr=16000)['plcmos'], plcmos.run("lossy.wav", sr=16000)['plcmos'], plcmos.run("enhanced.wav", sr=16000)['plcmos']]

    #DNS = [dnsmos.run("target.wav", sr=16000)['ovrl_mos'], dnsmos.run("lossy.wav", sr=16000)['ovrl_mos'], dnsmos.run("enhanced.wav", sr=16000)['ovrl_mos']]

    df_1['PLCMOSv2'] = PLC_massv2
    #df_1['DNSMOS'] = DNS
    

    #df_2 = pd.DataFrame(columns=['DNSMOS', 'PLCMOSv2'])

    #df_2['DNSMOS'] = DNS

    #df_2['PLCMOSv2'] = PLC_massv2

    #new_columns = pd.MultiIndex.from_tuples([('Неэталонные метрики', 'DNSMOS'), ('Неэталонные метрики', 'PLCMOSv2')])
                          
    # Присваиваем новый мультииндекс столбцам
    #df_2.columns = new_columns
    #df_merged = df_1.merge(df_2, left_index=True, right_index=True)


    r = speech_r.Recognizer()



    
    harvard = speech_r.AudioFile('target.wav')
    with harvard as source:
        audio = r.record(source)

    orig = r.recognize_google(audio, language = str(lang))





    
    harvard = speech_r.AudioFile('lossy.wav')
    #with harvard as source:
    #    audio = r.record(source)
    #lossy = r.recognize_google(audio, language = "ru-RU")

    try:
        with harvard as source:
            audio = r.record(source)
        lossy = r.recognize_google(audio, language = str(lang))
        #print("Распознанный текст:", text)
    except speech_r.UnknownValueError:
        #st.text("Система не смогла распознать аудио")
        lossy = ''
    #except speech_r.RequestError as e:
        #st.text("Ошибка при запросе к сервису распознавания речи; {0}".format(e))




    
    harvard = speech_r.AudioFile('enhanced.wav')
    #with harvard as source:
    #    audio = r.record(source)
    #enhanced = r.recognize_google(audio, language = "ru-RU")

    try:
        with harvard as source:
            audio = r.record(source)
        enhanced = r.recognize_google(audio, language = str(lang))
        #print("Распознанный текст:", text)
    except speech_r.UnknownValueError:
        #st.text("Система не смогла распознать улучшенное аудио")
        enhanced = ''
    #except speech_r.RequestError as e:
        #st.text("Ошибка при запросе к сервису распознавания речи; {0}".format(e))
        
    error1 = wer(orig, orig)
    error2 = wer(orig, lossy)
    error3 = wer(orig, enhanced)
    WAcc_mass=[(1-error1)*100, (1-error2)*100, (1-error3)*100]

    df_1['WAcc'] = WAcc_mass

    st.subheader('4. Метрики аудио')
    #st.dataframe(df_1)
    st.write("#### "+"Оригинал")
    col1, col2, col3, col4, col5 = st.columns(5)
    col1.metric("PESQ", value = round(psq_mas[0],3))
    col2.metric("STOI", value = round(stoi_mass[0],3))
    col3.metric("PLCMOSv1", value = round(PLC_massv1[0],3))
    col4.metric("PLCMOSv2", value = round(PLC_massv2[0],3))
    col5.metric("WAcc", value = round(WAcc_mass[0],3))


    st.write("#### "+"С потерями")
    col1, col2, col3, col4, col5 = st.columns(5)
    col1.metric("PESQ", value = round(psq_mas[1],3), delta = str(round(-(abs(psq_mas[1] - psq_mas[0]) / psq_mas[0]) * 100.0,3))+'%')
    col2.metric("STOI", value = round(stoi_mass[1],3), delta = str(round(-(abs(stoi_mass[1] - stoi_mass[0]) / stoi_mass[0]) * 100.0,3))+'%')
    col3.metric("PLCMOSv1", value = round(PLC_massv1[1],3), delta = str(round(-(abs(PLC_massv1[1] - PLC_massv1[0]) / PLC_massv1[0]) * 100.0,3))+'%')
    col4.metric("PLCMOSv2", value = round(PLC_massv2[1],3), delta = str(round(-(abs(PLC_massv2[1] - PLC_massv2[0]) / PLC_massv2[0]) * 100.0,3))+'%')
    col5.metric("WAcc", value = round(WAcc_mass[1],3), delta = str(round(-(abs(WAcc_mass[1] - WAcc_mass[0]) / WAcc_mass[0]) * 100.0,3))+'%')

    
    st.write("#### "+"Улучшенное")
    col1, col2, col3, col4, col5 = st.columns(5)
    PESQ_s = sign_x_y(psq_mas[1], psq_mas[2])
    col1.metric("PESQ", value = round(psq_mas[2],3), delta = PESQ_s + str(round((abs(psq_mas[2] - psq_mas[1]) / psq_mas[1]) * 100.0,3))+'%')
    STOI_s = sign_x_y(stoi_mass[1], stoi_mass[2])
    col2.metric("STOI", value = round(stoi_mass[2],3), delta = STOI_s + str(round((abs(stoi_mass[2] - stoi_mass[1]) / stoi_mass[1]) * 100.0,3))+'%')
    PLCv1_s = sign_x_y(PLC_massv1[1], PLC_massv1[2])
    col3.metric("PLCMOSv1", value = round(PLC_massv1[2],3), delta = PLCv1_s + str(round((abs(PLC_massv1[2] - PLC_massv1[1]) / PLC_massv1[1]) * 100.0,3))+'%')
    PLCv2_s = sign_x_y(PLC_massv2[1], PLC_massv2[2])
    col4.metric("PLCMOSv2", value = round(PLC_massv2[2],3), delta = PLCv2_s + str(round((abs(PLC_massv2[2] - PLC_massv2[1]) / PLC_massv2[1]) * 100.0,3))+'%')
    WER_s = sign_x_y(WAcc_mass[1], WAcc_mass[2])
    if WAcc_mass[1]==0:
        if WAcc_mass[2]!=0:
            col5.metric("WAcc", value = round(WAcc_mass[2],3), delta = WER_s + str(round((abs(WAcc_mass[2] - 0.001) / 0.001) * 100.0,3))+'%')
        else:
            col5.metric("WAcc", value = round(WAcc_mass[2],3))
    else:
        col5.metric("WAcc", value = round(WAcc_mass[2],3), delta = WER_s + str(round((abs(WAcc_mass[2] - WAcc_mass[1]) / WAcc_mass[1]) * 100.0,3))+'%')

    tab1, tab2, tab3, tab4, tab5 = st.tabs(["PESQ", "STOI", "PLCMOSv1", "PLCMOSv2", "WAcc"])

    with tab1:
        st.header("PESQ")
        st.bar_chart(df_1, x="Аудио", y="PESQ")
    with tab2:
        st.header("STOI")
        st.bar_chart(df_1, x="Аудио", y="STOI")
    with tab3:
        st.header("PLCMOSv1")
        st.bar_chart(df_1, x="Аудио", y="PLCMOSv1")
    with tab4:
        st.header("PLCMOSv2")
        st.bar_chart(df_1, x="Аудио", y="PLCMOSv2")
    with tab5:
        st.header("WAcc")
        st.bar_chart(df_1, x="Аудио", y="WAcc")
    #st.bar_chart(df_1, x="Audio", y="PESQ")
    #st.bar_chart(df_1, x="Audio", y="STOI")
    #st.bar_chart(df_1, x="Audio", y="PLCMOSv1")
    #st.bar_chart(df_1, x="Audio", y="PLCMOSv2")
    #st.bar_chart(df_1, x="Audio", y="WER")



    #col1.metric("PESQ", value = psq_mas[-1], delta = psq_mas[-1] - psq_mas[-2])
    #col2.metric("STOI", value = stoi_mass[-1], delta = stoi_mass[-1] - stoi_mass[-2])
    #col3.metric("PLCMOSv1", value = PLC_massv1[-1], delta = PLC_massv1[-1] - PLC_massv1[-2])
    #col4.metric("PLCMOSv2", value = PLC_massv2[-1], delta = PLC_massv2[-1] - PLC_massv2[-2])
    #col5.metric("WER", value = WER_mass[-1], delta = WER_mass[-1] - WER_mass[-2], delta_color="inverse")