import requests
import json
import re

def extract_between_tags(text, start_tag, end_tag):
    start_index = text.find(start_tag)
    end_index = text.find(end_tag, start_index)
    return text[start_index+len(start_tag):end_index-len(end_tag)]

class VectaraQuery():
    def __init__(self, api_key: str, customer_id: str, corpus_id: str, prompt_name: str = None):
        self.customer_id = customer_id
        self.corpus_id = corpus_id
        self.api_key = api_key
        self.prompt_name = prompt_name if prompt_name else "vectara-experimental-summary-ext-2023-12-11-large"
        self.conv_id = None

    def get_body(self, user_response: str):
        corpora_key_list = [{
            'customer_id': self.customer_id, 'corpus_id': self.corpus_id, 'lexical_interpolation_config': {'lambda': 0.025}
        }]

        user_response = user_response.replace('"', '\\"')  # Escape double quotes
        prompt = f'''
        [
            {{
                "role": "system",
                "content": "You are an assistant that provides information about drink names based on a given corpus. \
                Format the response in the following way:\n\
                Reason: <reason why the name cannot be used>\n\
                Alternative: <alternative name>\n\
                Notes: <additional notes>\n\n\
                Example:\n\
                Reason: The name 'Vodka Sunrise' cannot be used because it is trademarked.\n\
                Alternative: Use 'Morning Delight' instead.\n\
                Notes: Ensure the drink contains vodka to match the alternative name."
            }},
            {{
                "role": "user",
                "content": "{user_response}"
            }}
        ]
        '''

        return {
            'query': [
                { 
                    'query': user_response,
                    'start': 0,
                    'numResults': 10,
                    'corpusKey': corpora_key_list,
                    'context_config': {
                        'sentences_before': 2,
                        'sentences_after': 2,
                        'start_tag': "%START_SNIPPET%",
                        'end_tag': "%END_SNIPPET%",
                    }
                } 
            ]
        }

    def get_headers(self):
        return {
            "Content-Type": "application/json",
            "Accept": "application/json",
            "customer-id": self.customer_id,
            "x-api-key": self.api_key,
            "grpc-timeout": "60S"
        }

    def submit_query(self, query_str: str):
        endpoint = f"https://api.vectara.io/v1/stream-query"
        body = self.get_body(query_str)
        response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers(), stream=True) 
        if response.status_code != 200:
            print(f"Query failed with code {response.status_code}, reason {response.reason}, text {response.text}")
            return "Sorry, something went wrong. Please try again later."

        accumulated_text = ""
        for line in response.iter_lines():
            if line:  # filter out keep-alive new lines
                data = json.loads(line.decode('utf-8'))
                print(f"Received data chunk: {json.dumps(data, indent=2)}")  # Debugging line

                if 'result' not in data:
                    print("No 'result' in data")
                    continue

                res = data['result']
                if 'responseSet' not in res:
                    print("No 'responseSet' in result")
                    continue
                
                response_set = res['responseSet']
                
                if response_set:
                    for result in response_set['response']:
                        if 'text' not in result:
                            print("No 'text' in result")
                            continue
                        text = result['text']
                        print(f"Processing text: {text}")  # Debugging line
                        accumulated_text += text + " "

        if accumulated_text:
            return self.format_response_using_vectara(accumulated_text)

        return "No relevant information found."

    def format_response_using_vectara(self, text):
        endpoint = f"https://api.vectara.io/v1/stream-summary"
        body = {
            'text': text,
            'summary': {
                'responseLang': 'eng',
                'maxSummarizedResults': 1,
                'summarizerPromptName': self.prompt_name,
                'promptText': f'''
                [
                    {{
                        "role": "system",
                        "content": "You are an assistant that provides information about drink names based on a given corpus. \
                        Format the response in the following way:\n\
                        Reason: <reason why the name cannot be used>\n\
                        Alternative: <alternative name>\n\
                        Notes: <additional notes>\n\n\
                        Example:\n\
                        Reason: The name 'Vodka Sunrise' cannot be used because it is trademarked.\n\
                        Alternative: Use 'Morning Delight' instead.\n\
                        Notes: Ensure the drink contains vodka to match the alternative name."
                    }},
                    {{
                        "role": "user",
                        "content": "{text}"
                    }}
                ]
                '''
            }
        }
        headers = self.get_headers()
        response = requests.post(endpoint, data=json.dumps(body), headers=headers)
        if response.status_code != 200:
            print(f"Summary query failed with code {response.status_code}, reason {response.reason}, text {response.text}")
            return "Sorry, something went wrong. Please try again later."

        data = response.json()
        if 'summary' in data:
            return data['summary']['text']
        
        return "No relevant information found."