import spaces
import re 
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
import json

LEAN4_DEFAULT_HEADER = (
    "import Mathlib\n"
    "import Aesop\n\n"
    "set_option maxHeartbeats 0\n\n"
    "open BigOperators Real Nat Topology Rat\n"
)

title = """🙋🏻‍♂️Welcome to🌟Tonic's🔮Goedel Prover📉
You can build with this endpoint using🔮Goedel-Prover-SFT📉 available here : [Goedel-LM/Goedel-Prover-SFT](https://huggingface.co/Goedel-LM/Goedel-Prover-SFT)."""

def format_prompt(formal_statement, informal_prefix=""):
    """Format the input according to the Lean4 structure"""
    return (
        f"Complete the following Lean 4 code with explanatory comments preceding each line of code:\n\n"
        f"```lean4\n"
        f"{LEAN4_DEFAULT_HEADER}\n"
        f"{informal_prefix}\n"
        f"{formal_statement}"
    )

def extract_code(response):
    """Extract code between lean4 code blocks and the model's output"""
    try:
        # Find the last occurrence of ```lean4 and extract everything until the last ```
        start_idx = response.rfind("```lean4")
        if start_idx == -1:
            return response.strip()
        
        # Get content after ```lean4
        content = response[start_idx + 7:]
        
        # Find the last closing ```
        end_idx = content.rfind("```")
        if end_idx != -1:
            content = content[:end_idx]
        
        # Clean up the content
        lines = content.split('\n')
        cleaned_lines = []
        
        for line in lines:
            # Skip empty lines at start
            if not cleaned_lines and not line.strip():
                continue
            # Skip "Complete the following" lines
            if "Complete the following" in line:
                continue
            cleaned_lines.append(line)
        
        return '\n'.join(cleaned_lines)
    except Exception as e:
        print(f"Error in extract_code: {str(e)}")
        return "Error processing code"
        

# Example problems
unimath1 = """Goal:
  X : UU
  Y : UU
  P : UU
  xp : (X → P) → P
  yp : (Y → P) → P
  X0 : X × Y → P
  x : X
  ============================
   (Y → P)"""

unimath2 = """Goal:
    R : ring  M : module R
  ============================
   (islinear (idfun M))"""

unimath3 = """Goal:
    X : UU  i : nat  b : hProptoType (i < S i)  x : Vector X (S i)  r : i = i
  ============================
   (pr1 lastelement = pr1 (i,, b))"""

unimath4 = """Goal:
    X : dcpo  CX : continuous_dcpo_struct X  x : pr1hSet X  y : pr1hSet X
  ============================
   (x ⊑ y ≃ (∀ i : approximating_family CX x, approximating_family CX x i ⊑ y))"""

additional_info_prompt = "/-Explain using mathematics-/\n"

examples = [
    [unimath1, additional_info_prompt, 2500],
    [unimath2, additional_info_prompt, 2500],
    [unimath3, additional_info_prompt, 2500],
    [unimath4, additional_info_prompt, 2500]
]

model_name = "Goedel-LM/Goedel-Prover-SFT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

# Set generation config
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.generation_config.bos_token_id = 100000
model.generation_config.eos_token_id = 100001
model.generation_config.do_sample = True
model.generation_config.temperature = 1.0
model.generation_config.top_p = 0.95

@spaces.GPU
def solve_math_problem(question, informal_prefix, max_tokens):
    # Format the prompt using Lean4 structure
    prompt = format_prompt(question, informal_prefix)
    
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
    attention_mask = torch.ones_like(input_ids)
    
    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=max_tokens + input_ids.shape[1],
        pad_token_id=model.generation_config.pad_token_id,
        temperature=1.0,
        top_p=0.95,
    )
    
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract the full code from the response
    full_code = extract_code(prompt + result)
    
    # Create output dictionary similar to reference code
    output_data = {
        "model_input": prompt,
        "model_output": result,
        "full_code": full_code
    }
    
    return json.dumps(output_data, indent=2), full_code

def main():
    iface = gr.Interface(        
        title="🙋🏻‍♂️Welcome to🌟Tonic's🔮Goedel Prover📉",
        description="""You can build with this endpoint using🔮Goedel-Prover-SFT📉 available here : [Goedel-LM/Goedel-Prover-SFT](https://huggingface.co/Goedel-LM/Goedel-Prover-SFT). We're using 🤖[introspector/unimath](https://huggingface.co/datasets/introspector/unimath) for cool examples, check it out below ! The demo is still a work in progress and we're looking forward to build downstream tasks that showcase outstanding mathematical reasoning. Have any ideas ? join us below !
You can also use 🔮Goedel Prover📉 by cloning this space. Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/Math?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3> 
Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [Join us on Discord](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) Math with [introspector](https://huggingface.co/introspector) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [SciTonic](https://github.com/Tonic-AI/scitonic)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
""",
        fn=solve_math_problem,
        outputs=[
            gr.JSON(label="Full Output"),
            gr.Code(label="Extracted Lean4 Code", language="python")
        ],
        inputs=[
            gr.Textbox(label="🤔Enter your Lean4 formal statement", lines=7),
            gr.Textbox(value=additional_info_prompt, label="🪜Optional informal prefix"),
            gr.Slider(minimum=150, maximum=4086, value=2500, label="🪙Max Tokens")
        ],
        examples=examples
    )

    iface.launch()

if __name__ == "__main__":
    main()