Alexa-NLU-Clone / app.py
qanastek's picture
Update app.py
d66e935
raw
history blame
3.77 kB
import gradio as gr
import os
import torch
import librosa
from glob import glob
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline, AutoModelForTokenClassification, TokenClassificationPipeline, Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
# ASR
model_name = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
processor_asr = Wav2Vec2Processor.from_pretrained(model_name)
model_asr = Wav2Vec2ForCTC.from_pretrained(model_name)
# Classifier Intent
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification'
tokenizer_intent = AutoTokenizer.from_pretrained(model_name)
model_intent = AutoModelForSequenceClassification.from_pretrained(model_name)
classifier_intent = TextClassificationPipeline(model=model_intent, tokenizer=tokenizer_intent)
# Classifier Language
model_name = 'qanastek/51-languages-classifier'
tokenizer_langs = AutoTokenizer.from_pretrained(model_name)
model_langs = AutoModelForSequenceClassification.from_pretrained(model_name)
classifier_language = TextClassificationPipeline(model=model_langs, tokenizer=tokenizer_langs)
# NER Extractor
model_name = 'qanastek/XLMRoberta-Alexa-Intents-NER-NLU'
tokenizer_ner = AutoTokenizer.from_pretrained(model_name)
model_ner = AutoModelForTokenClassification.from_pretrained(model_name)
predict_ner = TokenClassificationPipeline(model=model_ner, tokenizer=tokenizer_ner)
EXAMPLE_DIR = './'
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.wav')))
def transcribe(audio_path):
speech_array, sampling_rate = librosa.load(audio_path, sr=16_000)
inputs = processor_asr(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model_asr(inputs.input_values, attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
return processor_asr.batch_decode(predicted_ids)[0]
def getUniform(text):
idx = 0
res = {}
for t in text:
raw = t["entity"].replace("B-","").replace("I-","")
word = t["word"].replace("▁","")
if "B-" in t["entity"]:
res[f"{raw}|{idx}"] = [word]
idx += 1
else:
res[f"{raw}|{idx}"].append(word)
res = [(r.split("|")[0], res[r]) for r in res]
return res
def process(path):
text = transcribe(path)
intent_class = classifier_intent(text)[0]["label"]
language_class = classifier_language(text)[0]["label"]
named_entities = getUniform(predict_ner(text))
return {
"text": text,
"language": language_class,
"intent_class": intent_class,
"named_entities": named_entities,
}
audio_paths = [
"/users/ylabrak/Alexa_NLU/Pipeline/wavs/set-the-volume-to-low.wav",
"/users/ylabrak/Alexa_NLU/Pipeline/wavs/tell-me-a-joke.wav",
"/users/ylabrak/Alexa_NLU/Pipeline/wavs/tell me the artist of this song.wav",
"/users/ylabrak/Alexa_NLU/Pipeline/wavs/order-a-pizza.wav",
"/users/ylabrak/Alexa_NLU/Pipeline/wavs/TTS_1/tell-me-a-good-joke.wav",
"/users/ylabrak/Alexa_NLU/Pipeline/wavs/TTS_1/order me a pizza.wav",
"/users/ylabrak/Alexa_NLU/Pipeline/wavs/TTS_1/tell-me-the-artist-of-this-song.wav",
]
def predict(wav_file):
res = process(wav_file)
return res["text"]
# iface = gr.Interface(fn=predict, inputs="text", outputs="text")
iface = gr.Interface(
predict,
title='Alexa NLU Clone',
description='Upload your wav file to test the model',
inputs=[
gr.inputs.Audio(label='wav file', source='microphone', type='filepath')
],
outputs=[
gr.outputs.Textbox(label='decoding result'),
],
examples=examples,
article='Made with ❤️ by Yanis Labrak thanks to 🤗',
)
iface.launch()