Nathan Slaughter commited on
Commit
82915e5
·
1 Parent(s): 8428312

move parse message

Browse files
Files changed (4) hide show
  1. app/models.py +25 -1
  2. app/pipeline.py +12 -35
  3. app/processing.py +5 -7
  4. tests/test_pipeline.py +2 -18
app/models.py CHANGED
@@ -13,7 +13,7 @@ class Message(BaseModel):
13
  content: list[Card]
14
 
15
  @validator('content', pre=True)
16
- def parse_content(cls, v):
17
  if isinstance(v, str):
18
  try:
19
  content_list = json.loads(v)
@@ -44,3 +44,27 @@ class PydanticEncoder(json.JSONEncoder):
44
  if isinstance(obj, BaseModel):
45
  return obj.dict()
46
  return super().default(obj)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  content: list[Card]
14
 
15
  @validator('content', pre=True)
16
+ def parse_content(cls, v: str) -> 'Message':
17
  if isinstance(v, str):
18
  try:
19
  content_list = json.loads(v)
 
44
  if isinstance(obj, BaseModel):
45
  return obj.dict()
46
  return super().default(obj)
47
+
48
+ def parse_message(input_dict: dict[str, any]) -> Message:
49
+ try:
50
+ # Extract the role
51
+ role: str = input_dict['role']
52
+
53
+ # Parse the content
54
+ content: str = input_dict['content']
55
+
56
+ # If content is a string, try to parse it as JSON
57
+ if isinstance(content, str):
58
+ content = json.loads(content)
59
+
60
+ # Create Card objects from the content
61
+ cards = [Card(**item) for item in content]
62
+
63
+ # Create and return the Message object
64
+ return Message(role=role, content=cards)
65
+ except json.JSONDecodeError as e:
66
+ raise ValueError(f"Invalid JSON in content: {str(e)}")
67
+ except ValidationError as e:
68
+ raise ValueError(f"Validation error: {str(e)}")
69
+ except KeyError as e:
70
+ raise ValueError(f"Missing required key: {str(e)}")
app/pipeline.py CHANGED
@@ -5,7 +5,7 @@ import logging
5
  import torch
6
  from transformers import pipeline
7
 
8
- from .models import Card, Message, ValidationError
9
 
10
  logger = logging.getLogger(__name__)
