import gradio as gr from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer from functools import lru_cache # Pre-selected small models MODELS = { "SmolLM2-135M-Instruct": "HuggingFaceTB/SmolLM2-135M-Instruct", "GPT-2 (Small)": "gpt2", "DistilGPT-2": "distilgpt2", "Facebook OPT-125M": "facebook/opt-125m" } # Cache the model and tokenizer to avoid reloading @lru_cache(maxsize=1) def load_model_cached(model_name): try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) return pipeline("text-generation", model=model, tokenizer=tokenizer) except Exception as e: return f"Error loading model: {str(e)}" # Function to generate a response from the model def chat(selected_model, user_input, chat_history, system_prompt=""): if not selected_model: return "Please select a model from the dropdown.", chat_history # Get the model name from the dropdown model_name = MODELS.get(selected_model) if not model_name: return "Invalid model selected.", chat_history # Load the model (cached) generator = load_model_cached(model_name) if isinstance(generator, str): # If there was an error loading the model return generator, chat_history # Prepare the input with an optional system prompt full_input = f"{system_prompt}\n\n{user_input}" if system_prompt else user_input # Generate a response try: # Get the model's maximum context length max_context_length = generator.model.config.max_position_embeddings max_length = min(500, max_context_length) # Ensure we don't exceed the model's limit # Truncate the input if it's too long inputs = generator.tokenizer( full_input, return_tensors="pt", max_length=max_length, truncation=True ) # Generate the response with a progress indicator with gr.Progress() as progress: progress(0.5, desc="Generating response...") response = generator( inputs['input_ids'], max_length=max_length, num_return_sequences=1, do_sample=True, top_p=0.95, top_k=60 )[0]['generated_text'] # Append the interaction to the chat history chat_history.append((user_input, response)) return "", chat_history except Exception as e: return f"Error generating response: {str(e)}", chat_history # Gradio interface with gr.Blocks() as demo: gr.Markdown("# Chat with Small Language Models") with gr.Row(): selected_model = gr.Dropdown( label="Select a Model", choices=list(MODELS.keys()), value="SmolLM2-135M-Instruct" # Default model ) chatbot = gr.Chatbot(label="Chat") user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...") system_prompt = gr.Textbox( label="System Prompt (Optional)", placeholder="e.g., You are a helpful AI assistant.", lines=2 ) clear_button = gr.Button("Clear Chat") # Define the chat function user_input.submit(chat, [selected_model, user_input, chatbot, system_prompt], [user_input, chatbot]) clear_button.click(lambda: [], None, chatbot, queue=False) # Launch the app demo.launch()