flashcard-studio / app /pipeline.py
Nathan Slaughter
add pipeline method
2f264ab
raw
history blame
5.17 kB
from io import StringIO
import csv
import json
import logging
import torch
from transformers import pipeline
from pydantic import BaseModel, ValidationError, validator
logger = logging.getLogger(__name__)
class Card(BaseModel):
question: str
answer: str
class Message(BaseModel):
role: str
content: list[Card]
@validator('content', pre=True)
def parse_content(cls, v):
if isinstance(v, str):
try:
content_list = json.loads(v)
return content_list
except json.JSONDecodeError as e:
raise ValueError(f"Error decoding 'content' JSON: {e}") from e
return v
def content_to_json(self) -> str:
return json.dumps([card.dict() for card in self.content], indent=2)
def content_to_csv(self) -> str:
output = StringIO()
writer = csv.writer(output)
writer.writerow(['Question', 'Answer']) # CSV Header
for card in self.content:
writer.writerow([card.question, card.answer])
return output.getvalue()
class PydanticEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, BaseModel):
return obj.dict()
return super().default(obj)
class Pipeline:
def __init__(self, model_name: str = "Qwen/Qwen2.5-7B-Instruct"):
self.torch_pipe = pipeline(
"text-generation",
"Qwen/Qwen2.5-7B-Instruct",
torch_dtype="auto",
device_map="auto"
)
self.device = self._determine_device()
self.messages = [
{"role": "system", "content": """You are an expert flashcard creator. You always include a single knowledge item per flashcard.
- You ALWAYS include a single knowledge item per flashcard.
- You ALWAYS respond in valid JSON format.
Format responses like the example below.
EXAMPLE:
[
{"question": "What is AI?", "answer": "Artificial Intelligence."},
{"question": "What is ML?", "answer": "Machine Learning."}
]
"""},
]
def extract_flashcards(self, content: str = "", max_new_tokens: int = 1024) -> str:
user_prompt = {"role": "user", "content": content}
self.messages.append(user_prompt)
response_message = self.torch_pipe(
self.messages,
max_new_tokens=max_new_tokens
)[0]["generated_text"][-1]
return response_message
def format_flashcards(self, output_format: str, response: str) -> str:
output = ""
try :
message = parse_message(response)
logger.debug("after parse_obj_as")
except ValidationError as e:
raise e
if output_format.lower() == "json":
output = message.content_to_json()
elif output_format.lower() == "csv":
output = message.content_to_csv()
return output
def generate_flashcards(self, output_format: str, content: str) -> str:
response = self.extract_flashcards(content)
return self.format_flashcards(output_format, response)
def parse_message(self, input_dict: dict[str, any]) -> Message:
try:
# Extract the role
role = input_dict['role']
# Parse the content
content = input_dict['content']
# If content is a string, try to parse it as JSON
if isinstance(content, str):
content = content.strip()
content = json.loads(content)
# Create Card objects from the content
cards = [Card(**item) for item in content]
# Create and return the Message object
return Message(role=role, content=cards)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in content: {str(e)}")
except ValidationError as e:
raise ValueError(f"Validation error: {str(e)}")
except KeyError as e:
raise ValueError(f"Missing required key: {str(e)}")
def _determine_device(self):
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
def parse_message(input_dict: dict[str, any]) -> Message:
try:
# Extract the role
role: str = input_dict['role']
# Parse the content
content: str = input_dict['content']
# If content is a string, try to parse it as JSON
if isinstance(content, str):
content = json.loads(content)
# Create Card objects from the content
cards = [Card(**item) for item in content]
# Create and return the Message object
return Message(role=role, content=cards)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in content: {str(e)}")
except ValidationError as e:
raise ValueError(f"Validation error: {str(e)}")
except KeyError as e:
raise ValueError(f"Missing required key: {str(e)}")