# src/utils/conversation_summarizer.py
from typing import List, Dict
from transformers import pipeline
import numpy as np
from datetime import datetime
from config.config import settings


class ConversationSummarizer:
    def __init__(
        self,
        model_name: str = None,
        max_length: int = None,
        min_length: int = None
    ):
        """
        Initialize the summarizer

        Args:
            model_name (str, optional): Override default model from config
            max_length (int, optional): Override default max_length from config
            min_length (int, optional): Override default min_length from config
        """
        # Use provided values or fall back to config values
        self.model_name = model_name or settings.SUMMARIZER_CONFIG['model_name']
        self.max_length = max_length or settings.SUMMARIZER_CONFIG['max_length']
        self.min_length = min_length or settings.SUMMARIZER_CONFIG['min_length']

        # Initialize the summarizer with config settings
        self.summarizer = pipeline(
            "summarization",
            model=self.model_name,
            device=settings.SUMMARIZER_CONFIG['device'],
            model_kwargs=settings.SUMMARIZER_CONFIG['model_kwargs']
        )

    async def summarize_conversation(
        self,
        messages: List[Dict],
        include_metadata: bool = True
    ) -> Dict:
        """
        Summarize a conversation and provide key insights
        """
        # Format conversation for summarization
        formatted_convo = self._format_conversation(messages)

        # Generate summary
        summary = self.summarizer(
            formatted_convo,
            max_length=self.max_length,
            min_length=self.min_length,
            do_sample=False
        )[0]['summary_text']

        # Extract key insights
        insights = self._extract_insights(messages)

        # Generate metadata if requested
        metadata = self._generate_metadata(
            messages) if include_metadata else {}

        return {
            'summary': summary,
            'key_insights': insights,
            'metadata': metadata
        }

    def _format_conversation(self, messages: List[Dict]) -> str:
        """Format conversation for summarization"""
        formatted = []
        for msg in messages:
            role = msg.get('role', 'unknown')
            content = msg.get('content', '')
            formatted.append(f"{role}: {content}")

        return "\n".join(formatted)

    def _extract_insights(self, messages: List[Dict]) -> Dict:
        """Extract key insights from conversation"""
        # Count message types
        message_counts = {
            'user': len([m for m in messages if m.get('role') == 'user']),
            'assistant': len([m for m in messages if m.get('role') == 'assistant'])
        }

        # Calculate average message length
        avg_length = np.mean([len(m.get('content', '')) for m in messages])

        # Extract main topics (simplified)
        topics = self._extract_topics(messages)

        return {
            'message_distribution': message_counts,
            'average_message_length': int(avg_length),
            'main_topics': topics,
            'total_messages': len(messages)
        }

    def _extract_topics(self, messages: List[Dict]) -> List[str]:
        """Extract main topics from conversation"""
        # Combine all messages
        full_text = " ".join([m.get('content', '') for m in messages])

        # Use the summarizer to extract main points
        topics = self.summarizer(
            full_text,
            max_length=50,
            min_length=10,
            do_sample=False
        )[0]['summary_text'].split('. ')

        return topics

    def _generate_metadata(self, messages: List[Dict]) -> Dict:
        """Generate conversation metadata"""
        if not messages:
            return {}

        return {
            'start_time': messages[0].get('timestamp', None),
            'end_time': messages[-1].get('timestamp', None),
            'duration_minutes': self._calculate_duration(messages),
            'sources_used': self._extract_sources(messages)
        }

    def _calculate_duration(self, messages: List[Dict]) -> float:
        """Calculate conversation duration in minutes"""
        try:
            start_time = datetime.fromisoformat(
                messages[0].get('timestamp', ''))
            end_time = datetime.fromisoformat(
                messages[-1].get('timestamp', ''))
            return (end_time - start_time).total_seconds() / 60
        except:
            return 0

    def _extract_sources(self, messages: List[Dict]) -> List[str]:
        """Extract unique sources used in conversation"""
        sources = set()
        for message in messages:
            if message.get('sources'):
                for source in message['sources']:
                    sources.add(source.get('filename', ''))
        return list(sources)