import streamlit as st
import torch
from transformers import pipeline
import torchaudio
import os
import re
import numpy as np

# -----------------------------
# 1) Model loading and utility functions
# -----------------------------

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load Whisper model for Cantonese ASR
MODEL_NAME = "alvanlii/whisper-small-cantonese"
language = "zh"
asr_pipe = pipeline(
    task="automatic-speech-recognition",
    model=MODEL_NAME,
    chunk_length_s=30,  # Adjust chunk size for memory handling
    device=device,
    generate_kwargs={
        "no_repeat_ngram_size": 3,
        "repetition_penalty": 1.15,
        "temperature": 0.7,
        "top_p": 0.97,
        "top_k": 40,
        "max_new_tokens": 400,
        "do_sample": True
    }
)
asr_pipe.model.config.forced_decoder_ids = asr_pipe.tokenizer.get_decoder_prompt_ids(
    language=language, task="transcribe"
)

# Remove repeated sentences that are highly similar
def remove_repeated_phrases(text):
    def is_similar(a, b):
        from difflib import SequenceMatcher
        return SequenceMatcher(None, a, b).ratio() > 0.9

    sentences = re.split(r'(?<=[。!?])', text)
    cleaned_sentences = []
    for sentence in sentences:
        if not cleaned_sentences or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()):
            cleaned_sentences.append(sentence.strip())
    return " ".join(cleaned_sentences)

# Remove punctuation from text
def remove_punctuation(text):
    return re.sub(r'[^\w\s]', '', text)

# Transcribe the audio using Whisper
def transcribe_audio(audio_path):
    waveform, sample_rate = torchaudio.load(audio_path)

    # Convert multi-channel audio to mono if necessary
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    waveform = waveform.squeeze(0).numpy()
    duration = waveform.shape[0] / sample_rate

    # For audio longer than 60 seconds, process in overlapping chunks
    if duration > 60:
        chunk_size = sample_rate * 55
        step_size = sample_rate * 50
        results = []
        for start in range(0, waveform.shape[0], step_size):
            chunk = waveform[start:start + chunk_size]
            if chunk.shape[0] == 0:
                break
            transcript = asr_pipe({"sampling_rate": sample_rate, "raw": chunk})["text"]
            results.append(remove_punctuation(transcript))
        return remove_punctuation(remove_repeated_phrases(" ".join(results)))
    else:
        transcript = asr_pipe({"sampling_rate": sample_rate, "raw": waveform})["text"]
        return remove_punctuation(remove_repeated_phrases(transcript))

# Load sentiment analysis model
sentiment_pipe = pipeline(
    "text-classification",
    model="MonkeyDLLLLLLuffy/CustomModel-multilingual-sentiment-analysis-enhanced",
    device=device
)

# Perform sentiment analysis in chunks (max 512 tokens each)
def rate_quality(text):
    chunks = [text[i:i+512] for i in range(0, len(text), 512)]
    results = sentiment_pipe(chunks, batch_size=4)

    label_map = {
        "Very Negative": "Very Poor",
        "Negative": "Poor",
        "Neutral": "Neutral",
        "Positive": "Good",
        "Very Positive": "Very Good"
    }
    processed_results = [label_map.get(res["label"], "Unknown") for res in results]

    # Use majority voting to determine the final sentiment
    return max(set(processed_results), key=processed_results.count)

# -----------------------------
# 2) Main Streamlit application
# -----------------------------
def main():
    st.set_page_config(page_title="Customer Service Analyzer", page_icon="🎙️")

    # Custom CSS styling
    st.markdown("""
    <style>
    .header {
        background: linear-gradient(90deg, #4B79A1, #283E51);
        border-radius: 10px;
        padding: 1.5rem;
        text-align: center;
        box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
        margin-bottom: 1.5rem;
        color: white;
    }
    </style>
    """, unsafe_allow_html=True)

    st.markdown("""
    <div class="header">
        <h1 style='margin:0;'>🎙️ Customer Service Quality Analyzer</h1>
        <p>Evaluate the service quality with simple uploading!</p>
    </div>
    """, unsafe_allow_html=True)

    # Initialize session state to store results
    if "transcript" not in st.session_state:
        st.session_state["transcript"] = ""
    if "quality_rating" not in st.session_state:
        st.session_state["quality_rating"] = ""
    if "uploaded_filename" not in st.session_state:
        st.session_state["uploaded_filename"] = ""

    # File uploader
    uploaded_file = st.file_uploader(
        "📤 Please upload your Cantonese customer service audio file",
        type=["wav", "mp3", "flac"]
    )

    if uploaded_file is not None:
        # Display audio player
        st.audio(uploaded_file, format="audio/wav")

        # Only run the model again if a new file is uploaded
        if st.session_state["uploaded_filename"] != uploaded_file.name:
            st.session_state["uploaded_filename"] = uploaded_file.name

            # Save uploaded file to a temporary path
            temp_audio_path = "uploaded_audio.wav"
            with open(temp_audio_path, "wb") as f:
                f.write(uploaded_file.getbuffer())

            # Process the audio
            with st.spinner('🔄 Processing your audio, please wait...'):
                transcript = transcribe_audio(temp_audio_path)
                quality_rating = rate_quality(transcript)

            # Store results in session state
            st.session_state["transcript"] = transcript
            st.session_state["quality_rating"] = quality_rating

            # Remove the temporary file
            if os.path.exists(temp_audio_path):
                os.remove(temp_audio_path)

    # Display results if available
    if st.session_state["transcript"]:
        st.write("**Transcript:**", st.session_state["transcript"])
        st.write("**Sentiment Analysis Result:**", st.session_state["quality_rating"])

        # Prepare download content
        result_text = (
            f"Transcript:\n{st.session_state['transcript']}\n\n"
            f"Sentiment Analysis Result: {st.session_state['quality_rating']}"
        )
        # Download button for the analysis report
        st.download_button(
            label="📥 Download Analysis Report",
            data=result_text,
            file_name="analysis_report.txt"
        )

    st.markdown(
        "❓If you encounter any issues, please contact customer support: "
        "📧 **example@hellotoby.com**"
    )

if __name__ == "__main__":
    main()