Diptaraj Sen commited on
Commit
bc9b706
·
1 Parent(s): d10976f

story generation model changed

Browse files
Files changed (1) hide show
  1. app/storytelling.py +28 -33
app/storytelling.py CHANGED
@@ -1,14 +1,15 @@
1
  from app.logger import get_logger
2
  logger = get_logger(__name__)
3
 
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
  import torch
6
 
7
- model = "google/flan-t5-small"
 
8
  # Load tokenizer and model
9
- tokenizer =AutoTokenizer.from_pretrained(model)
10
- model = AutoModelForSeq2SeqLM.from_pretrained(model)
11
- model.eval()
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  model.to(device)
14
 
@@ -16,38 +17,32 @@ def generate_story(caption: str, max_length: int = 256) -> str:
16
  logger.info("Generating story...")
17
  try:
18
  # Turn caption into a story prompt
19
- prompt = f"""
20
- You are a creative storyteller who writes engaging short stories.
21
-
22
- - Length: The story should have around 200-300 words
23
- - Your job is to take the image caption and expand it into a vivid short story.
24
- - Start with an engaging hook, build a little conflict, and wrap up with a satisfying ending.
25
- - Use descriptive language and maintain a consistent tone.
26
-
27
- Caption: "{caption}"
28
-
29
- Write the story below:
30
- """.strip()
31
 
32
  # Tokenize and run through model
33
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
34
  outputs = model.generate(
35
- **inputs,
36
- max_length=max_length,
37
- do_sample=True,
38
- top_k=50,
39
- top_p=0.95,
40
- temperature=0.7,
41
- num_return_sequences=1,
42
- pad_token_id=tokenizer.pad_token_id,
43
- early_stopping=True,
44
- repetition_penalty=1.2,
45
- length_penalty=1.0)
46
-
47
-
48
- # Decode generated text
49
- story = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
- return story.replace(prompt, "").strip()
51
  except Exception as e:
52
  logger.exception(f"Failed to generate story: {str(e)}")
53
  raise
 
1
  from app.logger import get_logger
2
  logger = get_logger(__name__)
3
 
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
 
7
+ model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
8
+
9
  # Load tokenizer and model
10
+ tokenizer =AutoTokenizer.from_pretrained(model_id)
11
+ model = AutoModelForCausalLM.from_pretrained(model_id)
12
+
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  model.to(device)
15
 
 
17
  logger.info("Generating story...")
18
  try:
19
  # Turn caption into a story prompt
20
+ prompt = (
21
+ "<|system|>\n"
22
+ "You are a helpful assistant.</s>\n"
23
+ "<|user|>\n"
24
+ f"Write a complete, short story about {caption}. Make sure the story has a clear ending.\n</s>\n"
25
+ "<|assistant|>\n"
26
+ )
 
 
 
 
 
27
 
28
  # Tokenize and run through model
29
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
30
  outputs = model.generate(
31
+ **inputs,
32
+ max_new_tokens=1000,
33
+ do_sample=True,
34
+ temperature=0.8,
35
+ top_p=0.9,
36
+ top_k=50,
37
+ eos_token_id=tokenizer.eos_token_id
38
+ )
39
+
40
+
41
+ # Decode and clean output
42
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
+ generated_story = generated_text[len(prompt):] # Strip prompt part
44
+
45
+ return generated_story.replace(prompt, "").strip()
 
46
  except Exception as e:
47
  logger.exception(f"Failed to generate story: {str(e)}")
48
  raise