Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
# Load pre-trained model and tokenizer | |
model_name = "KoalaAI/Text-Moderation" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
# Get labels from the model's config | |
labels = list(model.config.id2label.values()) | |
def classify_text(text): | |
# Tokenize input | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
# Get prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
# Format results | |
results = {labels[i]: float(predictions[0][i]) for i in range(len(labels))} | |
return results | |
# Create Gradio interface | |
custom_theme = gr.themes.Soft( | |
primary_hue=gr.themes.colors.green, | |
secondary_hue=gr.themes.colors.emerald, | |
) | |
demo = gr.Interface( | |
fn=classify_text, | |
inputs=gr.Textbox(placeholder="Enter text to classify...", lines=5), | |
outputs=gr.Label(num_top_classes=len(labels)), | |
title="KoalaAI - Text-Moderation Demo", | |
description="This model determines whether or not there is potentially harmful content in a given text", | |
theme=custom_theme | |
) | |
# Launch app | |
if __name__ == "__main__": | |
demo.launch() |