from typing import List, Tuple, Dict, TypedDict, Optional, Any import os import gradio as gr from langchain_core.language_models.llms import LLM from langchain_openai.chat_models import ChatOpenAI from langchain_aws import ChatBedrock import boto3 from ask_candid.base.config.rest import OPENAI from ask_candid.base.config.models import Name2Endpoint from ask_candid.base.config.data import ALL_INDICES from ask_candid.utils import format_chat_ag_response from ask_candid.chat import run_chat ROOT = os.path.dirname(os.path.abspath(__file__)) BUCKET = "candid-data-science-reporting" PREFIX = "Assistant" class LoggedComponents(TypedDict): context: List[gr.components.Component] found_helpful: gr.components.Component will_recommend: gr.components.Component comments: gr.components.Component email: gr.components.Component def select_foundation_model(model_name: str, max_new_tokens: int) -> LLM: if model_name == "gpt-4o": llm = ChatOpenAI( model_name=Name2Endpoint[model_name], max_tokens=max_new_tokens, api_key=OPENAI["key"], temperature=0.0, streaming=True, ) elif model_name in {"claude-3.5-haiku", "llama-3.1-70b-instruct", "mistral-large", "mixtral-8x7B"}: llm = ChatBedrock( client=boto3.client("bedrock-runtime"), model=Name2Endpoint[model_name], max_tokens=max_new_tokens, temperature=0.0 ) else: raise gr.Error(f"Base model `{model_name}` is not supported") return llm def execute( thread_id: str, user_input: Dict[str, Any], history: List[Dict], model_name: str, max_new_tokens: int, indices: Optional[List[str]] = None, ): return run_chat( thread_id=thread_id, user_input=user_input, history=history, llm=select_foundation_model(model_name=model_name, max_new_tokens=max_new_tokens), indices=indices ) def build_rag_chat() -> Tuple[LoggedComponents, gr.Blocks]: with gr.Blocks(theme=gr.themes.Soft(), title="Chat") as demo: gr.Markdown( """

Ask Candid

Please read the guide to get started.


""" ) with gr.Accordion(label="Advanced settings", open=False): es_indices = gr.CheckboxGroup( choices=list(ALL_INDICES), value=list(ALL_INDICES), label="Sources to include", interactive=True, ) llmname = gr.Radio( label="Language model", value="gpt-4o", choices=list(Name2Endpoint.keys()), interactive=True, ) max_new_tokens = gr.Slider( value=256 * 3, minimum=128, maximum=2048, step=128, label="Max new tokens", interactive=True, ) with gr.Column(): chatbot = gr.Chatbot( label="AskCandid", elem_id="chatbot", bubble_full_width=True, avatar_images=( None, os.path.join(ROOT, "static", "candid_logo_yellow.png"), ), height="45vh", type="messages", show_label=False, show_copy_button=True, show_share_button=True, show_copy_all_button=True, ) msg = gr.MultimodalTextbox(label="Your message", interactive=True) thread_id = gr.Text(visible=False, value="", label="thread_id") gr.ClearButton(components=[msg, chatbot, thread_id], size="sm") # pylint: disable=no-member chat_msg = msg.submit( fn=execute, inputs=[thread_id, msg, chatbot, llmname, max_new_tokens, es_indices], outputs=[msg, chatbot, thread_id], ) chat_msg.then(format_chat_ag_response, chatbot, chatbot, api_name="bot_response") logged = LoggedComponents(context=[thread_id, chatbot]) return logged, demo def build_app(): _, candid_chat = build_rag_chat() with open(os.path.join(ROOT, "static", "chatStyle.css"), "r", encoding="utf8") as f: css_chat = f.read() demo = gr.TabbedInterface( interface_list=[ candid_chat, ], tab_names=[ "AskCandid", ], theme=gr.themes.Soft(), css=css_chat, ) return demo if __name__ == "__main__": app = build_app() app.queue(max_size=5).launch( show_api=False, auth=[ (os.getenv("APP_USERNAME"), os.getenv("APP_PASSWORD")), (os.getenv("APP_PUBLIC_USERNAME"), os.getenv("APP_PUBLIC_PASSWORD")), ], auth_message="Login to Candid's AI assistant", ssr_mode=False )