import spaces import gradio as gr import re import os hf_token = os.environ.get('HF_TOKEN') from gradio_client import Client, handle_file clipi_client = Client("fffiloni/CLIP-Interrogator-2") from transformers import AutoTokenizer, AutoModelForCausalLM model_path = "meta-llama/Llama-2-7b-chat-hf" tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, use_auth_token=hf_token) model = AutoModelForCausalLM.from_pretrained(model_path, use_auth_token=hf_token).half().cuda() #client = Client("https://fffiloni-test-llama-api-debug.hf.space/", hf_token=hf_token) clipi_client = Client("https://fffiloni-clip-interrogator-2.hf.space/") @spaces.GPU def llama_gen_story(prompt): """Generate a fictional story using the LLaMA 2 model based on a prompt. Args: prompt: A string prompt containing an image description and story generation instructions. Returns: A generated fictional story string with special formatting and tokens removed. """ instruction = """[INST] <>\nYou are a storyteller. You'll be given an image description and some keyword about the image. For that given you'll be asked to generate a story that you think could fit very well with the image provided. Always answer with a cool story, while being safe as possible. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<>\n\n{} [/INST]""" prompt = instruction.format(prompt) generate_ids = model.generate(tokenizer(prompt, return_tensors='pt').input_ids.cuda(), max_new_tokens=4096) output_text = tokenizer.decode(generate_ids[0], skip_special_tokens=True) #print(generate_ids) #print(output_text) pattern = r'\[INST\].*?\[/INST\]' cleaned_text = re.sub(pattern, '', output_text, flags=re.DOTALL) return cleaned_text def get_text_after_colon(input_text): # Find the first occurrence of ":" colon_index = input_text.find(":") # Check if ":" exists in the input_text if colon_index != -1: # Extract the text after the colon result_text = input_text[colon_index + 1:].strip() return result_text else: # Return the original text if ":" is not found return input_text def infer(image_input, audience): """Generate a fictional story based on an image using CLIP Interrogator and LLaMA2. Args: image_input: A file path to the input image to analyze. audience: A string indicating the target audience, such as 'Children' or 'Adult'. Returns: A formatted, multi-paragraph fictional story string related to the image content. Steps: 1. Use the CLIP Interrogator model to generate a semantic caption from the image. 2. Format a prompt asking the LLaMA2 model to write a story based on the caption. 3. Clean and format the story output for readability. """ gr.Info('Calling CLIP Interrogator ...') clipi_result = clipi_client.predict( image=handle_file(image_input), mode="best", best_max_flavors=4, api_name="/clipi2" ) print(clipi_result) llama_q = f""" I'll give you a simple image caption, please provide a fictional story for a {audience} audience that would fit well with the image. Please be creative, do not worry and only generate a cool fictional story. Here's the image description: '{clipi_result}' """ gr.Info('Calling Llama2 ...') result = llama_gen_story(llama_q) print(f"Llama2 result: {result}") result = get_text_after_colon(result) # Split the text into paragraphs based on actual line breaks paragraphs = result.split('\n') # Join the paragraphs back with an extra empty line between each paragraph formatted_text = '\n\n'.join(paragraphs) return formatted_text css=""" #col-container {max-width: 910px; margin-left: auto; margin-right: auto;} div#story textarea { font-size: 1.5em; line-height: 1.4em; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown( """

Image to Story

Upload an image, get a story made by Llama2 !

""" ) with gr.Row(): with gr.Column(): image_in = gr.Image(label="Image input", type="filepath", elem_id="image-in") audience = gr.Radio(label="Target Audience", choices=["Children", "Adult"], value="Children") submit_btn = gr.Button('Tell me a story') with gr.Column(): #caption = gr.Textbox(label="Generated Caption") story = gr.Textbox(label="generated Story", elem_id="story") gr.Examples(examples=[["./examples/crabby.png", "Children"],["./examples/hopper.jpeg", "Adult"]], fn=infer, inputs=[image_in, audience], outputs=[story], cache_examples=False ) submit_btn.click(fn=infer, inputs=[image_in, audience], outputs=[story]) demo.queue(max_size=12).launch(ssr_mode=False, mcp_server=True)