Spaces:
Sleeping
Sleeping
from io import StringIO | |
import csv | |
import json | |
import logging | |
import torch | |
from transformers import pipeline | |
from pydantic import BaseModel, ValidationError, validator | |
logger = logging.getLogger(__name__) | |
class Card(BaseModel): | |
question: str | |
answer: str | |
class Message(BaseModel): | |
role: str | |
content: list[Card] | |
def parse_content(cls, v): | |
if isinstance(v, str): | |
try: | |
content_list = json.loads(v) | |
return content_list | |
except json.JSONDecodeError as e: | |
raise ValueError(f"Error decoding 'content' JSON: {e}") from e | |
return v | |
def content_to_json(self) -> str: | |
return json.dumps([card.dict() for card in self.content], indent=2) | |
def content_to_csv(self) -> str: | |
output = StringIO() | |
writer = csv.writer(output) | |
writer.writerow(['Question', 'Answer']) # CSV Header | |
for card in self.content: | |
writer.writerow([card.question, card.answer]) | |
return output.getvalue() | |
class PydanticEncoder(json.JSONEncoder): | |
def default(self, obj): | |
if isinstance(obj, BaseModel): | |
return obj.dict() | |
return super().default(obj) | |
class Pipeline: | |
def __init__(self, model_name: str = "Qwen/Qwen2.5-7B-Instruct"): | |
self.torch_pipe = pipeline( | |
"text-generation", | |
"Qwen/Qwen2.5-7B-Instruct", | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
self.device = self._determine_device() | |
self.messages = [ | |
{"role": "system", "content": """You are an expert flashcard creator. You always include a single knowledge item per flashcard. | |
- You ALWAYS include a single knowledge item per flashcard. | |
- You ALWAYS respond in valid JSON format. | |
Format responses like the example below. | |
EXAMPLE: | |
[ | |
{"question": "What is AI?", "answer": "Artificial Intelligence."}, | |
{"question": "What is ML?", "answer": "Machine Learning."} | |
] | |
"""}, | |
] | |
def extract_flashcards(self, content: str = "", max_new_tokens: int = 1024) -> str: | |
user_prompt = {"role": "user", "content": content} | |
self.messages.append(user_prompt) | |
response_message = self.torch_pipe( | |
self.messages, | |
max_new_tokens=max_new_tokens | |
)[0]["generated_text"][-1] | |
return response_message | |
def format_flashcards(self, output_format: str, response: str) -> str: | |
output = "" | |
try : | |
message = parse_message(response) | |
logger.debug("after parse_obj_as") | |
except ValidationError as e: | |
raise e | |
if output_format.lower() == "json": | |
output = message.content_to_json() | |
elif output_format.lower() == "csv": | |
output = message.content_to_csv() | |
return output | |
def generate_flashcards(self, output_format: str, content: str) -> str: | |
response = self.extract_flashcards(content) | |
return self.format_flashcards(output_format, response) | |
def parse_message(self, input_dict: dict[str, any]) -> Message: | |
try: | |
# Extract the role | |
role = input_dict['role'] | |
# Parse the content | |
content = input_dict['content'] | |
# If content is a string, try to parse it as JSON | |
if isinstance(content, str): | |
content = content.strip() | |
content = json.loads(content) | |
# Create Card objects from the content | |
cards = [Card(**item) for item in content] | |
# Create and return the Message object | |
return Message(role=role, content=cards) | |
except json.JSONDecodeError as e: | |
raise ValueError(f"Invalid JSON in content: {str(e)}") | |
except ValidationError as e: | |
raise ValueError(f"Validation error: {str(e)}") | |
except KeyError as e: | |
raise ValueError(f"Missing required key: {str(e)}") | |
def _determine_device(self): | |
if torch.cuda.is_available(): | |
return torch.device("cuda") | |
elif torch.backends.mps.is_available(): | |
return torch.device("mps") | |
else: | |
return torch.device("cpu") | |
def parse_message(input_dict: dict[str, any]) -> Message: | |
try: | |
# Extract the role | |
role: str = input_dict['role'] | |
# Parse the content | |
content: str = input_dict['content'] | |
# If content is a string, try to parse it as JSON | |
if isinstance(content, str): | |
content = json.loads(content) | |
# Create Card objects from the content | |
cards = [Card(**item) for item in content] | |
# Create and return the Message object | |
return Message(role=role, content=cards) | |
except json.JSONDecodeError as e: | |
raise ValueError(f"Invalid JSON in content: {str(e)}") | |
except ValidationError as e: | |
raise ValueError(f"Validation error: {str(e)}") | |
except KeyError as e: | |
raise ValueError(f"Missing required key: {str(e)}") | |