11
  logging.basicConfig(filename="pipeline.log", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S")
@@ -21,7 +21,7 @@ class Pipeline:
21
  self.device = self._determine_device()
22
  logger.info(f"device type: {self.device}")
23
  self.messages = [
24
- {"role": "system", "content": """You are an expert flashcard creator. You always include a single knowledge item per flashcard.
25
  - You ALWAYS include a single knowledge item per flashcard.
26
  - You ALWAYS respond in valid JSON format.
27
 
@@ -38,47 +38,24 @@ class Pipeline:
38
  def extract_flashcards(self, content: str = "", max_new_tokens: int = 1024) -> str:
39
  user_prompt = {"role": "user", "content": content}
40
  self.messages.append(user_prompt)
41
- response_message = self.torch_pipe(
42
- self.messages,
43
- max_new_tokens=max_new_tokens
44
- )[0]["generated_text"][-1]
45
- return response_message
 
 
 
 
46
 
47
  def generate_flashcards(self, output_format: str, content: str) -> str:
48
  response = self.extract_flashcards(content)
49
  return format_flashcards(output_format, response)
50
 
51
- def _determine_device(self):
52
  if torch.cuda.is_available():
53
  return torch.device("cuda")
54
  elif torch.backends.mps.is_available():
55
  return torch.device("mps")
56
  else:
57
  return torch.device("cpu")
58
-
59
- def parse_message(input_dict: dict[str, any]) -> Message:
60
- try:
61
- # Extract the role
62
- role: str = input_dict['role']
63
-
64
- # Parse the content
65
- content: str = input_dict['content']
66
-
67
- # If content is a string, try to parse it as JSON
68
- if isinstance(content, str):
69
- content = json.loads(content)
70
-
71
- # Create Card objects from the content
72
- cards = [Card(**item) for item in content]
73
-
74
- # Create and return the Message object
75
- return Message(role=role, content=cards)
76
- except json.JSONDecodeError as e:
77
- logger.error(f"Invalid JSON in content: {str(e)}")
78
- raise ValueError(f"Invalid JSON in content: {str(e)}")
79
- except ValidationError as e:
80
- logger.error(f"Validation error: {str(e)}")
81
- raise ValueError(f"Validation error: {str(e)}")
82
- except KeyError as e:
83
- logger.error(f"Missing required key: {str(e)}")
84
- raise ValueError(f"Missing required key: {str(e)}")
 
5
  import torch
6
  from transformers import pipeline
7
 
8
+ from .processing import format_flashcards
9
 
10
  logger = logging.getLogger(__name__)
11
  logging.basicConfig(filename="pipeline.log", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S")
 
21
  self.device = self._determine_device()
22
  logger.info(f"device type: {self.device}")
23
  self.messages = [
24
+ {"role": "system", "content": """You are an expert flashcard creator.
25
  - You ALWAYS include a single knowledge item per flashcard.
26
  - You ALWAYS respond in valid JSON format.
27
 
 
38
  def extract_flashcards(self, content: str = "", max_new_tokens: int = 1024) -> str:
39
  user_prompt = {"role": "user", "content": content}
40
  self.messages.append(user_prompt)
41
+ try:
42
+ response_message = self.torch_pipe(
43
+ self.messages,
44
+ max_new_tokens=max_new_tokens
45
+ )[0]["generated_text"][-1]
46
+ return response_message
47
+ except Exception as e:
48
+ logger.error(f"Error extracting flashcards: {str(e)}")
49
+ raise ValueError(f"Error extraction flashcards: {str(e)}")
50
 
51
  def generate_flashcards(self, output_format: str, content: str) -> str:
52
  response = self.extract_flashcards(content)
53
  return format_flashcards(output_format, response)
54
 
55
+ def _determine_device(self) -> torch.device:
56
  if torch.cuda.is_available():
57
  return torch.device("cuda")
58
  elif torch.backends.mps.is_available():
59
  return torch.device("mps")
60
  else:
61
  return torch.device("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/processing.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import pymupdf4llm
3
 
 
 
4
  def process_pdf(pdf_path: str) -> str:
5
  """
6
  Extracts text from a PDF file using pymupdf4llm.
@@ -28,14 +30,12 @@ def process_file(file_obj, output_format: str, pipeline) -> str:
28
  """
29
  file_path = file_obj.name
30
  file_ext = os.path.splitext(file_path)[1].lower()
31
-
32
  if file_ext == '.pdf':
33
  text = process_pdf(file_path)
34
  elif file_ext in ['.txt', '.md']:
35
  text = read_text_file(file_path)
36
  else:
37
  raise ValueError("Unsupported file type.")
38
-
39
  flashcards = pipeline.generate_flashcards(output_format, text)
40
  return flashcards
41
 
@@ -49,16 +49,14 @@ def process_text_input(output_format: str, input_text: str) -> str:
49
  flashcards = pipeline.generate_flashcards(output_format, input_text)
50
  return flashcards
51
 
52
-
53
- def format_flashcards(self, output_format: str, response: str) -> str:
54
  output = ""
55
  try :
56
  message = parse_message(response)
57
- logger.debug("after parse_obj_as")
58
- except ValidationError as e:
59
  raise e
60
  if output_format.lower() == "json":
61
- output = message.content_to_json()
62
  elif output_format.lower() == "csv":
63
  output = message.content_to_csv()
64
  return output
 
1
  import os
2
  import pymupdf4llm
3
 
4
+ from .models import parse_message
5
+
6
  def process_pdf(pdf_path: str) -> str:
7
  """
8
  Extracts text from a PDF file using pymupdf4llm.
 
30
  """
31
  file_path = file_obj.name
32
  file_ext = os.path.splitext(file_path)[1].lower()
 
33
  if file_ext == '.pdf':
34
  text = process_pdf(file_path)
35
  elif file_ext in ['.txt', '.md']:
36
  text = read_text_file(file_path)
37
  else:
38
  raise ValueError("Unsupported file type.")
 
39
  flashcards = pipeline.generate_flashcards(output_format, text)
40
  return flashcards
41
 
 
49
  flashcards = pipeline.generate_flashcards(output_format, input_text)
50
  return flashcards
51
 
52
+ def format_flashcards(output_format: str, response: str) -> str:
 
53
  output = ""
54
  try :
55
  message = parse_message(response)
56
+ except Exception as e:
 
57
  raise e
58
  if output_format.lower() == "json":
59
+ output:str = message.content_to_json()
60
  elif output_format.lower() == "csv":
61
  output = message.content_to_csv()
62
  return output
tests/test_pipeline.py CHANGED
@@ -3,8 +3,8 @@ from unittest.mock import Mock, patch
3
  import json
4
  from io import StringIO
5
  from pydantic import ValidationError
6
- from app.pipeline import Pipeline, Message, Card, parse_message
7
- from app.models import PydanticEncoder
8
 
9
  # Tests for Pipeline class
10
  @pytest.fixture
@@ -13,22 +13,6 @@ def mock_pipeline():
13
  mock_pipe.return_value = Mock()
14
  yield Pipeline("mock_model")
15
 
16
- # def test_extract_flashcards(mock_pipeline):
17
- # mock_pipeline.torch_pipe.return_value = [{"generated_text": [{"role": "assistant", "content": '[{"question": "Q", "answer": "A"}]'}]}]
18
- # response = mock_pipeline.extract_flashcards("Test content")
19
- # assert isinstance(response, dict)
20
- # assert "content" in response
21
-
22
- # def test_format_flashcards_csv(mock_pipeline):
23
- # response = {"role": "assistant", "content": '[{"question": "Q", "answer": "A"}]'}
24
- # formatted = mock_pipeline.format_flashcards("csv", response)
25
- # assert formatted.strip() == "Question,Answer\nQ,A"
26
-
27
- # def test_generate_flashcards(mock_pipeline):
28
- # mock_pipeline.extract_flashcards.return_value = {"role": "assistant", "content": '[{"question": "Q", "answer": "A"}]'}
29
- # result = mock_pipeline.generate_flashcards("json", "Test content")
30
- # assert json.loads(result) == [{"question": "Q", "answer": "A"}]
31
-
32
  # Tests for parse_message function
33
  def test_parse_message_valid_input():
34
  input_dict = {
 
3
  import json
4
  from io import StringIO
5
  from pydantic import ValidationError
6
+ from app.pipeline import Pipeline
7
+ from app.models import PydanticEncoder, Message, Card, parse_message
8
 
9
  # Tests for Pipeline class
10
  @pytest.fixture
 
13
  mock_pipe.return_value = Mock()
14
  yield Pipeline("mock_model")
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # Tests for parse_message function
17
  def test_parse_message_valid_input():
18
  input_dict = {