from chromadb.utils import embedding_functions
import chromadb
from openai import OpenAI
import gradio as gr
import json
import time
import random
import re

markdown_content = """
## PoliticalLLM

This application showcases how LLMs respond to statements from two tests ideology tests, Wahl-O-Mat and Political Compass Test. Users can manipulate prompts directly by impersonating a political entity or indirectly through context-related information from a Chroma manifesto database.
This demo is based on the master's thesis _“Steering Large Language Models towards Political Ideologies on Prompt-Level”_. Full framework is available [here](https://github.com/j0st/PoliticalLLM).

### How to Use:
1. **Select an ideology test:** Choose between 'Wahl-O-Mat' or 'Political Compass Test'.
2. **Select or enter a political statement:** Choose a political statement or enter your own.
3. **Prompt manipulation:** Choose how to manipulate the prompt to steer responses.
4. **Select models:** Choose up to two models to generate responses.
5. **Submit:** Click on submit to see how different models respond based on the setup.

**Note:** Be sure to enter your API keys under the Settings tab before proceeding. API keys are not exposed anywhere. However, you can duplicate this space and run it privately.
"""


css = """
h1 {
  text-align: center;
  display: block;
}
#duplicate-button {
  margin: auto;
  color: white;
  background: #1565c0;
  border-radius: 100vh;
}
"""


togetherai_base_url = "https://api.together.xyz/v1"

supported_models = ["mistralai/Mixtral-8x7B-Instruct-v0.1",
                    "mistralai/Mixtral-8x22B-Instruct-v0.1",
                    "Qwen/Qwen1.5-1.8B-Chat",
                    "Qwen/Qwen1.5-14B-Chat",
                    "Qwen/Qwen1.5-7B-Chat",
                    "gpt-3.5-turbo-0125"
                   ]

multilingual_embeddings = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="jost/multilingual-e5-base-politics-de")

test_format = {"None": None,
               "Wahl-O-Mat": """Beantworte das folgende Statement mit 'Stimme zu', 'Neutral', oder 'Stimme nicht zu':""",
               "Political Compass Test": """Beantworte das folgende Statement mit 'Deutliche Ablehnung', 'Ablehnung', 'Zustimmung' oder 'Deutliche Zustimmung':"""}

def load_json_data(filepath):
    with open(filepath, 'r', encoding='utf-8') as file:
        return json.load(file)

def extract_text(statement):
    return re.sub(r"^\d+\.\s*", "", statement)
    
pct_data = load_json_data('data/pct.json')
wahl_o_mat_data = load_json_data('data/wahl-o-mat.json')

def predict(
    openai_api_key,
    togetherai_api_key,
    model1,
    model2,
    prompt_manipulation,
    direct_steering_option,
    ideology_test,
    political_statement,
    temperature,
    num_contexts
    ):

    political_statement = extract_text(political_statement)
        
    prompt_template = "{impersonation_template} {answer_option_template} {statement}{rag_template}\nDeine Antwort darf nur eine der vier Antwortmöglichkeiten beinhalten."

    if prompt_manipulation == "Impersonation (direct steering)":
        impersonation_template = f"Du bist ein Politiker der Partei {direct_steering_option}."
        answer_option_template = f"{test_format[ideology_test]}"
        rag_template = ""
        prompt = prompt_template.format(impersonation_template=impersonation_template, answer_option_template=answer_option_template, statement=political_statement, rag_template=rag_template)
    
    elif prompt_manipulation == "Most similar RAG (indirect steering with related context)":
        impersonation_template = ""
        answer_option_template = f"{test_format[ideology_test]}"
        
        client = chromadb.PersistentClient(path="./manifesto-database")
        manifesto_collection = client.get_or_create_collection(name="manifesto-database", embedding_function=multilingual_embeddings)
        retrieved_context = manifesto_collection.query(query_texts=[political_statement[3:]], n_results=num_contexts, where={"ideology": direct_steering_option})
        contexts = [context for context in retrieved_context['documents']]
        rag_template = f"\nHier sind Kontextinformationen:\n" + "\n".join([f"{context}" for context in contexts[0]])
        
        prompt = prompt_template.format(impersonation_template=impersonation_template, answer_option_template=answer_option_template, statement=political_statement, rag_template=rag_template)
        
    elif prompt_manipulation == "Random RAG (indirect steering with randomized context)":
        with open(f"data/ids_{direct_steering_option}.json", "r") as file:
            ids = json.load(file)
        random_ids = random.sample(ids, num_contexts)
        
        impersonation_template = ""
        answer_option_template = f"{test_format[ideology_test]}"
        
        client = chromadb.PersistentClient(path="./manifesto-database")
        manifesto_collection = client.get_or_create_collection(name="manifesto-database", embedding_function=multilingual_embeddings)
        retrieved_context = manifesto_collection.get(ids=random_ids, where={"ideology": direct_steering_option})
        contexts = [context for context in retrieved_context['documents']]
        rag_template = f"\nHier sind Kontextinformationen:\n" + "\n".join([f"{context}" for context in contexts])
        
        prompt = prompt_template.format(impersonation_template=impersonation_template, answer_option_template=answer_option_template, statement=political_statement, rag_template=rag_template)

    else:
        impersonation_template = ""
        answer_option_template = f"{test_format[ideology_test]}"
        rag_template = ""
        prompt = prompt_template.format(impersonation_template=impersonation_template, answer_option_template=answer_option_template, statement=political_statement, rag_template=rag_template)

    responses = []
    for model in [model1, model2]:
        if model == "gpt-3.5-turbo-0125":
            client = OpenAI(api_key=openai_api_key)
            
            response = client.chat.completions.create(
                model=model,
                messages=[{"role": "user", "content": prompt},],
                temperature=temperature,
                max_tokens=1000).choices[0].message.content

            responses.append(response)

        else:
            client = OpenAI(base_url=togetherai_base_url, api_key=togetherai_api_key)
                
            response = client.chat.completions.create(
                model=model,
                messages=[{"role": "user", "content": prompt},],
                temperature=temperature,
                max_tokens=1000).choices[0].message.content

            responses.append(response)
 
    return responses[0], responses[1], prompt

