Diptaraj Sen commited on
Commit
d87e8d0
Β·
0 Parent(s):

First Commit

Browse files
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ venv/
2
+ __pycache__/
3
+ outputs/
4
+ logs/
5
+ *.pyc
.streamlit/config.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [server]
2
+ runOnSave = true
app/__init__.py ADDED
File without changes
app/captioning.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.logger import get_logger
2
+ logger = get_logger(__name__)
3
+
4
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
5
+ from PIL import Image
6
+ import torch
7
+
8
+ # Load processor and model (ViT)
9
+ model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
10
+ processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
11
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
12
+ # Move model to GPU if available
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ print("DEVICE:--------> ",device)
15
+ model.to(device)
16
+
17
+ def generate_caption(image_path: str) -> str:
18
+ logger.info("Generating caption...")
19
+ try:
20
+ # Open and convert image to RGB
21
+ image = Image.open(image_path).convert('RGB')
22
+
23
+ # Preprocess image and prepare inputs
24
+ inputs = processor(images=image, return_tensors="pt").to(device)
25
+
26
+ # Generate caption (greedy decoding for now)
27
+ output = model.generate(**inputs, max_length=16, num_beams=1)
28
+
29
+ # Decode output to text
30
+ caption = tokenizer.decode(output[0], skip_special_tokens=True)
31
+
32
+ logger.info(f"Caption generated: {caption}")
33
+ return caption
34
+ except Exception as e:
35
+ logger.exception("Failed to generate caption")
36
+ raise
app/logger.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ def get_logger(name: str):
5
+ logs_dir = "logs"
6
+ os.makedirs(logs_dir, exist_ok=True)
7
+
8
+ logger = logging.getLogger(name)
9
+ logger.setLevel(logging.DEBUG)
10
+
11
+ # File handler
12
+ file_handler = logging.FileHandler(os.path.join(logs_dir, "pipeline.log"))
13
+ file_handler.setLevel(logging.DEBUG)
14
+
15
+ # Console handler
16
+ console_handler = logging.StreamHandler()
17
+ console_handler.setLevel(logging.INFO)
18
+
19
+ # Formatter
20
+ formatter = logging.Formatter(
21
+ "[%(asctime)s] [%(levelname)s] - %(name)s - %(message)s", "%Y-%m-%d %H:%M:%S"
22
+ )
23
+ file_handler.setFormatter(formatter)
24
+ console_handler.setFormatter(formatter)
25
+
26
+ # Avoid duplicate handlers
27
+ if not logger.handlers:
28
+ logger.addHandler(file_handler)
29
+ logger.addHandler(console_handler)
30
+
31
+ return logger
app/storytelling.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.logger import get_logger
2
+ logger = get_logger(__name__)
3
+
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ import torch
6
+
7
+ model = "google/flan-t5-small"
8
+ # Load tokenizer and model
9
+ tokenizer =AutoTokenizer.from_pretrained(model)
10
+ model = AutoModelForSeq2SeqLM.from_pretrained(model)
11
+ model.eval()
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model.to(device)
14
+
15
+ def generate_story(caption: str, max_length: int = 256) -> str:
16
+ logger.info("Generating story...")
17
+ try:
18
+ # Turn caption into a story prompt
19
+ prompt = f"""
20
+ You are a creative storyteller who writes engaging short stories.
21
+
22
+ - Length: The story should have around 200-300 words
23
+ - Your job is to take the image caption and expand it into a vivid short story.
24
+ - Start with an engaging hook, build a little conflict, and wrap up with a satisfying ending.
25
+ - Use descriptive language and maintain a consistent tone.
26
+
27
+ Caption: "{caption}"
28
+
29
+ Write the story below:
30
+ """.strip()
31
+
32
+ # Tokenize and run through model
33
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
34
+ outputs = model.generate(
35
+ **inputs,
36
+ max_length=max_length,
37
+ do_sample=True,
38
+ top_k=50,
39
+ top_p=0.95,
40
+ temperature=0.7,
41
+ num_return_sequences=1,
42
+ pad_token_id=tokenizer.pad_token_id,
43
+ early_stopping=True,
44
+ repetition_penalty=1.2,
45
+ length_penalty=1.0)
46
+
47
+
48
+ # Decode generated text
49
+ story = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
+ return story.replace(prompt, "").strip()
51
+ except Exception as e:
52
+ logger.exception(f"Failed to generate story: {str(e)}")
53
+ raise
app/tts.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.logger import get_logger
2
+ logger = get_logger(__name__)
3
+
4
+ from gtts import gTTS
5
+ from playsound import playsound
6
+ import os
7
+ import uuid
8
+
9
+ def speak_story(story: str, lang: str = 'en') -> str:
10
+ """
11
+ Converts the story text to speech and saves it as an audio file.
12
+ Optionally plays the audio.
13
+
14
+ Returns the path to the audio file.
15
+ """
16
+ logger.info("Generating audio...")
17
+ try:
18
+ # Generate a unique filename
19
+ filename = f"story_{uuid.uuid4().hex}.mp3"
20
+ filepath = os.path.join("outputs", filename)
21
+
22
+ # Ensure outputs/ directory exists
23
+ os.makedirs("outputs", exist_ok=True)
24
+
25
+ # Generate TTS from text
26
+ tts = gTTS(text=story, lang=lang)
27
+ tts.save(filepath)
28
+
29
+ # Play the audio (optional)
30
+ try:
31
+ playsound(filepath)
32
+ except Exception as e:
33
+ logger.exception("Couldn't play audio: {e}")
34
+
35
+ return filepath
36
+ except Exception as e:
37
+ logger.exception("Failed to generate audio:{e}")
38
+ raise
run_pipeline.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from app.captioning import generate_caption
4
+ from app.storytelling import generate_story
5
+ from app.tts import speak_story
6
+
7
+ def main(file_name):
8
+ image_path = os.path.join(os.path.dirname(__file__), "assets",file_name)
9
+ print("πŸ” Generating caption from image...")
10
+ caption = generate_caption(image_path)
11
+ print(f"\nπŸ–ΌοΈ Caption: {caption}")
12
+
13
+ print("\n✍️ Generating story from caption...")
14
+ story = generate_story(caption)
15
+ print(f"\nπŸ“– Story:\n{story}")
16
+
17
+ print("\nπŸ”Š Converting story to speech...")
18
+ audio_path = speak_story(story)
19
+ print(f"\nβœ… Audio saved at: {audio_path}")
20
+
21
+ if __name__ == "__main__":
22
+ parser = argparse.ArgumentParser(description="Run image β†’ caption β†’ story β†’ speech pipeline")
23
+ parser.add_argument("image_path", type=str, help="Path to the input image")
24
+
25
+ args = parser.parse_args()
26
+ main(args.image_path)
streamlit_app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from app.captioning import generate_caption
3
+ from app.storytelling import generate_story
4
+ from app.tts import speak_story
5
+ import tempfile
6
+ from PIL import Image
7
+
8
+ st.set_page_config(page_title="GenAI Storyteller", layout="centered")
9
+
10
+ st.title("πŸ“ΈπŸ§  GenAI Storyteller")
11
+ st.markdown("Upload an image, get a caption, a story, and hear it spoken aloud!")
12
+
13
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
14
+
15
+ if uploaded_file:
16
+ # Show uploaded image
17
+ image = Image.open(uploaded_file)
18
+ st.image(image, caption="Uploaded Image", use_container_width=True)
19
+
20
+ # Process the pipeline on button click
21
+ if st.button("Generate Story"):
22
+ with st.spinner("Generating caption..."):
23
+ # Save uploaded image to a temp file
24
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
25
+ image.save(tmp.name)
26
+ caption = generate_caption(tmp.name)
27
+
28
+ st.success("Caption Generated!")
29
+ st.write(f"**Caption**: {caption}")
30
+
31
+ with st.spinner("Generating story..."):
32
+ story = generate_story(caption)
33
+
34
+ st.success("Story Generated!")
35
+ st.text_area("πŸ“– Story", story, height=250)
36
+
37
+ with st.spinner("Generating audio..."):
38
+ audio_path = speak_story(story)
39
+
40
+ st.success("Done! Here's the story in audio:")
41
+ audio_file = open(audio_path, "rb")
42
+ st.audio(audio_file.read(), format="audio/mp3")
tests/__init__.py ADDED
File without changes
tests/test_captioning.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from app.captioning import generate_caption
3
+
4
+ # Build image path relative to this file
5
+ file_name = "IMG_3736.jpg"
6
+ image_path = os.path.join(os.path.dirname(__file__), file_name)
7
+
8
+ caption = generate_caption(image_path) # Put a real image path here
9
+ print("Caption:", caption)
tests/test_story.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # test_story.py
2
+
3
+ from app.storytelling import generate_story
4
+
5
+ caption = "a group of people standing in a pool"
6
+ story = generate_story(caption)
7
+ print("Generated Story:\n", story)
tests/test_tts.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from app.tts import speak_story
2
+
3
+ story = """Once upon a time in a quiet village, a curious cat named Whiskers loved to watch the birds from his favorite spot by the window..."""
4
+ audio_path = speak_story(story)
5
+
6
+ print("Audio saved to:", audio_path)