Spaces:
Running
Running
from typing import List, Tuple, Dict, TypedDict, Optional, Any | |
import os | |
import gradio as gr | |
from langchain_core.language_models.llms import LLM | |
from langchain_openai.chat_models import ChatOpenAI | |
from langchain_aws import ChatBedrock | |
import boto3 | |
from ask_candid.base.config.rest import OPENAI | |
from ask_candid.base.config.models import Name2Endpoint | |
from ask_candid.base.config.data import ALL_INDICES | |
from ask_candid.utils import format_chat_ag_response | |
from ask_candid.chat import run_chat | |
try: | |
from feedback import FeedbackApi | |
except ImportError: | |
from demos.feedback import FeedbackApi | |
ROOT = os.path.dirname(os.path.abspath(__file__)) | |
class LoggedComponents(TypedDict): | |
context: List[gr.components.Component] | |
found_helpful: gr.components.Component | |
will_recommend: gr.components.Component | |
comments: gr.components.Component | |
email: gr.components.Component | |
def send_feedback( | |
chat_context, | |
found_helpful, | |
will_recommend, | |
comments, | |
): | |
api = FeedbackApi() | |
total_submissions = 0 | |
try: | |
response = api( | |
context=chat_context, | |
found_helpful=found_helpful, | |
will_recommend=will_recommend, | |
comments=comments, | |
email=email | |
) | |
total_submissions = response.get("response", 0) | |
gr.Info("Thank you for submitting feedback") | |
except Exception as ex: | |
raise gr.Error(f"Error submitting feedback: {ex}") | |
return total_submissions | |
def select_foundation_model(model_name: str, max_new_tokens: int) -> LLM: | |
if model_name == "gpt-4o": | |
llm = ChatOpenAI( | |
model_name=Name2Endpoint[model_name], | |
max_tokens=max_new_tokens, | |
api_key=OPENAI["key"], | |
temperature=0.0, | |
streaming=True, | |
) | |
elif model_name in {"claude-3.5-haiku", "llama-3.1-70b-instruct", "mistral-large", "mixtral-8x7B"}: | |
llm = ChatBedrock( | |
client=boto3.client("bedrock-runtime", region_name="us-east-1"), | |
model=Name2Endpoint[model_name], | |
max_tokens=max_new_tokens, | |
temperature=0.0 | |
) | |
else: | |
raise gr.Error(f"Base model `{model_name}` is not supported") | |
return llm | |
def execute( | |
thread_id: str, | |
user_input: Dict[str, Any], | |
history: List[Dict], | |
model_name: str, | |
max_new_tokens: int, | |
indices: Optional[List[str]] = None, | |
): | |
return run_chat( | |
thread_id=thread_id, | |
user_input=user_input, | |
history=history, | |
llm=select_foundation_model(model_name=model_name, max_new_tokens=max_new_tokens), | |
indices=indices | |
) | |
def build_rag_chat() -> Tuple[LoggedComponents, gr.Blocks]: | |
with gr.Blocks(theme=gr.themes.Soft(), title="Candid's AI assistant") as demo: | |
gr.Markdown( | |
""" | |
<h1>Candid's AI assistant</h1> | |
<p> | |
Please read the <a | |
href='https://info.candid.org/chatbot-reference-guide' | |
target="_blank" | |
rel="noopener noreferrer" | |
>guide</a> to get started. | |
</p> | |
<hr> | |
""" | |
) | |
with gr.Accordion(label="Advanced settings", open=False): | |
es_indices = gr.CheckboxGroup( | |
choices=list(ALL_INDICES), | |
# value=[idx for idx in ALL_INDICES if "news" not in idx], | |
value=list(ALL_INDICES), | |
label="Sources to include", | |
interactive=True, | |
) | |
llmname = gr.Radio( | |
label="Language model", | |
value="claude-3.5-haiku", | |
choices=list(Name2Endpoint.keys()), | |
interactive=True, | |
) | |
max_new_tokens = gr.Slider( | |
value=256 * 3, | |
minimum=128, | |
maximum=2048, | |
step=128, | |
label="Max new tokens", | |
interactive=True, | |
) | |
with gr.Column(): | |
chatbot = gr.Chatbot( | |
label="AskCandid", | |
elem_id="chatbot", | |
bubble_full_width=True, | |
avatar_images=( | |
None, | |
os.path.join(ROOT, "static", "candid_logo_yellow.png"), | |
), | |
height="45vh", | |
type="messages", | |
show_label=False, | |
show_copy_button=True, | |
show_share_button=None, | |
show_copy_all_button=False, | |
autoscroll=True, | |
layout="panel", | |
) | |
msg = gr.MultimodalTextbox(label="Your message", interactive=True) | |
thread_id = gr.Text(visible=False, value="", label="thread_id") | |
gr.ClearButton(components=[msg, chatbot, thread_id], size="sm") | |
# pylint: disable=no-member | |
chat_msg = msg.submit( | |
fn=execute, | |
inputs=[thread_id, msg, chatbot, llmname, max_new_tokens, es_indices], | |
outputs=[msg, chatbot, thread_id], | |
) | |
chat_msg.then(format_chat_ag_response, chatbot, chatbot, api_name="bot_response") | |
logged = LoggedComponents(context=chatbot) | |
return logged, demo | |
def build_feedback(components: LoggedComponents) -> gr.Blocks: | |
with gr.Blocks(theme=gr.themes.Soft(), title="Candid AI demo") as demo: | |
gr.Markdown("<h1>Help us improve this tool with your valuable feedback</h1>") | |
with gr.Row(): | |
with gr.Column(): | |
found_helpful = gr.Radio( | |
[True, False], label="Did you find what you were looking for?" | |
) | |
will_recommend = gr.Radio( | |
[True, False], | |
label="Will you recommend this Chatbot to others?", | |
) | |
comment = gr.Textbox(label="Additional comments (optional)", lines=4) | |
email = gr.Textbox(label="Your email (optional)", lines=1) | |
submit = gr.Button("Submit Feedback") | |
components["found_helpful"] = found_helpful | |
components["will_recommend"] = will_recommend | |
components["comments"] = comment | |
components["email"] = email | |
# pylint: disable=no-member | |
submit.click( | |
fn=send_feedback, | |
inputs=[ | |
components["context"], | |
components["found_helpful"], | |
components["will_recommend"], | |
components["comments"], | |
components["email"] | |
], | |
outputs=None, | |
show_api=False, | |
api_name=False, | |
preprocess=False, | |
) | |
return demo | |
def build_app(): | |
logger, candid_chat = build_rag_chat() | |
feedback = build_feedback(logger) | |
with open(os.path.join(ROOT, "static", "chatStyle.css"), "r", encoding="utf8") as f: | |
css_chat = f.read() | |
demo = gr.TabbedInterface( | |
interface_list=[ | |
candid_chat, | |
feedback | |
], | |
tab_names=[ | |
"Candid's AI assistant", | |
"Feedback" | |
], | |
title="Candid's AI assistant", | |
theme=gr.themes.Soft(), | |
css=css_chat, | |
) | |
return demo | |
if __name__ == "__main__": | |
app = build_app() | |
app.queue(max_size=5).launch( | |
show_api=False, | |
auth=[ | |
(os.getenv("APP_USERNAME"), os.getenv("APP_PASSWORD")), | |
(os.getenv("APP_PUBLIC_USERNAME"), os.getenv("APP_PUBLIC_PASSWORD")), | |
], | |
auth_message="Login to Candid's AI assistant", | |
ssr_mode=False | |
) | |