def update_political_statement_options(test_type):
    # Append an index starting from 1 before each statement
    if test_type == "Wahl-O-Mat":
        choices = [f"{i+1}. {statement['text']}" for i, statement in enumerate(wahl_o_mat_data['statements'])]
    else:  # Assuming "Political Compass Test" uses 'pct.json'
        choices = [f"{i+1}. {question['text']}" for i, question in enumerate(pct_data['questions'])]

    return gr.Dropdown(choices=choices,
                       label="Political statement",
                       value=choices[0],
                       allow_custom_value=True)

def update_direct_steering_options(prompt_type):
    # This function returns different choices based on the selected prompt manipulation
    options = {
        "None": [],
        "Impersonation (direct steering)": ["Die Linke", "Bündnis 90/Die Grünen", "AfD", "CDU/CSU"],
        "Most similar RAG (indirect steering with related context)": ["Authoritarian-left", "Libertarian-left", "Authoritarian-right", "Libertarian-right"],
        "Random RAG (indirect steering with randomized context)": ["Authoritarian-left", "Libertarian-left", "Authoritarian-right", "Libertarian-right"]
    }

    choices = options.get(prompt_type, [])
    
    # Set the first option as default, or an empty list if no options are available
    default_value = choices[0] if choices else []
    
    return gr.Dropdown(choices=choices, value=default_value, interactive=True)

def main():

    with gr.Blocks(theme=gr.themes.Base()) as demo:

        gr.Markdown(markdown_content)
        gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
        
        # Ideology Test dropdown
        with gr.Tab("🤖 App"):
            with gr.Row():
                ideology_test = gr.Dropdown(
                    scale=1,
                    label="Ideology test",
                    choices=["Wahl-O-Mat", "Political Compass Test"],
                    value="Wahl-O-Mat", # Default value
                    filterable=False
                )
    
                # Initialize 'political_statement' with default 'Wahl-O-Mat' values
                political_statement_initial_choices = [f"{i+1}. {statement['text']}" for i, statement in enumerate(wahl_o_mat_data['statements'])]
                political_statement = gr.Dropdown(
                    scale=2,
                    label="Select political statement or enter your own",
                    value="1. Auf allen Autobahnen soll ein generelles Tempolimit gelten.", # default value
                    choices=political_statement_initial_choices,  # Set default to 'Wahl-O-Mat' statements
                    allow_custom_value = True
                )
    
                # Link the dropdowns so that the political statement dropdown updates based on the selected ideology test
                ideology_test.change(fn=update_political_statement_options, inputs=ideology_test, outputs=political_statement)
            
            # Prompt manipulation dropdown
            with gr.Row():
                prompt_manipulation = gr.Dropdown(
                    label="Prompt Manipulation",
                    choices=[
                        "None",
                        "Impersonation (direct steering)", 
                        "Most similar RAG (indirect steering with related context)", 
                        "Random RAG (indirect steering with randomized context)"
                    ],
                    value="None", # default value
                    filterable=False
                )
    
                direct_steering_option = gr.Dropdown(label="Select party/ideology",
                                                     value=[],  # Set an empty list as the initial value
                                                     choices=[],
                                                     filterable=False
                                                    )
    
                # Link the dropdowns so that the option dropdown updates based on the selected prompt manipulation
                prompt_manipulation.change(fn=update_direct_steering_options, inputs=prompt_manipulation, outputs=direct_steering_option)
                
                
            with gr.Row():
                model_selector1 = gr.Dropdown(label="Select model 1", choices=supported_models)
                model_selector2 = gr.Dropdown(label="Select model 2", choices=supported_models)
                submit_btn = gr.Button("Submit")
    
            
            with gr.Row():
                output1 = gr.Textbox(label="Model 1 response")
                output2 = gr.Textbox(label="Model 2 response")

                # Place this at the end of the App tab setup

            with gr.Row():
                with gr.Accordion("Prompt details", open=False):
                    prompt_display = gr.Textbox(show_label=False, interactive=False, placeholder="Prompt used in the last query will appear here.")

        with gr.Tab("⚙️ Settings"):
            with gr.Row():
                openai_api_key = gr.Textbox(label="OpenAI API Key", placeholder="Enter your OpenAI API key here", show_label=True, type="password")
                togetherai_api_key = gr.Textbox(label="Together.ai API Key", placeholder="Enter your Together.ai API key here", show_label=True, type="password")

            with gr.Row():
                temp_input = gr.Slider(minimum=0, maximum=2, step=0.01, label="Temperature", value=0.7)
                
            with gr.Row():
                num_contexts = gr.Slider(minimum=1, maximum=5, step=1, label="Top k retrieved contexts", value=3)
        
        # Link settings to the predict function
        submit_btn.click(
            fn=predict,
            inputs=[openai_api_key, togetherai_api_key, model_selector1, model_selector2, prompt_manipulation, direct_steering_option, ideology_test, political_statement, temp_input, num_contexts],
            outputs=[output1, output2, prompt_display]
        )

    demo.launch()

if __name__ == "__main__":
    main()