import gradio as gr from utils import ( device, jina_tokenizer, jina_model, embeddings_predict_relevance, stsb_model, stsb_tokenizer, ms_model, ms_tokenizer, cross_encoder_predict_relevance ) def predict(system_prompt, user_prompt, selected_model): if selected_model == "jinaai/jina-embeddings-v2-small-en": predicted_label, probabilities = embeddings_predict_relevance(system_prompt, user_prompt, jina_model, jina_tokenizer, device) elif selected_model == "cross-encoder/stsb-roberta-base": predicted_label, probabilities = cross_encoder_predict_relevance(system_prompt, user_prompt, stsb_model, stsb_tokenizer, device) elif selected_model == "cross-encoder/ms-marco-MiniLM-L-6-v2": predicted_label, probabilities = cross_encoder_predict_relevance(system_prompt, user_prompt, ms_model, ms_tokenizer, device) probability_off_topic = probabilities[0][1] * 100 label = "Off-topic" if predicted_label==1 else "On-topic" result = f""" **Prediction Summary**: - **Predicted Label**: {label} - **Probability of Off-topic**: {probability_off_topic:.3f}% """ return result with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as app: gr.Markdown("# Off-Topic Classification using Fine-tuned Embeddings and Cross-Encoder Models") with gr.Row(): system_prompt = gr.Textbox(label="System Prompt") user_prompt = gr.Textbox(label="User Prompt") with gr.Row(): selected_model = gr.Dropdown( ["jinaai/jina-embeddings-v2-small-en", "cross-encoder/stsb-roberta-base", "cross-encoder/ms-marco-MiniLM-L-6-v2"], label="Select a model") # Button to run the prediction get_classfication = gr.Button("Check Content") output_result = gr.Markdown(label="Classification and Probabilities") get_classfication.click( fn=predict, inputs=[system_prompt, user_prompt, selected_model], outputs=output_result ) if __name__ == "__main__": app.launch()