import torch
import gradio as gr
import threading
import logging
import sys
from urllib.parse import urlparse
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from datasets import load_dataset

# Configure logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

def parse_hf_dataset_url(url: str) -> tuple[str, str | None]:
    """Parse Hugging Face dataset URL into (dataset_name, config)"""
    parsed = urlparse(url)
    path_parts = parsed.path.split('/')
    
    try:
        # Find 'datasets' in path
        datasets_idx = path_parts.index('datasets')
    except ValueError:
        raise ValueError("Invalid Hugging Face dataset URL")
    
    dataset_parts = path_parts[datasets_idx+1:]
    dataset_name = "/".join(dataset_parts[0:2])
    
    # Try to find config (common pattern for datasets with viewer)
    try:
        viewer_idx = dataset_parts.index('viewer')
        config = dataset_parts[viewer_idx+1] if viewer_idx+1 < len(dataset_parts) else None
    except ValueError:
        config = None
    
    return dataset_name, config

def train(dataset_url: str):
    try:
        # Parse dataset URL
        dataset_name, dataset_config = parse_hf_dataset_url(dataset_url)
        logging.info(f"Loading dataset: {dataset_name} (config: {dataset_config})")

        # Load model and tokenizer
        model_name = "microsoft/phi-2"
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", trust_remote_code=True)

        # Add padding token
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # Load dataset from Hugging Face Hub
        dataset = load_dataset(
            dataset_name,
            dataset_config,
            trust_remote_code=True
        )

        # Handle dataset splits
        if "train" not in dataset:
            raise ValueError("Dataset must have a 'train' split")
        
        train_dataset = dataset["train"]
        eval_dataset = dataset.get("validation", dataset.get("test", None))

        # Split if no validation set
        if eval_dataset is None:
            split = train_dataset.train_test_split(test_size=0.1, seed=42)
            train_dataset = split["train"]
            eval_dataset = split["test"]

        # Tokenization function
        def tokenize_function(examples):
            return tokenizer(
                examples["text"],  # Adjust column name as needed
                padding="max_length",
                truncation=True,
                max_length=256,
                return_tensors="pt",
            )

        # Tokenize datasets
        tokenized_train = train_dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=train_dataset.column_names
        )
        tokenized_eval = eval_dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=eval_dataset.column_names
        )

        # Data collator
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer,
            mlm=False
        )

        # Training arguments
        training_args = TrainingArguments(
            output_dir="./phi2-results",
            per_device_train_batch_size=2,
            per_device_eval_batch_size=2,
            num_train_epochs=3,
            logging_dir="./logs",
            logging_steps=10,
            fp16=False,
        )

        # Trainer
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_train,
            eval_dataset=tokenized_eval,
            data_collator=data_collator,
        )

        # Start training
        logging.info("Training started...")
        trainer.train()
        trainer.save_model("./phi2-trained-model")
        logging.info("Training completed!")

        return "✅ Training succeeded! Model saved."

    except Exception as e:
        logging.error(f"Training failed: {str(e)}")
        return f"❌ Training failed: {str(e)}"

# Gradio interface
with gr.Blocks(title="Phi-2 Training") as demo:
    gr.Markdown("# 🚀 Train Phi-2 with HF Hub Data")
    
    with gr.Row():
        dataset_url = gr.Textbox(
            label="Dataset URL",
            value="https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0"
        )
    
    start_btn = gr.Button("Start Training", variant="primary")
    status_output = gr.Textbox(label="Status", interactive=False)
    
    start_btn.click(
        fn=lambda url: threading.Thread(target=train, args=(url,)).start(),
        inputs=[dataset_url],
        outputs=status_output
    )

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860
    )