import streamlit as st
import os, time
from app.vdr_session import *
from app.vdr_schemas import *
from st_clickable_images import clickable_images
from app.prompt_template import VDR_PROMPT
  
def page_vdr():
    st.header("Visual Document Retrieval")

    # Store session context
    if "vdr_session" not in st.session_state.keys():
        st.session_state["vdr_session"] = VDRSession()

    with st.sidebar:
        
        #api_key = st.text_input('Enter API Key:', type='password')
        api_key = os.getenv("GLOBAL_AIFS_API_KEY")

        check_api_key=st.session_state["vdr_session"].set_api_key(api_key)

        if check_api_key:
            st.success('API Key is valid!', icon='✅')
            avai_llms = st.session_state["vdr_session"].get_available_vlms()
            avai_embeds = st.session_state["vdr_session"].get_available_image_embeds()
            selected_llm = st.sidebar.selectbox('Choose VLM models', avai_llms, key='selected_llm', disabled=not check_api_key)
            selected_embed = st.sidebar.selectbox('Choose Embedding models', avai_embeds, key='selected_embed', disabled=not check_api_key)
            #st.session_state["vdr_session"].set_context(selected_llm, selected_embed)
        else:
            st.warning('Please enter valid credentials!', icon='⚠️')

    if check_api_key:
        
        with st.sidebar:
            uploaded_files = st.file_uploader("Upload PDF files", key="uploaded_files", accept_multiple_files=True, disabled=not check_api_key)

            if st.button("Add selected context", key="add_context", type="primary"):
                if uploaded_files:
                    try:
                        indexing_bar = st.progress(0, text="Indexing...")
                        if st.session_state["vdr_session"].indexing(uploaded_files, selected_embed, indexing_bar):
                            st.success('Indexing completed!')
                            indexing_bar.empty()
                            #st.rerun()
                        else:
                            st.warning('Files empty or not supported.', icon='⚠️')
                    except Exception as e:
                        st.error(f"Error during indexing: {e}")
                else:
                    st.warning('Please upload files first!', icon='⚠️')

            if st.button("🗑️ Remove all context", key="remove_context"):
                try:
                    st.session_state["vdr_session"].clear_context()
                    st.success("Context removed")
                    st.rerun()
                except Exception as e:
                    st.error(f"Error during removing context: {e}")

            
            top_k_sim = st.slider(label="Top k similarity", min_value=1, max_value=10, value=3, step=1, key="top_k_sim")
            #text_only_embed = st.toggle("Text only embedding", key="text_only_embed", value=False)
            chat_prompt = st.text_area("Prompt template", key="chat_prompt", value=VDR_PROMPT, height=300)

        query = st.text_input(label="Query",key='query',placeholder="Enter your query here",label_visibility="hidden", disabled=not st.session_state.get("vdr_session").indexed_images)

        with st.expander(f"**Top {top_k_sim} retrieved contexts**", expanded=True):
            try:
                if len(query.strip()) > 2:
                    if query != st.session_state.get("last_query", None):
                        with st.spinner('Searching...'):
                            st.session_state["last_query"] = query
                            st.session_state["result_images"] = st.session_state["vdr_session"].search_images(query, top_k_sim)

                if st.session_state.get("result_images", []):
                    images = st.session_state["result_images"]

                    clicked = clickable_images(
                        images,
                        titles=[f"Image #{str(i)}" for i in range(len(images))],
                        div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
                        img_style={"margin": "5px", "height": "200px"},
                    )
                    st.write(f"**Retrieved by: {selected_embed}**")

                    @st.dialog(" ", width="large")
                    def show_selected_image(id):
                        st.markdown(f"**Similarity rank: {id}**")
                        st.image(images[id])
                    
                    if clicked > -1 and clicked != st.session_state.get("clicked", None):
                        show_selected_image(clicked)
                        st.session_state["clicked"] = clicked
                    
            except Exception as e:
                st.error(f"Error during search: {e}")

        if st.session_state.get("result_images", None):
            if st.button("Generate answer", key="ask", type="primary"):
                if len(query.strip()) > 2:
                    try:
                        with st.spinner('Generating response...'):
                            stream_response = st.session_state["vdr_session"].ask(
                                query=query, 
                                model=selected_llm, 
                                prompt_template= chat_prompt, 
                                retrieved_context=st.session_state["result_images"],
                                stream=True
                            )
                            #print(stream_response)
                            st.write_stream(stream_response)
                            st.write(f"**Answered by: {selected_llm}**")
                    except Exception as e:
                        st.error(f"Error during asking: {e}")
                else:
                    st.warning('Please enter query first!', icon='⚠️')