from app.logger import get_logger
logger = get_logger(__name__)

import torch
from PIL import Image
from transformers import ViTFeatureExtractor, AutoTokenizer, VisionEncoderDecoderModel

model_id = "ydshieh/vit-gpt2-coco-en"

# Load model, tokenizer, and image processor
feature_extractor = ViTFeatureExtractor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = VisionEncoderDecoderModel.from_pretrained(model_id)
model.eval()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"DEVICE:--------> {device}")
model.to(device)

def generate_caption(image_path: str) -> str:
    logger.info("Generating caption...")
    try:
        # Open and convert image to RGB
        image = Image.open(image_path).convert('RGB')
        
        # Preprocess image and prepare inputs
        pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)

        # Generate caption (greedy decoding for now)
        with torch.no_grad():
            output_ids = model.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True).sequences

        preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        caption = [pred.strip() for pred in preds]

        return caption
    except Exception as e:
        logger.exception("Failed to generate caption")
        raise