MonkeyDLLLLLLuffy's picture
Update app.py
a33bb2e verified
raw
history blame
5.01 kB
import streamlit as st
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import torchaudio
import os
import re
from difflib import SequenceMatcher
import numpy as np
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Whisper model with adjusted parameters for better memory handling
MODEL_NAME = "alvanlii/whisper-small-cantonese"
language = "zh"
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
chunk_length_s=30, # Reduce chunk size for better memory handling
device=device,
generate_kwargs={
"no_repeat_ngram_size": 3,
"repetition_penalty": 1.15,
"temperature": 0.7,
"top_p": 0.97,
"top_k": 40,
"max_new_tokens": 400, # Reduced from 500 to avoid exceeding 448
"do_sample": True # Required for `top_p` and `top_k` to take effect
}
)
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
# Similarity check to remove repeated phrases
def remove_repeated_phrases(text):
sentences = re.split(r'(?<=[γ€‚οΌοΌŸ])', text)
cleaned_sentences = []
for sentence in sentences:
if not cleaned_sentences or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()):
cleaned_sentences.append(sentence.strip())
return " ".join(cleaned_sentences)
def remove_punctuation(text):
return re.sub(r'[^\w\s]', '', text)
def transcribe_audio(audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
# Convert stereo to mono (if needed)
if waveform.shape[0] > 1: # More than 1 channel
waveform = torch.mean(waveform, dim=0, keepdim=True) # Average the channels
waveform = waveform.squeeze(0).numpy() # Convert to NumPy (1D array)
duration = waveform.shape[0] / sample_rate
if duration > 60:
chunk_size = sample_rate * 55 # 55 seconds
step_size = sample_rate * 50 # 50 seconds overlap
results = []
for start in range(0, waveform.shape[0], step_size):
chunk = waveform[start:start + chunk_size]
if chunk.shape[0] == 0:
break
transcript = pipe({"sampling_rate": sample_rate, "raw": chunk})["text"]
results.append(remove_punctuation(transcript))
return remove_punctuation(remove_repeated_phrases(" ".join(results)))
return remove_punctuation(remove_repeated_phrases(pipe({"sampling_rate": sample_rate, "raw": waveform})["text"]))
# Sentiment analysis model
sentiment_pipe = pipeline("text-classification", model="MonkeyDLLLLLLuffy/CustomModel-multilingual-sentiment-analysis-enhanced", device=device)
# Rate sentiment with batch processing
def rate_quality(text):
chunks = [text[i:i+512] for i in range(0, len(text), 512)]
results = sentiment_pipe(chunks, batch_size=4)
label_map = {"Very Negative": "Very Poor", "Negative": "Poor", "Neutral": "Neutral", "Positive": "Good", "Very Positive": "Very Good"}
processed_results = [label_map.get(res["label"], "Unknown") for res in results]
return max(set(processed_results), key=processed_results.count)
# Streamlit main interface
def main():
st.set_page_config(page_title="Customer Service Analyzer", page_icon="πŸŽ™οΈ")
# Business-oriented CSS styling
st.markdown("""
<style>
.header {
background: linear-gradient(90deg, #4B79A1, #283E51);
border-radius: 10px;
padding: 1.5rem;
text-align: center;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
margin-bottom: 1.5rem;
color: white;
}
</style>
""", unsafe_allow_html=True)
st.markdown("""
<div class="header">
<h1 style='margin:0;'>πŸŽ™οΈ Customer Service Quality Analyzer</h1>
<p>Evaluate the service quality with simple uploading!</p>
</div>
""", unsafe_allow_html=True)
uploaded_file = st.file_uploader("πŸ“€ Please upload your Cantonese customer service audio file", type=["wav", "mp3", "flac"])
if uploaded_file is not None:
temp_audio_path = "uploaded_audio.wav"
with open(temp_audio_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.audio(uploaded_file, format="audio/wav")
with st.spinner('πŸ”„ Processing your audio, please wait...'):
transcript = transcribe_audio(temp_audio_path)
quality_rating = rate_quality(transcript)
st.write("**Transcript:**", transcript)
st.write("**Sentiment Analysis Result:**", quality_rating)
result_text = f"Transcript:\n{transcript}\n\nSentiment Analysis Result: {quality_rating}"
st.download_button(label="πŸ“₯ Download Analysis Report", data=result_text, file_name="analysis_report.txt")
st.markdown("❓If you encounter any issues, please contact customer support: πŸ“§ **example@hellotoby.com**")
os.remove(temp_audio_path)
if __name__ == "__main__":
main()