import torch
import gradio as gr
import multiprocessing
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset

device = "cpu"
training_process = None
log_file = "training_status.log"

# Logging function
def log_status(message):
    with open(log_file, "w") as f:
        f.write(message)

# Read training status
def read_status():
    if os.path.exists(log_file):
        with open(log_file, "r") as f:
            return f.read()
    return "⏳ در انتظار شروع ترینینگ..."

# Function to find the text column dynamically
def find_text_column(dataset):
    sample = dataset["train"][0]  # Get the first row of the training dataset
    for column in sample.keys():
        if isinstance(sample[column], str):  # Find the first text-like column
            return column
    return None  # No valid text column found

# Model training function
def train_model(dataset_url, model_url, epochs):
    try:
        log_status("🚀 در حال بارگیری مدل...")
        
        tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_url, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu"
        )

        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=8,
            lora_alpha=32,
            lora_dropout=0.1,
            target_modules=["q_proj", "v_proj"]
        )
        model = get_peft_model(model, lora_config)
        model.to(device)

        dataset = load_dataset(dataset_url)

        # Automatically detect the correct text column
        text_column = find_text_column(dataset)
        if not text_column:
            log_status("❌ خطا: ستون متنی در دیتاست یافت نشد!")
            return
        
        def tokenize_function(examples):
            return tokenizer(examples[text_column], truncation=True, padding="max_length", max_length=256)

        tokenized_datasets = dataset.map(tokenize_function, batched=True)
        train_dataset = tokenized_datasets["train"]

        # Automatically check for validation dataset
        eval_dataset = tokenized_datasets["validation"] if "validation" in tokenized_datasets else None

        training_args = TrainingArguments(
            output_dir="./deepseek_lora_cpu",
            evaluation_strategy="epoch" if eval_dataset else "no",  # Enable evaluation if validation data exists
            learning_rate=5e-4,
            per_device_train_batch_size=1,
            per_device_eval_batch_size=1,
            num_train_epochs=int(epochs),
            save_strategy="epoch",
            save_total_limit=2,
            logging_dir="./logs",
            logging_steps=10,
            fp16=False,
            gradient_checkpointing=True,
            optim="adamw_torch",
            report_to="none"
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset  # Add eval dataset if available
        )

        log_status("🚀 ترینینگ شروع شد...")

        for epoch in range(int(epochs)):
            log_status(f"🔄 در حال اجرا: Epoch {epoch+1}/{epochs}...")
            trainer.train(resume_from_checkpoint=True)
            trainer.save_model(f"./deepseek_lora_finetuned_epoch_{epoch+1}")

        log_status("✅ ترینینگ کامل شد!")

    except Exception as e:
        log_status(f"❌ خطا: {str(e)}")

# Start training in a separate process
def start_training(dataset_url, model_url, epochs):
    global training_process
    if training_process is None or not training_process.is_alive():
        training_process = multiprocessing.Process(target=train_model, args=(dataset_url, model_url, epochs))
        training_process.start()
        return "🚀 ترینینگ شروع شد!"
    else:
        return "⚠ ترینینگ در حال اجرا است!"

# Function to update the status
def update_status():
    return read_status()

# Gradio UI
with gr.Blocks() as app:
    gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - نمایش وضعیت لحظه‌ای")

    with gr.Row():
        dataset_input = gr.Textbox(label="📂 لینک دیتاست (Hugging Face)")
        model_input = gr.Textbox(label="🤖 مدل پایه (Hugging Face)")
        epochs_input = gr.Number(label="🔄 تعداد Epochs", value=3)

    start_button = gr.Button("🚀 شروع ترینینگ")
    status_output = gr.Textbox(label="📢 وضعیت ترینینگ", interactive=False)

    start_button.click(start_training, inputs=[dataset_input, model_input, epochs_input], outputs=status_output)
    status_button = gr.Button("🔄 بروزرسانی وضعیت")
    status_button.click(update_status, outputs=status_output)

app.launch()