import os
import sys
import logging
from datetime import datetime
import torch
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
from PIL import Image

# 配置日志格式
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('app.log')
    ]
)
logger = logging.getLogger(__name__)

def print_section(title, char='='):
    """打印格式化的章节标题"""
    print(f"\n{char * 50}")
    print(f"{title.center(50)}")
    print(f"{char * 50}\n")

def print_table(data):
    """格式化打印表格数据"""
    if not data:
        print("No data available")
        return
        
    # 计算每列的最大宽度
    col_widths = []
    for i in range(len(data[0])):
        col_width = max(len(str(row[i])) for row in data)
        col_widths.append(col_width)
    
    # 打印表头
    header = data[0]
    header_str = " | ".join(str(header[i]).ljust(col_widths[i]) for i in range(len(header)))
    print(header_str)
    print("-" * len(header_str))
    
    # 打印数据行
    for row in data[1:]:
        row_str = " | ".join(str(row[i]).ljust(col_widths[i]) for i in range(len(row)))
        print(row_str)

class ChartAnalyzer:
    def __init__(self):
        try:
            print_section("初始化模型")
            print("正在加载模型和处理器...")
            self.model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
            self.processor = AutoProcessor.from_pretrained("google/deplot")
            print("✓ 模型加载完成")
        except Exception as e:
            print("✗ 模型加载失败")
            logger.error(f"Error initializing model: {str(e)}")
            raise

    def process_image(self, image_path, prompt=None):
        """处理图片并生成数据表格"""
        try:
            print_section("图片处理", char='-')
            
            # 验证文件存在
            if not os.path.exists(image_path):
                raise FileNotFoundError(f"找不到图片文件: {image_path}")

            # 打开并处理图片
            print(f"正在处理图片: {image_path}")
            image = Image.open(image_path)
            
            # 准备输入
            if prompt is None:
                prompt = "Generate underlying data table of the figure below:"
            
            inputs = self.processor(
                images=image, 
                text=prompt,
                return_tensors="pt"
            )

            # 生成预测
            print("\n正在生成数据分析...")
            with torch.no_grad():
                predictions = self.model.generate(
                    **inputs,
                    max_new_tokens=512,
                    num_beams=4,
                    length_penalty=1.0
                )

            # 解码预测结果
            raw_output = self.processor.decode(predictions[0], skip_special_tokens=True)
            
            # 处理结果
            split_by_newline = raw_output.split("<0x0A>")
            result_array = []
            for item in split_by_newline:
                if item.strip():  # 跳过空行
                    result_array.append([x.strip() for x in item.split("|")])

            # 保存结果
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_file = f'results_{timestamp}.log'
            
            with open(output_file, mode='w', encoding='utf-8') as file:
                for row in result_array:
                    file.write(" | ".join(row) + "\n")
            
            print(f"\n✓ 结果已保存至: {output_file}")
            return result_array

        except Exception as e:
            print("\n✗ 处理失败")
            logger.error(f"Error processing image: {str(e)}")
            raise

def main():
    try:
        print_section("图表数据提取系统", char='*')
        
        # 创建分析器实例
        analyzer = ChartAnalyzer()
        
        # 指定图片路径
        image_path = '05e57f1c9acff69f1eb6fa72d4805d0.jpg'
        
        # 处理图片
        results = analyzer.process_image(image_path)
        
        # 打印结果
        print_section("分析结果")
        print_table(results)
        
        print_section("处理完成", char='*')
            
    except Exception as e:
        logger.error(f"Application error: {str(e)}")
        print("\n✗ 程序执行出错,请查看日志获取详细信息")
        raise

if __name__ == "__main__":
    main()