import os from pathlib import Path from huggingface_hub import hf_hub_download import gradio as gr from llama_index.core import Settings from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.llms.llama_cpp import LlamaCPP from .parse_tabular import create_symptom_index # Use relative import import json import psutil from typing import Tuple, Dict import torch from gtts import gTTS import io import base64 import numpy as np from transformers.pipelines import pipeline # Changed from transformers import pipeline from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor import torchaudio import torchaudio.transforms as T # Model options mapped to their requirements MODEL_OPTIONS = { "tiny": { "name": "TinyLlama-1.1B-Chat-v1.0.Q4_K_M.gguf", "repo": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", "vram_req": 2, # GB "ram_req": 4 # GB }, "small": { "name": "phi-2.Q4_K_M.gguf", "repo": "TheBloke/phi-2-GGUF", "vram_req": 4, "ram_req": 8 }, "medium": { "name": "mistral-7b-instruct-v0.1.Q4_K_M.gguf", "repo": "TheBloke/Mistral-7B-Instruct-v0.1-GGUF", "vram_req": 6, "ram_req": 16 } } # Initialize Whisper components globally (these are lightweight) feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base.en") tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en") processor = WhisperProcessor(feature_extractor, tokenizer) def get_asr_pipeline(): """Lazy load ASR pipeline with proper configuration.""" global transcriber if "transcriber" not in globals(): transcriber = pipeline( "automatic-speech-recognition", model="openai/whisper-base.en", chunk_length_s=30, stride_length_s=5, device="cpu", torch_dtype=torch.float32 ) return transcriber # Audio preprocessing function def process_audio(audio_array, sample_rate): """Pre-process audio for Whisper.""" if audio_array.ndim > 1: audio_array = audio_array.mean(axis=1) # Normalize audio audio_array = audio_array.astype(np.float32) audio_array /= np.max(np.abs(audio_array)) # Resample to 16kHz if needed if sample_rate != 16000: resampler = T.Resample(orig_freq=sample_rate, new_freq=16000) audio_tensor = torch.FloatTensor(audio_array) audio_tensor = resampler(audio_tensor) audio_array = audio_tensor.numpy() # Process with correct input format inputs = processor( audio_array, sampling_rate=16000, return_tensors="pt" ) return { "input_features": inputs.input_features, "attention_mask": inputs.attention_mask } # Update transcriber configuration transcriber = pipeline( "automatic-speech-recognition", model="openai/whisper-base.en", chunk_length_s=30, stride_length_s=5, device="cpu", torch_dtype=torch.float32, feature_extractor=feature_extractor, generate_kwargs={ "use_cache": True, "return_timestamps": True } ) def get_system_specs() -> Dict[str, float]: """Get system specifications.""" # Get RAM ram_gb = psutil.virtual_memory().total / (1024**3) # Get GPU info if available gpu_vram_gb = 0 if torch.cuda.is_available(): try: # Query GPU memory in bytes and convert to GB gpu_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) except Exception as e: print(f"Warning: Could not get GPU memory: {e}") return { "ram_gb": ram_gb, "gpu_vram_gb": gpu_vram_gb } def select_best_model() -> Tuple[str, str]: """Select the best model based on system specifications.""" specs = get_system_specs() print(f"\nSystem specifications:") print(f"RAM: {specs['ram_gb']:.1f} GB") print(f"GPU VRAM: {specs['gpu_vram_gb']:.1f} GB") # Prioritize GPU if available if specs['gpu_vram_gb'] >= 4: # You have 6GB, so this should work model_tier = "small" # phi-2 should work well on RTX 2060 elif specs['ram_gb'] >= 8: model_tier = "small" else: model_tier = "tiny" selected = MODEL_OPTIONS[model_tier] print(f"\nSelected model tier: {model_tier}") print(f"Model: {selected['name']}") return selected['name'], selected['repo'] # Set up model paths MODEL_NAME, REPO_ID = select_best_model() BASE_DIR = os.path.dirname(os.path.dirname(__file__)) MODEL_DIR = os.path.join(BASE_DIR, "models") MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME) from typing import Optional def ensure_model(model_name: Optional[str] = None, repo_id: Optional[str] = None) -> str: """Ensures model is available, downloading only if needed.""" # Determine environment and set cache directory if os.path.exists("/home/user"): # HF Space environment cache_dir = "/home/user/.cache/models" else: # Local development environment cache_dir = os.path.join(BASE_DIR, "models") # Create cache directory if it doesn't exist try: os.makedirs(cache_dir, exist_ok=True) except Exception as e: print(f"Warning: Could not create cache directory {cache_dir}: {e}") # Fall back to temporary directory if needed cache_dir = os.path.join("/tmp", "models") os.makedirs(cache_dir, exist_ok=True) # Get model details if not model_name or not repo_id: model_option = MODEL_OPTIONS["small"] # default to small model model_name = model_option["name"] repo_id = model_option["repo"] # Ensure model_name and repo_id are not None if model_name is None: raise ValueError("model_name cannot be None") if repo_id is None: raise ValueError("repo_id cannot be None") # Check if model already exists in cache model_path = os.path.join(cache_dir, model_name) if os.path.exists(model_path): print(f"\nUsing cached model: {model_path}") return model_path print(f"\nDownloading model {model_name} from {repo_id}...") try: model_path = hf_hub_download( repo_id=repo_id, filename=model_name, cache_dir=cache_dir, local_dir=cache_dir ) print(f"Model downloaded successfully to {model_path}") return model_path except Exception as e: print(f"Error downloading model: {str(e)}") raise # Ensure model is downloaded model_path = ensure_model() # Configure local LLM with LlamaCPP print("\nInitializing LLM...") llm = LlamaCPP( model_path=model_path, temperature=0.7, max_new_tokens=256, context_window=2048, verbose=False # Reduce logging # n_batch and n_threads are not valid parameters for LlamaCPP and should not be used. # If you encounter segmentation faults, try reducing context_window or check your system resources. ) print("LLM initialized successfully") # Configure global settings print("\nConfiguring settings...") Settings.llm = llm Settings.embed_model = HuggingFaceEmbedding( model_name="sentence-transformers/all-MiniLM-L6-v2" ) print("Settings configured") # Create the index at startup print("\nCreating symptom index...") symptom_index = create_symptom_index() print("Index created successfully") # --- System prompt --- SYSTEM_PROMPT = """ You are a medical assistant helping a user narrow down to the most likely ICD-10 code. At each turn, EITHER ask one focused clarifying question (e.g. "Is your cough dry or productive?") or, if you have enough info, output a final JSON with fields: {"diagnoses":[…], "confidences":[…]}. """ def process_speech(audio_data, history): """Process speech input and convert to text.""" try: if not audio_data: return [] if isinstance(audio_data, tuple) and len(audio_data) == 2: sample_rate, audio_array = audio_data # Audio preprocessing if audio_array.ndim > 1: audio_array = audio_array.mean(axis=1) audio_array = audio_array.astype(np.float32) audio_array /= np.max(np.abs(audio_array)) # Ensure correct sampling rate if sample_rate != 16000: resampler = T.Resample(sample_rate, 16000) audio_tensor = torch.FloatTensor(audio_array) audio_tensor = resampler(audio_tensor) audio_array = audio_tensor.numpy() sample_rate = 16000 # Transcribe with error handling # Format dictionary correctly with required keys input_features = { "raw": audio_array, "sampling_rate": sample_rate } result = transcriber(input_features) # Handle different result types if isinstance(result, dict) and "text" in result: transcript = result["text"].strip() elif isinstance(result, str): transcript = result.strip() else: print(f"Unexpected transcriber result type: {type(result)}") return [] if not transcript: print("No transcription generated") return [] # Query symptoms with transcribed text diagnosis_query = f""" Given these symptoms: '{transcript}' Identify the most likely ICD-10 diagnoses and key questions. Focus on clinical implications. """ response = symptom_index.as_query_engine().query(diagnosis_query) return [ {"role": "user", "content": transcript}, {"role": "assistant", "content": json.dumps({ "diagnoses": [], "confidences": [], "follow_up": str(response) })} ] else: print(f"Invalid audio format: {type(audio_data)}") return [] except Exception as e: print(f"Processing error: {str(e)}") return [] def update_transcription(audio_path): """Update transcription box with speech recognition results.""" if not audio_path: return "" # Extract transcription from audio result transcript = audio_path[1] if isinstance(audio_path, tuple) else audio_path return transcript # Build enhanced Gradio interface with gr.Blocks( theme="default", css=""" * { font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Ubuntu, 'Helvetica Neue', Arial, sans-serif; } code, pre { font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, 'Liberation Mono', 'Courier New', monospace; } """ ) as demo: gr.Markdown(""" # 🏥 Medical Symptom to ICD-10 Code Assistant ## About This application is part of the Agents+MCP Hackathon. It helps medical professionals and patients understand potential diagnoses based on described symptoms. ### How it works: 1. Either click the record button and describe your symptoms or type them into the textbox 2. The AI will analyze your description and suggest possible diagnoses 3. Answer follow-up questions to refine the diagnosis """) with gr.Row(): with gr.Column(scale=2): # Add text input above microphone with gr.Row(): text_input = gr.Textbox( label="Type your symptoms", placeholder="Or type your symptoms here...", lines=3 ) submit_btn = gr.Button("Submit", variant="primary") # Existing microphone row with gr.Row(): microphone = gr.Audio( sources=["microphone"], streaming=True, type="numpy", label="Describe your symptoms" ) transcript_box = gr.Textbox( label="Transcribed Text", interactive=False, show_label=True ) clear_btn = gr.Button("Clear Chat", variant="secondary") chatbot = gr.Chatbot( label="Medical Consultation", height=500, container=True, type="messages" # This is now properly supported by our message format ) with gr.Column(scale=1): with gr.Accordion("Advanced Settings", open=False): api_key = gr.Textbox( label="OpenAI API Key (optional)", type="password", placeholder="sk-..." ) model_selector = gr.Dropdown( choices=list(MODEL_OPTIONS.keys()), label="Model Tier", value="small", interactive=True ) temperature = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, label="Temperature" ) # Event handlers clear_btn.click(lambda: None, None, chatbot, queue=False) def format_response_for_user(response_dict): """Format the assistant's response dictionary into a user-friendly string.""" diagnoses = response_dict.get("diagnoses", []) confidences = response_dict.get("confidences", []) follow_up = response_dict.get("follow_up", "") result = "" if diagnoses: result += "Possible Diagnoses:\n" for i, diag in enumerate(diagnoses): conf = f" ({confidences[i]*100:.1f}%)" if i < len(confidences) else "" result += f"- {diag}{conf}\n" if follow_up: result += f"\nFollow-up: {follow_up}" return result.strip() def enhanced_process_speech(audio_path, history, api_key=None, model_tier="small", temp=0.7): """Handle streaming speech processing and chat updates.""" if not audio_path: return history try: if isinstance(audio_path, tuple) and len(audio_path) == 2: sample_rate, audio_array = audio_path # Audio preprocessing if audio_array.ndim > 1: audio_array = audio_array.mean(axis=1) audio_array = audio_array.astype(np.float32) audio_array /= np.max(np.abs(audio_array)) # Ensure correct sampling rate if sample_rate != 16000: resampler = T.Resample( orig_freq=sample_rate, new_freq=16000 ) audio_tensor = torch.FloatTensor(audio_array) audio_tensor = resampler(audio_tensor) audio_array = audio_tensor.numpy() sample_rate = 16000 # Format input dictionary exactly as required transcriber_input = { "raw": audio_array, "sampling_rate": sample_rate } # Get transcription from Whisper result = transcriber(transcriber_input) # Extract text from result transcript = "" if isinstance(result, dict): transcript = result.get("text", "").strip() elif isinstance(result, str): transcript = result.strip() if not transcript: return history # Process the symptoms diagnosis_query = f""" Based on these symptoms: '{transcript}' Provide relevant ICD-10 codes and diagnostic questions. """ response = symptom_index.as_query_engine().query(diagnosis_query) # Format and return chat messages return history + [ {"role": "user", "content": transcript}, {"role": "assistant", "content": format_response_for_user({ "diagnoses": [], "confidences": [], "follow_up": str(response) })} ] except Exception as e: print(f"Streaming error: {str(e)}") return history microphone.stream( fn=enhanced_process_speech, inputs=[microphone, chatbot, api_key, model_selector, temperature], outputs=chatbot, show_progress="hidden", api_name=False, queue=True # Enable queuing for better stream handling ) def process_audio(audio_array, sample_rate): """Pre-process audio for Whisper.""" if audio_array.ndim > 1: audio_array = audio_array.mean(axis=1) # Convert to tensor for resampling audio_tensor = torch.FloatTensor(audio_array) # Resample to 16kHz if needed if sample_rate != 16000: resampler = T.Resample(sample_rate, 16000) audio_tensor = resampler(audio_tensor) # Normalize audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor)) # Use feature extractor with correct sampling rate features = feature_extractor( audio_tensor.numpy(), sampling_rate=16000, # Always use 16kHz return_tensors="pt" ) return { "input_features": features.input_features, "sampling_rate": 16000 # Return resampled rate } # Update transcription handler def update_live_transcription(audio): """Real-time transcription updates.""" if not audio or not isinstance(audio, tuple): return "" sample_rate, audio_array = audio features = process_audio(audio_array, sample_rate) # Get pipeline and transcribe asr = get_asr_pipeline() result = asr(features) if isinstance(result, dict): return result.get("text", "").strip() elif isinstance(result, str): return result.strip() return "" microphone.stream( fn=update_live_transcription, inputs=[microphone], outputs=transcript_box, show_progress="hidden", queue=True ) clear_btn.click( fn=lambda: (None, "", ""), outputs=[chatbot, transcript_box, text_input], queue=False ) def cleanup_memory(): """Release unused memory (placeholder for future memory management).""" import gc gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def process_text_input(text, history): """Process text input with memory management.""" if not text: return history # Limit input length if len(text) > 500: text = text[:500] + "..." # Process the symptoms diagnosis_query = f""" Based on these symptoms: '{text}' Provide relevant ICD-10 codes and diagnostic questions. Focus on clinical implications. Limit response to 1000 characters. """ response = symptom_index.as_query_engine().query(diagnosis_query) # Clean up memory cleanup_memory() return history + [ {"role": "user", "content": text}, {"role": "assistant", "content": format_response_for_user({ "diagnoses": [], "confidences": [], "follow_up": str(response)[:1000] # Limit response length })} ] submit_btn.click( fn=process_text_input, inputs=[text_input, chatbot], outputs=chatbot, queue=True ) # Add footer with social links gr.Markdown(""" --- ### 👋 About the Creator Hi! I'm Graham Paasch, an experienced technology professional! 🎥 **Check out my YouTube channel** for more tech content: [Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ) 💼 **Looking for a skilled developer?** I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/) ⭐ If you found this tool helpful, please consider: - Subscribing to my YouTube channel - Connecting on LinkedIn - Sharing this tool with others in healthcare tech """) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, mcp_server=True, allowed_paths=["*"] )