Nathan Slaughter commited on
Commit
2f264ab
·
1 Parent(s): 4d17caa

add pipeline method

Browse files
app/interface.py CHANGED
@@ -1,10 +1,10 @@
1
  import gradio as gr
2
- from .models import LanguageModel
3
  from .processing import process_file, process_text_input
4
 
5
  def create_interface():
6
  # Initialize the language model
7
- language_model = LanguageModel()
8
 
9
  # Define the Output Format Selector
10
  output_format_selector = gr.Radio(
@@ -18,18 +18,18 @@ def create_interface():
18
  flashcard_output_file = gr.Textbox(
19
  label="Flashcards",
20
  lines=20,
21
- placeholder="Extracted flashcards will appear here..."
22
  )
23
  flashcard_output_text = gr.Textbox(
24
  label="Flashcards",
25
  lines=20,
26
- placeholder="Extracted flashcards will appear here..."
27
  )
28
 
29
  # Define the Gradio interface function for File Upload
30
  def handle_file_upload(file_obj, output_format):
31
  try:
32
- flashcards = process_file(file_obj, output_format, language_model)
33
  return flashcards
34
  except ValueError as ve:
35
  return str(ve)
@@ -37,16 +37,16 @@ def create_interface():
37
  # Define the Gradio interface function for Text Input
38
  def handle_text_input(input_text, output_format):
39
  try:
40
- flashcards = process_text_input(input_text, output_format, language_model)
41
  return flashcards
42
  except ValueError as ve:
43
  return str(ve)
44
 
45
  # Create the Gradio Tabs
46
  with gr.Blocks() as interface:
47
- gr.Markdown("# Flashcard Extraction Tool")
48
  gr.Markdown(
49
- "Extract flashcards from uploaded files or directly input text. Choose your preferred output format."
50
  )
51
  with gr.Tab("Upload File"):
52
  with gr.Row():
 
1
  import gradio as gr
2
+ from .pipeline import Pipeline
3
  from .processing import process_file, process_text_input
4
 
5
  def create_interface():
6
  # Initialize the language model
7
+ language_model = Pipeline()
8
 
9
  # Define the Output Format Selector
10
  output_format_selector = gr.Radio(
 
18
  flashcard_output_file = gr.Textbox(
19
  label="Flashcards",
20
  lines=20,
21
+ placeholder="Your flashcards will appear here..."
22
  )
23
  flashcard_output_text = gr.Textbox(
24
  label="Flashcards",
25
  lines=20,
26
+ placeholder="Your flashcards will appear here..."
27
  )
28
 
29
  # Define the Gradio interface function for File Upload
30
  def handle_file_upload(file_obj, output_format):
31
  try:
32
+ flashcards = process_file(file_obj, output_format, Pipeline())
33
  return flashcards
34
  except ValueError as ve:
35
  return str(ve)
 
37
  # Define the Gradio interface function for Text Input
38
  def handle_text_input(input_text, output_format):
39
  try:
40
+ flashcards = process_text_input(input_text, output_format, Pipeline())
41
  return flashcards
42
  except ValueError as ve:
43
  return str(ve)
44
 
45
  # Create the Gradio Tabs
46
  with gr.Blocks() as interface:
47
+ gr.Markdown("# Flashcard Studio")
48
  gr.Markdown(
49
+ "Make flashcards from uploaded files or directly input text. Choose your preferred output format."
50
  )
51
  with gr.Tab("Upload File"):
52
  with gr.Row():
app/models.py DELETED
@@ -1,31 +0,0 @@
1
- import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
-
4
- class LanguageModel:
5
- def __init__(self, model_name: str = "Qwen/Qwen2.5-7B-Instruct"):
6
- self.device = self._determine_device()
7
- self.model = AutoModelForCausalLM.from_pretrained(
8
- model_name,
9
- torch_dtype="auto",
10
- device_map="auto"
11
- )
12
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
13
-
14
- def _determine_device(self):
15
- if torch.cuda.is_available():
16
- return torch.device("cuda")
17
- elif torch.backends.mps.is_available():
18
- return torch.device("mps")
19
- else:
20
- return torch.device("cpu")
21
-
22
- def generate_flashcards(self, prompt: str, max_new_tokens: int = 1024) -> str:
23
- inputs = self.tokenizer(prompt, return_tensors='pt').to(self.model.device)
24
- with torch.no_grad():
25
- output_ids = self.model.generate(
26
- inputs.input_ids,
27
- max_new_tokens=max_new_tokens,
28
- do_sample=True
29
- )
30
- response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
31
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/pipeline.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import StringIO
2
+ import csv
3
+ import json
4
+ import logging
5
+
6
+ import torch
7
+ from transformers import pipeline
8
+ from pydantic import BaseModel, ValidationError, validator
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class Card(BaseModel):
13
+ question: str
14
+ answer: str
15
+
16
+ class Message(BaseModel):
17
+ role: str
18
+ content: list[Card]
19
+
20
+ @validator('content', pre=True)
21
+ def parse_content(cls, v):
22
+ if isinstance(v, str):
23
+ try:
24
+ content_list = json.loads(v)
25
+ return content_list
26
+ except json.JSONDecodeError as e:
27
+ raise ValueError(f"Error decoding 'content' JSON: {e}") from e
28
+ return v
29
+
30
+ def content_to_json(self) -> str:
31
+ return json.dumps([card.dict() for card in self.content], indent=2)
32
+
33
+ def content_to_csv(self) -> str:
34
+ output = StringIO()
35
+ writer = csv.writer(output)
36
+ writer.writerow(['Question', 'Answer']) # CSV Header
37
+ for card in self.content:
38
+ writer.writerow([card.question, card.answer])
39
+ return output.getvalue()
40
+
41
+ class PydanticEncoder(json.JSONEncoder):
42
+ def default(self, obj):
43
+ if isinstance(obj, BaseModel):
44
+ return obj.dict()
45
+ return super().default(obj)
46
+
47
+ class Pipeline:
48
+ def __init__(self, model_name: str = "Qwen/Qwen2.5-7B-Instruct"):
49
+ self.torch_pipe = pipeline(
50
+ "text-generation",
51
+ "Qwen/Qwen2.5-7B-Instruct",
52
+ torch_dtype="auto",
53
+ device_map="auto"
54
+ )
55
+ self.device = self._determine_device()
56
+ self.messages = [
57
+ {"role": "system", "content": """You are an expert flashcard creator. You always include a single knowledge item per flashcard.
58
+ - You ALWAYS include a single knowledge item per flashcard.
59
+ - You ALWAYS respond in valid JSON format.
60
+
61
+ Format responses like the example below.
62
+
63
+ EXAMPLE:
64
+ [
65
+ {"question": "What is AI?", "answer": "Artificial Intelligence."},
66
+ {"question": "What is ML?", "answer": "Machine Learning."}
67
+ ]
68
+ """},
69
+ ]
70
+
71
+ def extract_flashcards(self, content: str = "", max_new_tokens: int = 1024) -> str:
72
+ user_prompt = {"role": "user", "content": content}
73
+ self.messages.append(user_prompt)
74
+ response_message = self.torch_pipe(
75
+ self.messages,
76
+ max_new_tokens=max_new_tokens
77
+ )[0]["generated_text"][-1]
78
+ return response_message
79
+
80
+ def format_flashcards(self, output_format: str, response: str) -> str:
81
+ output = ""
82
+ try :
83
+ message = parse_message(response)
84
+ logger.debug("after parse_obj_as")
85
+ except ValidationError as e:
86
+ raise e
87
+ if output_format.lower() == "json":
88
+ output = message.content_to_json()
89
+ elif output_format.lower() == "csv":
90
+ output = message.content_to_csv()
91
+ return output
92
+
93
+ def generate_flashcards(self, output_format: str, content: str) -> str:
94
+ response = self.extract_flashcards(content)
95
+ return self.format_flashcards(output_format, response)
96
+
97
+ def parse_message(self, input_dict: dict[str, any]) -> Message:
98
+ try:
99
+ # Extract the role
100
+ role = input_dict['role']
101
+
102
+ # Parse the content
103
+ content = input_dict['content']
104
+ # If content is a string, try to parse it as JSON
105
+ if isinstance(content, str):
106
+ content = content.strip()
107
+ content = json.loads(content)
108
+
109
+ # Create Card objects from the content
110
+ cards = [Card(**item) for item in content]
111
+
112
+ # Create and return the Message object
113
+ return Message(role=role, content=cards)
114
+ except json.JSONDecodeError as e:
115
+ raise ValueError(f"Invalid JSON in content: {str(e)}")
116
+ except ValidationError as e:
117
+ raise ValueError(f"Validation error: {str(e)}")
118
+ except KeyError as e:
119
+ raise ValueError(f"Missing required key: {str(e)}")
120
+
121
+ def _determine_device(self):
122
+ if torch.cuda.is_available():
123
+ return torch.device("cuda")
124
+ elif torch.backends.mps.is_available():
125
+ return torch.device("mps")
126
+ else:
127
+ return torch.device("cpu")
128
+
129
+ def parse_message(input_dict: dict[str, any]) -> Message:
130
+ try:
131
+ # Extract the role
132
+ role: str = input_dict['role']
133
+
134
+ # Parse the content
135
+ content: str = input_dict['content']
136
+
137
+ # If content is a string, try to parse it as JSON
138
+ if isinstance(content, str):
139
+ content = json.loads(content)
140
+
141
+ # Create Card objects from the content
142
+ cards = [Card(**item) for item in content]
143
+
144
+ # Create and return the Message object
145
+ return Message(role=role, content=cards)
146
+ except json.JSONDecodeError as e:
147
+ raise ValueError(f"Invalid JSON in content: {str(e)}")
148
+ except ValidationError as e:
149
+ raise ValueError(f"Validation error: {str(e)}")
150
+ except KeyError as e:
151
+ raise ValueError(f"Missing required key: {str(e)}")
app/processing.py CHANGED
@@ -27,13 +27,12 @@ def format_prompt(output_format: str) -> str:
27
  Formats the prompt based on the output type.
28
  """
29
  if output_format.lower() == "json":
30
- return """You only respond with cards in JSON format. Follow the example below.
31
 
32
  EXAMPLE:
33
  [
34
  {"question": "What is AI?", "answer": "Artificial Intelligence."},
35
  {"question": "What is ML?", "answer": "Machine Learning."}
36
- ...
37
  ]
38
  """
39
  elif output_format.lower() == "csv":
@@ -42,32 +41,29 @@ def format_prompt(output_format: str) -> str:
42
  EXAMPLE:
43
  "What is AI?", "Artificial Intelligence."
44
  "What is ML?", "Machine Learning."
45
- ...
46
  """
47
 
48
- def extract_flashcards(text: str, output_format: str, language_model: str) -> str:
49
- """
50
- Extracts flashcards from the input text using the LLM and formats them in CSV or JSON.
51
- """
52
- prompt = f"""You are an expert flashcard creator. You always include a single knowledge item per flashcard.
53
 
54
- {format_prompt(output_format)}
55
 
56
 
57
- Extract flashcards from the user's text:
58
 
59
- {text}
60
 
61
- Do not include the prompt or any other unnecessary information in the flashcards.
62
- Do not include triple ticks (```) or any other code blocks in the flashcards.
63
- """
64
- # TODO:
65
- # see https://qwen.readthedocs.io/en/latest/inference/chat.html
66
- # e.g. pipeline = pipeline("text-generation", model="Qwen/Qwen2.5-7B-Instruct")
67
- response = language_model.generate_flashcards(prompt)
68
- return response
69
 
70
- def process_file(file_obj, output_format: str, language_model) -> str:
71
  """
72
  Processes the uploaded file based on its type and extracts flashcards.
73
  """
@@ -81,15 +77,15 @@ def process_file(file_obj, output_format: str, language_model) -> str:
81
  else:
82
  raise ValueError("Unsupported file type.")
83
 
84
- flashcards = extract_flashcards(text, output_format, language_model)
85
  return flashcards
86
 
87
- def process_text_input(input_text: str, output_format: str, language_model) -> str:
88
  """
89
  Processes the input text and extracts flashcards.
90
  """
91
  if not input_text.strip():
92
  raise ValueError("No text provided.")
93
 
94
- flashcards = extract_flashcards(input_text, output_format, language_model)
95
  return flashcards
 
27
  Formats the prompt based on the output type.
28
  """
29
  if output_format.lower() == "json":
30
+ return """You only respond in JSON format. Follow the example below.
31
 
32
  EXAMPLE:
33
  [
34
  {"question": "What is AI?", "answer": "Artificial Intelligence."},
35
  {"question": "What is ML?", "answer": "Machine Learning."}
 
36
  ]
37
  """
38
  elif output_format.lower() == "csv":
 
41
  EXAMPLE:
42
  "What is AI?", "Artificial Intelligence."
43
  "What is ML?", "Machine Learning."
 
44
  """
45
 
46
+ # def extract_flashcards(text: str, output_format: str, pipeline: str) -> str:
47
+ # """
48
+ # Extracts flashcards from the input text using the LLM and formats them in CSV or JSON.
49
+ # """
50
+ # prompt = f"""You are an expert flashcard creator. You always include a single knowledge item per flashcard.
51
 
52
+ # {format_prompt(output_format)}
53
 
54
 
55
+ # Extract flashcards from the user's text:
56
 
57
+ # {text}
58
 
59
+ # Do not include the prompt or any other unnecessary information in the flashcards.
60
+ # Do not include triple ticks (```) or any other code blocks in the flashcards.
61
+ # """
62
+ # # TODO:
63
+ # response = pipeline.generate_flashcards("json", prompt)
64
+ # return response
 
 
65
 
66
+ def process_file(file_obj, output_format: str, pipeline) -> str:
67
  """
68
  Processes the uploaded file based on its type and extracts flashcards.
69
  """
 
77
  else:
78
  raise ValueError("Unsupported file type.")
79
 
80
+ flashcards = pipeline.generate_flashcards(output_format, text)
81
  return flashcards
82
 
83
+ def process_text_input(output_format: str, input_text: str) -> str:
84
  """
85
  Processes the input text and extracts flashcards.
86
  """
87
  if not input_text.strip():
88
  raise ValueError("No text provided.")
89
 
90
+ flashcards = pipeline.generate_flashcards(output_format, input_text)
91
  return flashcards
tests/conftest.py CHANGED
@@ -1,9 +1,9 @@
1
  import pytest
2
  from unittest.mock import Mock
3
- from app.models import LanguageModel
4
 
5
  @pytest.fixture
6
- def language_model():
7
  """
8
  Fixture to provide a mocked LanguageModel instance.
9
  """
 
1
  import pytest
2
  from unittest.mock import Mock
3
+ from app.pipeline import LanguageModel
4
 
5
  @pytest.fixture
6
+ def pipeline():
7
  """
8
  Fixture to provide a mocked LanguageModel instance.
9
  """
tests/{test_models.py → test_pipeline.py} RENAMED
@@ -1,8 +1,6 @@
1
- # tests/test_models.py
2
-
3
  import pytest
4
 
5
- def test_generate_flashcards(language_model, mocker):
6
  """
7
  Test the generate_flashcards method of LanguageModel.
8
  """
@@ -10,11 +8,11 @@ def test_generate_flashcards(language_model, mocker):
10
  expected_response = '{"flashcards": [{"Question": "What is AI?", "Answer": "Artificial Intelligence."}]}'
11
 
12
  # Configure the mock to return a specific response
13
- language_model.generate_flashcards.return_value = expected_response
14
 
15
  # Call the method
16
- response = language_model.generate_flashcards(prompt)
17
 
18
  # Assertions
19
  assert response == expected_response
20
- language_model.generate_flashcards.assert_called_once_with(prompt)
 
 
 
1
  import pytest
2
 
3
+ def test_generate_flashcards(pipeline, mocker):
4
  """
5
  Test the generate_flashcards method of LanguageModel.
6
  """
 
8
  expected_response = '{"flashcards": [{"Question": "What is AI?", "Answer": "Artificial Intelligence."}]}'
9
 
10
  # Configure the mock to return a specific response
11
+ pipeline.generate_flashcards.return_value = expected_response
12
 
13
  # Call the method
14
+ response = pipeline.generate_flashcards(prompt)
15
 
16
  # Assertions
17
  assert response == expected_response
18
+ pipeline.generate_flashcards.assert_called_once_with(prompt)
tests/test_processing.py CHANGED
@@ -1,9 +1,7 @@
1
- # tests/test_processing.py
2
-
3
  import pytest
4
  from app.processing import process_text_input, process_file
5
 
6
- def test_process_text_input_success(language_model):
7
  """
8
  Test processing of valid text input.
9
  """
@@ -11,11 +9,11 @@ def test_process_text_input_success(language_model):
11
  output_format = "JSON"
12
  expected_output = '{"flashcards": []}'
13
 
14
- result = process_text_input(input_text, output_format, language_model)
15
  assert result == expected_output
16
- language_model.generate_flashcards.assert_called_once()
17
 
18
- def test_process_text_input_empty(language_model):
19
  """
20
  Test processing of empty text input.
21
  """
@@ -23,10 +21,10 @@ def test_process_text_input_empty(language_model):
23
  output_format = "JSON"
24
 
25
  with pytest.raises(ValueError) as excinfo:
26
- process_text_input(input_text, output_format, language_model)
27
  assert "No text provided." in str(excinfo.value)
28
 
29
- def test_process_file_unsupported_type(language_model, tmp_path):
30
  """
31
  Test processing of an unsupported file type.
32
  """
@@ -35,10 +33,10 @@ def test_process_file_unsupported_type(language_model, tmp_path):
35
  dummy_file.write_text("Unsupported content")
36
 
37
  with pytest.raises(ValueError) as excinfo:
38
- process_file(dummy_file, "JSON", language_model)
39
  assert "Unsupported file type." in str(excinfo.value)
40
 
41
- def test_process_file_pdf(language_model, tmp_path, mocker):
42
  """
43
  Test processing of a PDF file.
44
  """
@@ -51,11 +49,11 @@ def test_process_file_pdf(language_model, tmp_path, mocker):
51
 
52
  expected_output = '{"flashcards": []}'
53
 
54
- result = process_file(dummy_file, "JSON", language_model)
55
  assert result == expected_output
56
- language_model.generate_flashcards.assert_called_once()
57
 
58
- def test_process_file_txt(language_model, tmp_path, mocker):
59
  """
60
  Test processing of a TXT file.
61
  """
@@ -68,6 +66,6 @@ def test_process_file_txt(language_model, tmp_path, mocker):
68
 
69
  expected_output = '{"flashcards": []}'
70
 
71
- result = process_file(dummy_file, "JSON", language_model)
72
  assert result == expected_output
73
- language_model.generate_flashcards.assert_called_once()
 
 
 
1
  import pytest
2
  from app.processing import process_text_input, process_file
3
 
4
+ def test_process_text_input_success(pipeline):
5
  """
6
  Test processing of valid text input.
7
  """
 
9
  output_format = "JSON"
10
  expected_output = '{"flashcards": []}'
11
 
12
+ result = process_text_input(input_text, output_format, pipeline)
13
  assert result == expected_output
14
+ pipeline.generate_flashcards.assert_called_once()
15
 
16
+ def test_process_text_input_empty(pipeline):
17
  """
18
  Test processing of empty text input.
19
  """
 
21
  output_format = "JSON"
22
 
23
  with pytest.raises(ValueError) as excinfo:
24
+ process_text_input(input_text, output_format, pipeline)
25
  assert "No text provided." in str(excinfo.value)
26
 
27
+ def test_process_file_unsupported_type(pipeline, tmp_path):
28
  """
29
  Test processing of an unsupported file type.
30
  """
 
33
  dummy_file.write_text("Unsupported content")
34
 
35
  with pytest.raises(ValueError) as excinfo:
36
+ process_file(dummy_file, "JSON", pipeline)
37
  assert "Unsupported file type." in str(excinfo.value)
38
 
39
+ def test_process_file_pdf(pipeline, tmp_path, mocker):
40
  """
41
  Test processing of a PDF file.
42
  """
 
49
 
50
  expected_output = '{"flashcards": []}'
51
 
52
+ result = process_file(dummy_file, "JSON", pipeline)
53
  assert result == expected_output
54
+ pipeline.generate_flashcards.assert_called_once()
55
 
56
+ def test_process_file_txt(pipeline, tmp_path, mocker):
57
  """
58
  Test processing of a TXT file.
59
  """
 
66
 
67
  expected_output = '{"flashcards": []}'
68
 
69
+ result = process_file(dummy_file, "JSON", pipeline)
70
  assert result == expected_output
71
+ pipeline.generate_flashcards.assert_called_once()