OSINT_Tool / app.py
Canstralian's picture
Update app.py
6a09dd7 verified
raw
history blame
6.14 kB
import streamlit as st
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForSeq2SeqLM,
)
import torch
import os
# Define the model names and their corresponding Hugging Face models
MODEL_MAPPING = {
"text2shellcommands": "t5-small", # Example seq2seq model for generating shell commands
"pentest_ai": "bert-base-uncased", # Example classification model for pentesting tasks
}
# Function to create a sidebar for model selection
def select_model():
"""
Adds a dropdown to the Streamlit sidebar for selecting a model.
Returns:
str: The selected model key from MODEL_MAPPING.
"""
st.sidebar.header("Model Configuration")
selected_model = st.sidebar.selectbox("Select a model", list(MODEL_MAPPING.keys()))
return selected_model
# Function to load the model and tokenizer with caching
@st.cache_resource
def load_model_and_tokenizer(model_name):
"""
Loads the tokenizer and model for the specified Hugging Face model name.
Uses caching to optimize performance.
Args:
model_name (str): The name of the Hugging Face model to load.
Returns:
tuple: A tokenizer and model instance.
"""
try:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Determine the correct model class to use
if "t5" in model_name or "seq2seq" in model_name:
# Load a sequence-to-sequence model
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
else:
# Load a sequence classification model
model = AutoModelForSequenceClassification.from_pretrained(model_name)
return tokenizer, model
except Exception as e:
# Display an error message in the Streamlit app
st.error(f"An error occurred while loading the model or tokenizer: {str(e)}")
return None, None
# Function to handle predictions based on the selected model
def predict_with_model(user_input, model, tokenizer, model_choice):
"""
Handles predictions using the loaded model and tokenizer.
Args:
user_input (str): Text input from the user.
model: Loaded Hugging Face model.
tokenizer: Loaded Hugging Face tokenizer.
model_choice (str): Selected model key from MODEL_MAPPING.
Returns:
dict: A dictionary containing the prediction results.
"""
if model_choice == "text2shellcommands":
# Generate shell commands (Seq2Seq task)
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model.generate(**inputs)
generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"Generated Shell Command": generated_command}
else:
# Perform classification
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=-1).item()
return {
"Predicted Class": predicted_class,
"Logits": logits.tolist(),
}
# Function to process uploaded files
def process_uploaded_file(uploaded_file):
"""
Reads and processes the uploaded file. Supports text and CSV files.
Args:
uploaded_file: The uploaded file.
Returns:
str: The content of the file as a string.
"""
try:
if uploaded_file is not None:
file_type = uploaded_file.type
# Text file processing
if "text" in file_type:
content = uploaded_file.read().decode("utf-8")
return content
# CSV file processing
elif "csv" in file_type:
import pandas as pd
df = pd.read_csv(uploaded_file)
return df.to_string() # Convert the dataframe to string
else:
st.error("Unsupported file type. Please upload a text or CSV file.")
return None
except Exception as e:
st.error(f"Error processing file: {e}")
return None
# Main function to define the Streamlit app
def main():
st.title("AI Model Inference Dashboard")
st.markdown(
"""
This dashboard allows you to interact with different AI models for inference tasks,
such as generating shell commands or performing text classification.
"""
)
# Model selection
model_choice = select_model()
model_name = MODEL_MAPPING.get(model_choice)
tokenizer, model = load_model_and_tokenizer(model_name)
# Input text area or file upload
input_choice = st.radio("Choose Input Method", ("Text Input", "Upload File"))
if input_choice == "Text Input":
user_input = st.text_area("Enter your text input:", placeholder="Type your text here...")
# Handle prediction after submit
submit_button = st.button("Submit")
if submit_button and user_input:
st.write("### Prediction Results:")
result = predict_with_model(user_input, model, tokenizer, model_choice)
for key, value in result.items():
st.write(f"**{key}:** {value}")
elif input_choice == "Upload File":
uploaded_file = st.file_uploader("Choose a text or CSV file", type=["txt", "csv"])
# Handle prediction after submit
submit_button = st.button("Submit")
if submit_button and uploaded_file:
file_content = process_uploaded_file(uploaded_file)
if file_content:
st.write("### File Content:")
st.write(file_content)
result = predict_with_model(file_content, model, tokenizer, model_choice)
st.write("### Prediction Results:")
for key, value in result.items():
st.write(f"**{key}:** {value}")
else:
st.info("No valid content found in the file.")
if __name__ == "__main__":
main()