Spaces:
Paused
Paused
Upload 5 files
Browse files- uno/utils/__init__.py +0 -0
- uno/utils/convert_yaml_to_args_file.py +34 -0
- uno/utils/image_describer.py +57 -0
- uno/utils/prompt_enhancer.py +150 -0
- uno/utils/prompt_router.py +48 -0
uno/utils/__init__.py
ADDED
File without changes
|
uno/utils/convert_yaml_to_args_file.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import yaml
|
17 |
+
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument("--yaml", type=str, required=True)
|
20 |
+
parser.add_argument("--arg", type=str, required=True)
|
21 |
+
args = parser.parse_args()
|
22 |
+
|
23 |
+
|
24 |
+
with open(args.yaml, "r") as f:
|
25 |
+
data = yaml.safe_load(f)
|
26 |
+
|
27 |
+
with open(args.arg, "w") as f:
|
28 |
+
for k, v in data.items():
|
29 |
+
if isinstance(v, list):
|
30 |
+
v = list(map(str, v))
|
31 |
+
v = " ".join(v)
|
32 |
+
if v is None:
|
33 |
+
continue
|
34 |
+
print(f"--{k} {v}", end=" ", file=f)
|
uno/utils/image_describer.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
import re
|
5 |
+
|
6 |
+
BLIP2_MODEL_NAME = "Salesforce/blip2-flan-t5-xl"
|
7 |
+
BLIP_DEVICE = "cpu"
|
8 |
+
MAX_LENGTH = 120
|
9 |
+
|
10 |
+
processor = None
|
11 |
+
model = None
|
12 |
+
|
13 |
+
def lazy_load_blip2():
|
14 |
+
global processor, model
|
15 |
+
if processor is None or model is None:
|
16 |
+
print("\U0001F680 [BLIP2] Loading BLIP-2 model and processor on CPU...")
|
17 |
+
processor = Blip2Processor.from_pretrained(BLIP2_MODEL_NAME)
|
18 |
+
model = Blip2ForConditionalGeneration.from_pretrained(
|
19 |
+
BLIP2_MODEL_NAME,
|
20 |
+
torch_dtype=torch.float32
|
21 |
+
).to(BLIP_DEVICE).eval()
|
22 |
+
|
23 |
+
def clean_caption(text: str) -> str:
|
24 |
+
text = text.strip()
|
25 |
+
text = re.sub(r"\s+", " ", text)
|
26 |
+
text = text.strip(' "\n')
|
27 |
+
return text[0].upper() + text[1:] if text else text
|
28 |
+
|
29 |
+
def describe_uploaded_images(images: list[Image.Image]) -> dict:
|
30 |
+
if not images:
|
31 |
+
return {"style_description": "", "full_caption": ""}
|
32 |
+
|
33 |
+
lazy_load_blip2()
|
34 |
+
|
35 |
+
captions = []
|
36 |
+
prompt = (
|
37 |
+
"Describe this image in detail. Focus on the art medium, visual style, mood or tone, lighting or rendering cues, "
|
38 |
+
"and describe how people interact with objects if applicable."
|
39 |
+
)
|
40 |
+
|
41 |
+
for img in images:
|
42 |
+
try:
|
43 |
+
inputs = processor(images=img, text=prompt, return_tensors="pt").to(BLIP_DEVICE)
|
44 |
+
generated_ids = model.generate(**inputs, max_new_tokens=MAX_LENGTH)
|
45 |
+
caption = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
46 |
+
cleaned = clean_caption(caption)
|
47 |
+
if cleaned and cleaned not in captions:
|
48 |
+
captions.append(cleaned)
|
49 |
+
except Exception as e:
|
50 |
+
print(f"β [BLIP-2 ERROR] Failed to describe image: {e}")
|
51 |
+
continue
|
52 |
+
|
53 |
+
joined_caption = "; ".join(captions)
|
54 |
+
return {
|
55 |
+
"style_description": joined_caption,
|
56 |
+
"full_caption": joined_caption
|
57 |
+
}
|
uno/utils/prompt_enhancer.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from typing import List, Optional
|
4 |
+
import openai
|
5 |
+
|
6 |
+
from uno.utils.prompt_router import classify_prompt_intent
|
7 |
+
from uno.utils.image_describer import describe_uploaded_images
|
8 |
+
|
9 |
+
# === OpenAI Client ===
|
10 |
+
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
11 |
+
|
12 |
+
# === Constants ===
|
13 |
+
PROMPT_PREFIX = "It's very important that"
|
14 |
+
PROMPT_SUFFIX = (
|
15 |
+
"and take all the time you needed as it is very important to achieve the best possible result"
|
16 |
+
)
|
17 |
+
TEMPLATE_FORMAT = "prefix [art medium] [main object or objective] [attribute] [expression] [key light] [detailing] suffix"
|
18 |
+
DEFAULT_NUM_PROMPTS = 5
|
19 |
+
MIN_WORD_COUNT = 8
|
20 |
+
MAX_PROMPT_LENGTH = 120
|
21 |
+
REQUIRED_ELEMENTS = [
|
22 |
+
"art medium",
|
23 |
+
"main object",
|
24 |
+
"attribute",
|
25 |
+
"expression",
|
26 |
+
"key light",
|
27 |
+
"detailing"
|
28 |
+
]
|
29 |
+
|
30 |
+
INTENT_SYSTEM_INSTRUCTIONS = {
|
31 |
+
"product_ad": "Focus: visually advertise a product using professional commercial photography.",
|
32 |
+
"service_promotion": "Focus: promote a service by illustrating its usage, setting, or emotional impact.",
|
33 |
+
"public_awareness": "Focus: support a cause, campaign, or message with narrative visual storytelling.",
|
34 |
+
"brand_storytelling": "Focus: express a brand's tone or identity using a visual lifestyle story.",
|
35 |
+
"creative_social_post": "Focus: generate stylistic or creative content suitable for social media that maintains core subject clarity.",
|
36 |
+
"fallback": "Fallback: default to showcasing the main product in a visually compelling commercial format."
|
37 |
+
}
|
38 |
+
|
39 |
+
def build_system_message(intent: str, num_prompts: int, style_hint: Optional[str] = "") -> str:
|
40 |
+
style_clause = (
|
41 |
+
f"The visual tone, lighting, and environment must match this style: {style_hint}.\n"
|
42 |
+
f"Only override this if the user explicitly requests a different visual style."
|
43 |
+
if style_hint else ""
|
44 |
+
)
|
45 |
+
|
46 |
+
if intent in INTENT_SYSTEM_INSTRUCTIONS:
|
47 |
+
instruction = INTENT_SYSTEM_INSTRUCTIONS[intent]
|
48 |
+
else:
|
49 |
+
print(f"π§ [DEBUG] Unrecognized intent '{intent}', using dynamic fallback...")
|
50 |
+
instruction = (
|
51 |
+
f"Focus: Generate prompts suitable for a '{intent}' scenario using descriptive, high-quality visual storytelling. "
|
52 |
+
"The core subject must remain clear and central."
|
53 |
+
)
|
54 |
+
|
55 |
+
return f"""
|
56 |
+
You are a prompt enhancement assistant for Flux Pro.
|
57 |
+
|
58 |
+
Your task is to transform a short user input into {num_prompts} full-sentence, professional image generation prompts.
|
59 |
+
|
60 |
+
Each prompt must follow this structure:
|
61 |
+
{TEMPLATE_FORMAT}
|
62 |
+
|
63 |
+
Prefix: '{PROMPT_PREFIX}'
|
64 |
+
Suffix: '{PROMPT_SUFFIX}'
|
65 |
+
|
66 |
+
Each prompt must:
|
67 |
+
- Be under {MAX_PROMPT_LENGTH} words
|
68 |
+
- Include: {", ".join(REQUIRED_ELEMENTS)}
|
69 |
+
- Be a single descriptive sentence
|
70 |
+
- Never use lists, examples, or bullet formatting
|
71 |
+
- Avoid specific color names unless inferred from uploaded images
|
72 |
+
- Do not wrap prompts in quotes or number them
|
73 |
+
|
74 |
+
All image elements must follow natural physical proportions. Ensure objects intended for interaction appear in realistic size and position relative to the subject and scene. Avoid exaggerated scaling unless the user explicitly asks for surrealism or stylization.
|
75 |
+
|
76 |
+
{style_clause}
|
77 |
+
|
78 |
+
{instruction}
|
79 |
+
|
80 |
+
Do not explain. Only return the prompts.
|
81 |
+
Generate exactly {num_prompts} unique prompts, one per line.
|
82 |
+
""".strip()
|
83 |
+
|
84 |
+
|
85 |
+
def enhance_prompt_with_chatgpt(
|
86 |
+
user_prompt: str,
|
87 |
+
num_prompts: int = DEFAULT_NUM_PROMPTS,
|
88 |
+
reference_images: Optional[List] = None
|
89 |
+
) -> List[str]:
|
90 |
+
intent = classify_prompt_intent(user_prompt)
|
91 |
+
|
92 |
+
blip_data = describe_uploaded_images(reference_images) if reference_images else {}
|
93 |
+
full_caption = blip_data.get("full_caption", "")
|
94 |
+
style_hint = blip_data.get("style_description", "")
|
95 |
+
|
96 |
+
print(f"\nπ₯ [DEBUG] User prompt: {user_prompt}")
|
97 |
+
if full_caption:
|
98 |
+
print(f"πΌοΈ [DEBUG] BLIP Caption: {full_caption}")
|
99 |
+
if style_hint:
|
100 |
+
print(f"π¨ [DEBUG] Style Description from Image: {style_hint}")
|
101 |
+
|
102 |
+
user_msg = f"Original prompt: {user_prompt}"
|
103 |
+
if style_hint:
|
104 |
+
user_msg += f"\nVisual reference style: {style_hint}"
|
105 |
+
|
106 |
+
try:
|
107 |
+
response = client.chat.completions.create(
|
108 |
+
model="gpt-4",
|
109 |
+
messages=[
|
110 |
+
{"role": "system", "content": build_system_message(intent, num_prompts, style_hint)},
|
111 |
+
{"role": "user", "content": user_msg}
|
112 |
+
],
|
113 |
+
temperature=0.7,
|
114 |
+
max_tokens=1800
|
115 |
+
)
|
116 |
+
|
117 |
+
raw_output = response.choices[0].message.content.strip()
|
118 |
+
print("\nπ [DEBUG] Raw GPT Output:")
|
119 |
+
print(raw_output)
|
120 |
+
|
121 |
+
candidate_prompts = [p.strip() for p in raw_output.split("\n") if p.strip()]
|
122 |
+
|
123 |
+
if (
|
124 |
+
len(candidate_prompts) == 1 and
|
125 |
+
candidate_prompts[0].lower().startswith(PROMPT_PREFIX.lower()) and
|
126 |
+
len(candidate_prompts[0].split()) > MIN_WORD_COUNT
|
127 |
+
):
|
128 |
+
enhanced_prompts = [candidate_prompts[0]]
|
129 |
+
else:
|
130 |
+
enhanced_prompts = [
|
131 |
+
p for p in candidate_prompts
|
132 |
+
if len(p.split()) > MIN_WORD_COUNT and p.lower().startswith(PROMPT_PREFIX.lower())
|
133 |
+
]
|
134 |
+
|
135 |
+
if len(enhanced_prompts) < num_prompts:
|
136 |
+
print(f"β οΈ Only {len(enhanced_prompts)} prompts returned. Padding with user prompt...")
|
137 |
+
enhanced_prompts += [user_prompt] * (num_prompts - len(enhanced_prompts))
|
138 |
+
elif len(enhanced_prompts) > num_prompts:
|
139 |
+
enhanced_prompts = enhanced_prompts[:num_prompts]
|
140 |
+
|
141 |
+
print("\nπ§ [DEBUG] ChatGPT Enhanced Prompts:")
|
142 |
+
for idx, p in enumerate(enhanced_prompts):
|
143 |
+
print(f"[{idx+1}] {p}")
|
144 |
+
print("--------------------------------------------------\n")
|
145 |
+
|
146 |
+
return enhanced_prompts
|
147 |
+
|
148 |
+
except Exception as e:
|
149 |
+
print(f"β [ERROR] Failed to enhance prompt: {e}")
|
150 |
+
return [user_prompt] * num_prompts
|
uno/utils/prompt_router.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import openai
|
3 |
+
|
4 |
+
# === OpenAI Client ===
|
5 |
+
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
6 |
+
|
7 |
+
# === System Message ===
|
8 |
+
INTENT_SYSTEM_MESSAGE = """
|
9 |
+
You are an intent classification engine for Flux Pro's AI prompt enhancement system.
|
10 |
+
|
11 |
+
Given a short user prompt, your job is to classify the *underlying marketing or creative intent* behind it using a single lowercase label that reflects the intent category.
|
12 |
+
|
13 |
+
You must:
|
14 |
+
- Respond with only one concise intent label (a single word or hyphenated phrase)
|
15 |
+
- Avoid using punctuation, explanations, or examples
|
16 |
+
- Infer intent intelligently using your understanding of marketing and creative goals
|
17 |
+
- You may return previously unseen or new intent labels if appropriate
|
18 |
+
|
19 |
+
Some valid label types might include (but are not limited to): product-ad, service-promotion, public-awareness, brand-storytelling, social-trend, artistic-expression, educational-content, campaign-launch, or experimental-style.
|
20 |
+
|
21 |
+
Only return the label. Do not echo the prompt or add any commentary.
|
22 |
+
""".strip()
|
23 |
+
|
24 |
+
def classify_prompt_intent(user_prompt: str) -> str:
|
25 |
+
"""
|
26 |
+
Uses GPT-4 to classify the user's prompt into a marketing or creative intent label.
|
27 |
+
Returns a lowercase intent string (e.g., 'product-ad', 'artistic-expression').
|
28 |
+
"""
|
29 |
+
user_msg = f"Prompt: {user_prompt}\nReturn only the label."
|
30 |
+
|
31 |
+
try:
|
32 |
+
response = client.chat.completions.create(
|
33 |
+
model="gpt-4",
|
34 |
+
messages=[
|
35 |
+
{"role": "system", "content": INTENT_SYSTEM_MESSAGE},
|
36 |
+
{"role": "user", "content": user_msg}
|
37 |
+
],
|
38 |
+
temperature=0,
|
39 |
+
max_tokens=10,
|
40 |
+
)
|
41 |
+
|
42 |
+
label = response.choices[0].message.content.strip().lower()
|
43 |
+
print(f"π [DEBUG] Inferred intent label: {label}")
|
44 |
+
return label
|
45 |
+
|
46 |
+
except Exception as e:
|
47 |
+
print(f"β [ERROR] Failed to classify prompt intent: {e}")
|
48 |
+
return "unknown"
|