import gradio as gr
import numpy as np
import pandas as pd
import torch
import torchaudio
from datetime import datetime
from lang_id import identify_languages
from whisper import transcribe

# アプリケーションの状態を保持する変数
data = []
data_df = pd.DataFrame()
current_chunk = []

SAMPLING_RATE = 16000
CHUNK_DURATION = 5  # 初期値としての5秒


def normalize_audio(audio):
    # 音量の正規化(最大振幅が1になるようにスケーリング)
    audio = audio / np.max(np.abs(audio))
    return audio


def resample_audio(audio, orig_sr, target_sr=16000):
    if orig_sr != target_sr:
        print(f"Resampling audio from {orig_sr} to {target_sr}")
        audio = audio.astype(np.float32)
        resampler = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)
        audio = resampler(torch.from_numpy(audio).unsqueeze(0)).squeeze(0).numpy()
    return audio


def process_chunk(chunk, language_set) -> pd.DataFrame:
    print(f"Processing audio chunk of length {len(chunk)}")
    rms = np.sqrt(np.mean(chunk**2))
    db_level = 20 * np.log10(rms + 1e-9)  # 加えた小さな値で-inf値を防ぐ

    # 音量の正規化
    chunk = normalize_audio(chunk)

    length = len(chunk) / SAMPLING_RATE  # 音声データの長さ(秒)
    s = datetime.now()
    selected_scores, all_scores = identify_languages(chunk, language_set)
    lang_id_time = (datetime.now() - s).total_seconds()

    # 日本語と英語の確率値を取得
    ja_prob = selected_scores['Japanese']
    en_prob = selected_scores['English']

    ja_en = 'ja' if ja_prob > en_prob else 'en'

    # Top 3言語を取得
    top3_languages = ", ".join([f"{lang} ({all_scores[lang]:.2f})" for lang in sorted(all_scores, key=all_scores.get, reverse=True)[:3]])

    # テキストの認識
    s = datetime.now()
    transcription = transcribe(chunk, language=ja_en)
    transcribe_time = (datetime.now() - s).total_seconds()

    return pd.DataFrame({
        "Length (s)": [length],
        "db_level": [db_level],
        "Japanese_English": [f"{ja_en} ({ja_prob:.2f}, {en_prob:.2f})"] if db_level > 50 else ["Silent"],
        "Language": [top3_languages],
        "Lang ID Time": [lang_id_time],
        "Transcribe Time": [transcribe_time],
        "Text": [transcription],
    })


def process_audio_stream(audio, chunk_duration, language_set):
    global data_df, current_chunk, SAMPLING_RATE
    print("Process_audio_stream")

    if audio is None:
        return None, data_df

    sr, audio_data = audio

    # language_set
    language_set = [lang.strip() for lang in language_set.split(",")]
    print(audio_data.shape, audio_data.dtype)
    # 一番最初にSampling rateを揃えておく
    audio_data = resample_audio(audio_data, sr, target_sr=SAMPLING_RATE)
    audio_sec = 0

    current_chunk.append(audio_data)

    total_chunk = np.concatenate(current_chunk)

    # CHUNK_DURATIONを超えていたら処理
    if len(total_chunk) >= SAMPLING_RATE * chunk_duration:
        chunk = total_chunk[:SAMPLING_RATE * chunk_duration]
        total_chunk = total_chunk[SAMPLING_RATE * chunk_duration:]
        audio_sec += chunk_duration

        # Check if the audio in the window is too quiet
        # rms = np.sqrt(np.mean(chunk**2))
        # db_level = 20 * np.log10(rms + 1e-9)  # 加えた小さな値で-inf値を防ぐ
        # print(db_level)

        df = process_chunk(chunk, language_set)
        # add db_level
        # df["dB Level"] = db_level
        data_df = pd.concat([data_df, df], ignore_index=True)

        current_chunk = [total_chunk]
        return (SAMPLING_RATE, chunk), data_df
    else:
        return (SAMPLING_RATE, total_chunk), data_df


def process_audio(audio, chunk_duration, language_set):
    global data, data_df, current_chunk, SAMPLING_RATE
    # reset state
    data = []
    data_df = pd.DataFrame()
    current_chunk = []

    print("Process_audio")
    print(audio)
    if audio is None:
        return

    sr, audio_data = audio

    # language_set
    language_set = [lang.strip() for lang in language_set.split(",")]

    print(audio_data.shape, audio_data.dtype)
    # 一番最初にSampling rateを揃えておく
    audio_data = resample_audio(audio_data, sr, target_sr=SAMPLING_RATE)
    audio_sec = 0

    # Check if the audio in the window is too quiet
    rms = np.sqrt(np.mean(audio_data**2))
    db_level = 20 * np.log10(rms + 1e-9)  # 加えた小さな値で-inf値を防ぐ
    print(db_level)

    # 音量の正規化
    audio_data = normalize_audio(audio_data)

    # 新しいデータを現在のチャンクに追加
    current_chunk.append(audio_data)
    total_chunk = np.concatenate(current_chunk)

    while len(total_chunk) >= SAMPLING_RATE * chunk_duration:
        chunk = total_chunk[:SAMPLING_RATE * chunk_duration]
        total_chunk = total_chunk[SAMPLING_RATE * chunk_duration:]  # 処理済みの部分を削除
        audio_sec += chunk_duration

        print(f"Processing audio chunk of length {len(chunk)}")
        df = process_chunk(chunk, language_set)
        data_df = pd.concat([data_df, df], ignore_index=True)

        yield (SAMPLING_RATE, chunk), data_df

    # 未処理の残りのデータを保持
    current_chunk = [total_chunk]


# パラメータの入力コンポーネント
chunk_duration_input = gr.Number(value=5, label="Chunk Duration (seconds)")
language_set_input = gr.Textbox(value="Japanese,English", label="Language Set (comma-separated)")

inputs_file = [gr.Audio(sources=["upload"], type="numpy"), chunk_duration_input, language_set_input]
inputs_stream = [gr.Audio(sources=["microphone"], type="numpy", streaming=True), chunk_duration_input, language_set_input]
outputs = [gr.Audio(type="numpy"), gr.DataFrame(headers=["Time", "Volume", "Length (s)"])]

with gr.Blocks() as demo:
    with gr.TabItem("Upload"):
        gr.Interface(
            fn=process_audio,
            inputs=inputs_file,
            outputs=outputs,
            live=False,
            title="File Audio Processing",
            description="Upload an audio file to see the processing results."
        )

    with gr.TabItem("Microphone"):
        gr.Interface(
            fn=process_audio_stream,
            inputs=inputs_stream,
            outputs=outputs,
            live=True,
            title="Real-time Audio Processing",
            description="Speak into the microphone and see real-time audio processing results."
        )

if __name__ == "__main__":
    demo.launch()