from typing import Annotated, Any, Literal

from langchain_tavily import TavilySearch
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.types import interrupt, Command
from typing_extensions import TypedDict

"""
from langchain_anthropic import ChatAnthropic
from langchain_ollama.llms import OllamaLLM
from langchain_experimental.llms.ollama_functions import OllamaFunctions
llm = OllamaFunctions(model="qwen2.5", format="json")  
llm_with_tools = llm #.bind_tools(tools)
"""

from langchain_groq import ChatGroq

llm = ChatGroq(
    model="gemma2-9b-it", #"llama-3.1-8b-instant",
    temperature=0.4,
    max_tokens=None,
    timeout=None,
    max_retries=2,
    # other params...
)

template = """Question: {question}
Answer: Let's think step by step."""

prompt = ChatPromptTemplate.from_template(template)

# model = OllamaLLM(model="deepseek-r1")

chain = prompt | llm


# print(chain.invoke({"question": "Explain like I'm 5 for capacity planning?"}))

@tool
def human_assistance(query: str) -> str:
  """Request assistance from a human."""
  human_response = interrupt({"query": query})
  return human_response["data"]


tool = TavilySearch(max_results=2)
tools = [tool, human_assistance]
llm_with_tools=llm.bind_tools(tools)

# llm = OllamaLLM(model="deepseek-r1") #ChatAnthropic(model="claude-3-5-sonnet-20240620")

class State(TypedDict):
  messages: Annotated[list, add_messages]
  persona: str
  email: str
  release: Literal['approve', 'reject']


graph_builder = StateGraph(State)


def write_email(state: State):
  prompt = f"""Write an promotional personalized email for this persona and offer financial education and setup a meeting for financial advisor, Only the email nothing else:
    {state["persona"]}
    """
  email = llm_with_tools.invoke(prompt)
  # Because we will be interrupting during tool execution,
  # we disable parallel tool calling to avoid repeating any
  # tool invocations when we resume.
  # assert len(email.tool_calls) <= 1
  return Command(update={"email": email.content})


graph_builder.add_node("write_email", write_email)


def delivery(state: State):
  print(f"""Delivering: {state['email']}""")

  return Command(update={"messages": ["Email delivered to customer"]})


graph_builder.add_node("delivery", delivery)


def human_approval(state: State) -> Command[Literal["delivery", END]]:
  is_approved = interrupt(
    "Approval for release the promotional email to customer? (type: approved or rejected):"
  )

  if is_approved == "approved":
    return Command(goto="delivery", update={"release": "approved"})
  else:
    return Command(goto=END, update={"release": "rejected"})


# Add the node to the graph in an appropriate location
# and connect it to the relevant nodes.
graph_builder.add_node("human_approval", human_approval)

graph_builder.add_edge(START, "write_email")
graph_builder.add_edge("write_email", "human_approval")

graph_builder.add_edge("delivery", END)

checkpointer = MemorySaver()
graph = graph_builder.compile(checkpointer=checkpointer)


def email(persona, campaign, history):
  thread_config = {"configurable": {"thread_id": campaign}}
  for event in graph.stream({"persona": persona}, config=thread_config):
    for value in event.values():
      return r"Assistant: ", value, r"Value: ", graph.get_state(thread_config).values


def feedback(deliver, campaign, history):
  thread_config = {"configurable": {"thread_id": campaign}}
  for event in graph.stream(Command(resume=deliver), config=thread_config):
    for value in event.values():
      return r"Assistant: ", value, r"Value: ", graph.get_state(thread_config).values


'''
from IPython.display import Image, display

try:
  display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
  # This requires some extra dependencies and is optional
  pass
'''

def campaign(user_input: Any, id: str):
  thread_config = {"configurable": {"thread_id": id}}
  for event in graph.stream(user_input, config=thread_config):
    for value in event.values():
      print("Assistant:", value, "Value: ", graph.get_state(thread_config).values)

"""
campaign({"persona": "My mortgage rate is 9%, I cannot afford it anymore, I need to refinance and I'm unemploy right now."}, "MOR")

campaign({"persona": "my credit card limit is too low, I need a card with bigger limit and low fee"}, "CARD")

campaign(Command(resume="approved"), "MOR")
"""

while False:
  try:
    user_input = input("User: ")
    if user_input.lower() in ["quit", "exit", "q"]:
      print("Goodbye!")
      break
    campaign(user_input, "MORT")
    # stream_graph_updates(user_input)
  except Exception as e:
    # fallback if input() is not available
    user_input = "What do you know about LangGraph?"
    print("User: " + user_input)
    campaign(user_input, "MORT")
    # stream_graph_updates(user_input)
    break