Spaces:
Running
Running
import streamlit as st | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
AutoModelForSeq2SeqLM, | |
) | |
import torch | |
import os | |
# Define the model names and their corresponding Hugging Face models | |
MODEL_MAPPING = { | |
"text2shellcommands": "t5-small", # Example seq2seq model for generating shell commands | |
"pentest_ai": "bert-base-uncased", # Example classification model for pentesting tasks | |
} | |
# Function to create a sidebar for model selection | |
def select_model(): | |
""" | |
Adds a dropdown to the Streamlit sidebar for selecting a model. | |
Returns: | |
str: The selected model key from MODEL_MAPPING. | |
""" | |
st.sidebar.header("Model Configuration") | |
selected_model = st.sidebar.selectbox("Select a model", list(MODEL_MAPPING.keys())) | |
return selected_model | |
# Function to load the model and tokenizer with caching | |
def load_model_and_tokenizer(model_name): | |
""" | |
Loads the tokenizer and model for the specified Hugging Face model name. | |
Uses caching to optimize performance. | |
Args: | |
model_name (str): The name of the Hugging Face model to load. | |
Returns: | |
tuple: A tokenizer and model instance. | |
""" | |
try: | |
# Load the tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Determine the correct model class to use | |
if "t5" in model_name or "seq2seq" in model_name: | |
# Load a sequence-to-sequence model | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
else: | |
# Load a sequence classification model | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
return tokenizer, model | |
except Exception as e: | |
# Display an error message in the Streamlit app | |
st.error(f"An error occurred while loading the model or tokenizer: {str(e)}") | |
return None, None | |
# Function to handle predictions based on the selected model | |
def predict_with_model(user_input, model, tokenizer, model_choice): | |
""" | |
Handles predictions using the loaded model and tokenizer. | |
Args: | |
user_input (str): Text input from the user. | |
model: Loaded Hugging Face model. | |
tokenizer: Loaded Hugging Face tokenizer. | |
model_choice (str): Selected model key from MODEL_MAPPING. | |
Returns: | |
dict: A dictionary containing the prediction results. | |
""" | |
if model_choice == "text2shellcommands": | |
# Generate shell commands (Seq2Seq task) | |
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
outputs = model.generate(**inputs) | |
generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return {"Generated Shell Command": generated_command} | |
else: | |
# Perform classification | |
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class = torch.argmax(logits, dim=-1).item() | |
return { | |
"Predicted Class": predicted_class, | |
"Logits": logits.tolist(), | |
} | |
# Function to process uploaded files | |
def process_uploaded_file(uploaded_file): | |
""" | |
Reads and processes the uploaded file. Supports text and CSV files. | |
Args: | |
uploaded_file: The uploaded file. | |
Returns: | |
str: The content of the file as a string. | |
""" | |
try: | |
if uploaded_file is not None: | |
file_type = uploaded_file.type | |
# Text file processing | |
if "text" in file_type: | |
content = uploaded_file.read().decode("utf-8") | |
return content | |
# CSV file processing | |
elif "csv" in file_type: | |
import pandas as pd | |
df = pd.read_csv(uploaded_file) | |
return df.to_string() # Convert the dataframe to string | |
else: | |
st.error("Unsupported file type. Please upload a text or CSV file.") | |
return None | |
except Exception as e: | |
st.error(f"Error processing file: {e}") | |
return None | |
# Main function to define the Streamlit app | |
def main(): | |
st.title("AI Model Inference Dashboard") | |
st.markdown( | |
""" | |
This dashboard allows you to interact with different AI models for inference tasks, | |
such as generating shell commands or performing text classification. | |
""" | |
) | |
# Model selection | |
model_choice = select_model() | |
model_name = MODEL_MAPPING.get(model_choice) | |
tokenizer, model = load_model_and_tokenizer(model_name) | |
# Input text area or file upload | |
input_choice = st.radio("Choose Input Method", ("Text Input", "Upload File")) | |
if input_choice == "Text Input": | |
user_input = st.text_area("Enter your text input:", placeholder="Type your text here...") | |
# Handle prediction after submit | |
submit_button = st.button("Submit") | |
if submit_button and user_input: | |
st.write("### Prediction Results:") | |
result = predict_with_model(user_input, model, tokenizer, model_choice) | |
for key, value in result.items(): | |
st.write(f"**{key}:** {value}") | |
elif input_choice == "Upload File": | |
uploaded_file = st.file_uploader("Choose a text or CSV file", type=["txt", "csv"]) | |
# Handle prediction after submit | |
submit_button = st.button("Submit") | |
if submit_button and uploaded_file: | |
file_content = process_uploaded_file(uploaded_file) | |
if file_content: | |
st.write("### File Content:") | |
st.write(file_content) | |
result = predict_with_model(file_content, model, tokenizer, model_choice) | |
st.write("### Prediction Results:") | |
for key, value in result.items(): | |
st.write(f"**{key}:** {value}") | |
else: | |
st.info("No valid content found in the file.") | |
if __name__ == "__main__": | |
main() | |