import re
from datetime import datetime
from flask import request, Response, stream_with_context
from requests import get
from server.config import special_instructions
from server.utils import generate_with_references, inject_references_to_messages, check_model, make_prompt
from config import MODEL_REFERENCE_1, MODEL_REFERENCE_2, MODEL_REFERENCE_3

class Backend_Api:
    def __init__(self, bp, config: dict) -> None:
        self.bp = bp
        self.routes = {
            '/backend-api/v2/conversation': {
                'function': self._conversation,
                'methods': ['POST']
            }
        }

    def _conversation(self):
        conversation_id = request.json['conversation_id']
        try:
            api_key = request.json['api_key']
            jailbreak = request.json['jailbreak']
            model = request.json['model']
            check_model(model)
            messages = self.build_messages(jailbreak)

            # Generate response
            response = self.generate_response(model, messages, jailbreak)
            return Response(stream_with_context(response), mimetype='text/event-stream')

        except Exception as e:
            print(e)
            print(e.__traceback__.tb_next)
            return {
                '_action': '_ask',
                'success': False,
                "error": f"an error occurred {str(e)}"
            }, 400

    def build_messages(self, jailbreak):
        _conversation = request.json['meta']['content']['conversation']
        internet_access = request.json['meta']['content']['internet_access']
        prompt = request.json['meta']['content']['parts'][0]

        # Add the existing conversation
        conversation = _conversation

        # Add web results if enabled
        if internet_access:
            current_date = datetime.now().strftime("%Y-%m-%d")
            query = f'Current date: {current_date}. ' + prompt["content"]
            search_results = self.fetch_search_results(query)
            conversation.extend(search_results)

        # Add jailbreak instructions if enabled
        if jailbreak_instructions := self.getJailbreak(jailbreak):
            conversation.extend(jailbreak_instructions)

        # Add the prompt
        conversation.append(prompt)

        # Reduce conversation size to avoid API Token quantity error
        if len(conversation) > 3:
            conversation = conversation[-4:]

        return conversation

    def fetch_search_results(self, query):
        search = get('https://ddg-api.herokuapp.com/search',
                     params={
                         'query': query,
                         'limit': 3,
                     })
        snippets = ""
        for index, result in enumerate(search.json()):
            snippet = f'[{index + 1}] "{result["snippet"]}" URL:{result["link"]}.'
            snippets += snippet

        response = "Here are some updated web searches. Use this to improve user response:"
        response += snippets

        return [{'role': 'system', 'content': response}]

    def generate_response(self, model, messages, jailbreak):
        reference_models = [MODEL_REFERENCE_1, MODEL_REFERENCE_2, MODEL_REFERENCE_3]
        data = {
            "instruction": [[] for _ in range(len(reference_models))],
            "references": [""] * len(reference_models),
            "model": [m for m in reference_models],
        }

        num_proc = len(reference_models)
        rounds = 1

        for i in range(len(reference_models)):
            data["instruction"][i].append({"role": "user", "content": messages[-1]['content']})

        for i_round in range(rounds):
            for i in range(num_proc):
                reference_model = data["model"][i]
                prompt = make_prompt(data["instruction"][i][-1]['content'], data["instruction"][i], reference_model)
                data["references"][i] = generate_with_references(reference_model, prompt)

        output = generate_with_references(
            model=model,
            messages=data["instruction"][0],
            references=data["references"],
        )

        all_output = ""
        for chunk in output:
            all_output += chunk

        if jailbreak:
            response_jailbreak = ''
            jailbroken_checked = False
            for message in all_output:
                response_jailbreak += message
                if jailbroken_checked:
                    yield message
                else:
                    if self.response_jailbroken_success(response_jailbreak):
                        jailbroken_checked = True
                    if self.response_jailbroken_failed(response_jailbreak):
                        yield response_jailbreak
                        jailbroken_checked = True
        else:
            yield all_output

    def response_jailbroken_success(self, response: str) -> bool:
        act_match = re.search(r'ACT:', response, flags=re.DOTALL)
        return bool(act_match)

    def response_jailbroken_failed(self, response):
        return False if len(response) < 4 else not (response.startswith("GPT:") or response.startswith("ACT:"))

    def getJailbreak(self, jailbreak):
        if jailbreak != "default":
            special_instructions[jailbreak][0]['content'] += special_instructions['two_responses_instruction']
            if jailbreak in special_instructions:
                special_instructions[jailbreak]
                return special_instructions[jailbreak]
            else:
                return None
        else:
            return None