from io import StringIO import csv import json import logging import torch from transformers import pipeline from pydantic import BaseModel, ValidationError, validator logger = logging.getLogger(__name__) class Card(BaseModel): question: str answer: str class Message(BaseModel): role: str content: list[Card] @validator('content', pre=True) def parse_content(cls, v): if isinstance(v, str): try: content_list = json.loads(v) return content_list except json.JSONDecodeError as e: raise ValueError(f"Error decoding 'content' JSON: {e}") from e return v def content_to_json(self) -> str: return json.dumps([card.dict() for card in self.content], indent=2) def content_to_csv(self) -> str: output = StringIO() writer = csv.writer(output) writer.writerow(['Question', 'Answer']) # CSV Header for card in self.content: writer.writerow([card.question, card.answer]) return output.getvalue() class PydanticEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, BaseModel): return obj.dict() return super().default(obj) class Pipeline: def __init__(self, model_name: str = "Qwen/Qwen2.5-7B-Instruct"): self.torch_pipe = pipeline( "text-generation", "Qwen/Qwen2.5-7B-Instruct", torch_dtype="auto", device_map="auto" ) self.device = self._determine_device() self.messages = [ {"role": "system", "content": """You are an expert flashcard creator. You always include a single knowledge item per flashcard. - You ALWAYS include a single knowledge item per flashcard. - You ALWAYS respond in valid JSON format. Format responses like the example below. EXAMPLE: [ {"question": "What is AI?", "answer": "Artificial Intelligence."}, {"question": "What is ML?", "answer": "Machine Learning."} ] """}, ] def extract_flashcards(self, content: str = "", max_new_tokens: int = 1024) -> str: user_prompt = {"role": "user", "content": content} self.messages.append(user_prompt) response_message = self.torch_pipe( self.messages, max_new_tokens=max_new_tokens )[0]["generated_text"][-1] return response_message def format_flashcards(self, output_format: str, response: str) -> str: output = "" try : message = parse_message(response) logger.debug("after parse_obj_as") except ValidationError as e: raise e if output_format.lower() == "json": output = message.content_to_json() elif output_format.lower() == "csv": output = message.content_to_csv() return output def generate_flashcards(self, output_format: str, content: str) -> str: response = self.extract_flashcards(content) return self.format_flashcards(output_format, response) def parse_message(self, input_dict: dict[str, any]) -> Message: try: # Extract the role role = input_dict['role'] # Parse the content content = input_dict['content'] # If content is a string, try to parse it as JSON if isinstance(content, str): content = content.strip() content = json.loads(content) # Create Card objects from the content cards = [Card(**item) for item in content] # Create and return the Message object return Message(role=role, content=cards) except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON in content: {str(e)}") except ValidationError as e: raise ValueError(f"Validation error: {str(e)}") except KeyError as e: raise ValueError(f"Missing required key: {str(e)}") def _determine_device(self): if torch.cuda.is_available(): return torch.device("cuda") elif torch.backends.mps.is_available(): return torch.device("mps") else: return torch.device("cpu") def parse_message(input_dict: dict[str, any]) -> Message: try: # Extract the role role: str = input_dict['role'] # Parse the content content: str = input_dict['content'] # If content is a string, try to parse it as JSON if isinstance(content, str): content = json.loads(content) # Create Card objects from the content cards = [Card(**item) for item in content] # Create and return the Message object return Message(role=role, content=cards) except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON in content: {str(e)}") except ValidationError as e: raise ValueError(f"Validation error: {str(e)}") except KeyError as e: raise ValueError(f"Missing required key: {str(e)}")