|
import gradio as gr |
|
import pandas as pd |
|
import os |
|
import random |
|
from datetime import datetime |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
from PIL import Image |
|
|
|
import config |
|
import utils |
|
|
|
|
|
|
|
utils.login_hugging_face() |
|
|
|
|
|
preferences_df = utils.load_preferences_from_hf_hub(config.HF_DATASET_REPO_ID, config.RESULTS_CSV_FILE) |
|
if preferences_df is None: |
|
if os.path.exists(config.RESULTS_CSV_FILE): |
|
print(f"Loading preferences from local file: {config.RESULTS_CSV_FILE}") |
|
try: |
|
preferences_df = pd.read_csv(config.RESULTS_CSV_FILE) |
|
except pd.errors.EmptyDataError: |
|
print(f"Local preferences file {config.RESULTS_CSV_FILE} is empty. Starting fresh.") |
|
preferences_df = pd.DataFrame(columns=config.CSV_HEADERS) |
|
except Exception as e: |
|
print(f"Error loading local {config.RESULTS_CSV_FILE}: {e}. Starting fresh.") |
|
preferences_df = pd.DataFrame(columns=config.CSV_HEADERS) |
|
else: |
|
print("No existing preferences found on Hub or locally. Starting with an empty table.") |
|
preferences_df = pd.DataFrame(columns=config.CSV_HEADERS) |
|
|
|
|
|
ALL_SAMPLES_BY_DOMAIN = utils.scan_data_directory(config.DATA_FOLDER) |
|
if not ALL_SAMPLES_BY_DOMAIN: |
|
print(f"CRITICAL: No data found in {config.DATA_FOLDER}. The app might not function correctly.") |
|
|
|
|
|
|
|
def scheduled_upload_job(): |
|
global preferences_df |
|
print(f"Running scheduled job: Saving and uploading preferences at {datetime.now()}") |
|
if preferences_df is not None and not preferences_df.empty: |
|
utils.save_preferences_to_hf_hub(preferences_df, config.HF_DATASET_REPO_ID, config.RESULTS_CSV_FILE, commit_message="Periodic background update") |
|
else: |
|
print("Scheduled job: Preferences DataFrame is empty. Nothing to upload.") |
|
|
|
scheduler = BackgroundScheduler() |
|
scheduler.add_job(scheduled_upload_job, 'interval', hours=config.PUSH_INTERVAL_HOURS) |
|
scheduler.start() |
|
print(f"Scheduler started. Will attempt to upload preferences every {config.PUSH_INTERVAL_HOURS} hour(s).") |
|
|
|
|
|
|
|
def start_new_session(): |
|
"""Initializes a new user session.""" |
|
session_id = utils.generate_session_id() |
|
sample_queue = utils.prepare_session_samples(ALL_SAMPLES_BY_DOMAIN, config.SAMPLES_PER_DOMAIN) |
|
current_sample_index = 0 |
|
if not sample_queue: |
|
no_samples_msg = f"# 😥 No Samples Available!\n\n### Please check the data folder configuration or try again later." |
|
return session_id, sample_queue, current_sample_index, no_samples_msg, None, None, None, [], [], True |
|
|
|
print(f"New session started: {session_id}, with {len(sample_queue)} samples.") |
|
domain_prompt_md, bg, fg, s_data, out_imgs, disp_info, end_flag = load_and_display_sample(sample_queue, current_sample_index) |
|
return session_id, sample_queue, current_sample_index, domain_prompt_md, bg, fg, s_data, out_imgs, disp_info, end_flag |
|
|
|
|
|
def load_and_display_sample(sample_queue, current_sample_index): |
|
"""Loads and prepares a single sample for display.""" |
|
if not sample_queue or current_sample_index >= len(sample_queue): |
|
end_session_msg = f"# 🎉 All Rated! 🎉\n\n### All samples for this session have been rated. Thank you!" |
|
return end_session_msg, None, None, None, [], [], True |
|
|
|
domain, sample_id = sample_queue[current_sample_index] |
|
sample_data = utils.load_sample_data(domain, sample_id) |
|
|
|
if sample_data is None: |
|
print(f"Error loading sample {domain}/{sample_id}. Skipping.") |
|
error_msg = f"## ⚠️ Error Loading Sample\n\nCould not load data for {domain}/{sample_id}. Skipping to the next one." |
|
return error_msg, None, None, None, [], [], False |
|
|
|
prompt_text = sample_data["prompt"] |
|
bg_img_path = sample_data["background_img_path"] |
|
fg_img_path = sample_data["foreground_img_path"] |
|
|
|
|
|
|
|
bg_image_to_display = Image.open(bg_img_path) |
|
fg_image_to_display = Image.open(fg_img_path) |
|
|
|
output_model_keys = list(sample_data["output_image_paths"].keys()) |
|
random.shuffle(output_model_keys) |
|
|
|
displayed_models_info = [] |
|
output_images_for_display = [] |
|
|
|
|
|
square_size = (config.IMAGE_DISPLAY_SIZE[0], config.IMAGE_DISPLAY_SIZE[0]) |
|
|
|
for model_key in output_model_keys: |
|
img_path = sample_data["output_image_paths"][model_key] |
|
try: |
|
img = Image.open(img_path).resize(square_size) |
|
output_images_for_display.append(img) |
|
displayed_models_info.append((model_key, img_path)) |
|
except FileNotFoundError: |
|
print(f"Image not found: {img_path} for model {model_key}. Skipping this option.") |
|
except Exception as e: |
|
print(f"Error loading or resizing image {img_path}: {e}. Skipping this option.") |
|
|
|
blank_image = Image.new('RGB', square_size, (200, 200, 200)) |
|
while len(output_images_for_display) < 4: |
|
output_images_for_display.append(blank_image) |
|
displayed_models_info.append(("BLANK_SLOT", "N/A")) |
|
|
|
domain_prompt_markdown = f""" |
|
<div style="padding:15px 20px 20px 20px;border-left:3px black;background-color:#4B5966;border-radius: 10px;color:black;"> |
|
|
|
### Domain: {domain} |
|
|
|
</div> |
|
<br> |
|
<div style="padding:15px 20px 20px 20px;border-left:3px black;background-color:#4B5966;border-radius: 10px;color:black;"> |
|
|
|
## Prompt |
|
|
|
### _"{prompt_text}"_ |
|
|
|
</div> |
|
""" |
|
|
|
return ( |
|
domain_prompt_markdown, |
|
bg_image_to_display, |
|
fg_image_to_display, |
|
sample_data, |
|
output_images_for_display[:4], |
|
displayed_models_info[:4], |
|
False |
|
) |
|
|
|
def process_vote(choice_index, session_id, sample_queue, current_sample_index, current_sample_data, displayed_models_info_for_sample): |
|
global preferences_df |
|
|
|
if current_sample_data is None or not displayed_models_info_for_sample or choice_index >= len(displayed_models_info_for_sample): |
|
print("Error: Invalid data for processing vote. Skipping.") |
|
current_sample_index += 1 |
|
if current_sample_index >= len(sample_queue): |
|
error_end_msg = f"# ⚠️ Error Processing Vote ⚠️\n\n### An issue occurred. The session has ended." |
|
return preferences_df, current_sample_index, error_end_msg, None, None, None, [], [], True |
|
else: |
|
next_prompt_md, next_bg, next_fg, next_s_data, next_out_imgs, next_disp_info, next_hide = load_and_display_sample(sample_queue, current_sample_index) |
|
return preferences_df, current_sample_index, next_prompt_md, next_bg, next_fg, next_s_data, next_out_imgs, next_disp_info, next_hide |
|
|
|
domain, sample_id = sample_queue[current_sample_index] |
|
preferred_model_key, _ = displayed_models_info_for_sample[choice_index] |
|
|
|
if preferred_model_key == "BLANK_SLOT": |
|
print("User clicked on a blank slot. Vote not recorded. Please select a valid image.") |
|
_prompt_md, _bg, _fg, _s_data, _out_imgs, _disp_info, _hide = load_and_display_sample(sample_queue, current_sample_index) |
|
return preferences_df, current_sample_index, _prompt_md, _bg, _fg, _s_data, _out_imgs, _disp_info, _hide |
|
|
|
print(f"Session {session_id}: Voted for model '{config.MODEL_DISPLAY_NAMES.get(preferred_model_key, preferred_model_key)}' (key: {preferred_model_key}) for sample {domain}/{sample_id}") |
|
|
|
preferences_df = utils.record_preference( |
|
df=preferences_df, |
|
session_id=session_id, |
|
domain=domain, |
|
sample_id=sample_id, |
|
prompt=current_sample_data["prompt"], |
|
bg_path=current_sample_data["background_img_path"], |
|
fg_path=current_sample_data["foreground_img_path"], |
|
displayed_models_info=displayed_models_info_for_sample, |
|
preferred_model_key=preferred_model_key |
|
) |
|
|
|
try: |
|
preferences_df.to_csv(config.RESULTS_CSV_FILE, index=False) |
|
print(f"Preferences saved locally to {config.RESULTS_CSV_FILE}") |
|
except Exception as e: |
|
print(f"Error saving preferences locally: {e}") |
|
|
|
current_sample_index += 1 |
|
if current_sample_index >= len(sample_queue): |
|
utils.save_preferences_to_hf_hub(preferences_df, config.HF_DATASET_REPO_ID, config.RESULTS_CSV_FILE, commit_message="Session end update") |
|
final_msg = f"# 🎉 Session Complete! 🎉\n\n### All samples have been rated. Thank you for your participation!" |
|
return preferences_df, current_sample_index, final_msg, None, None, None, [], [], True |
|
|
|
next_prompt_md, next_bg, next_fg, next_s_data, next_out_imgs, next_disp_info, next_hide = load_and_display_sample(sample_queue, current_sample_index) |
|
return preferences_df, current_sample_index, next_prompt_md, next_bg, next_fg, next_s_data, next_out_imgs, next_disp_info, next_hide |
|
|
|
|
|
|
|
custom_css = """ |
|
.custom-vote-button { |
|
background-color: #FFA500 !important; /* Light Orange for normal state */ |
|
border-color: #FFA500 !important; /* Light Orange for normal state */ |
|
color: white !important; |
|
} |
|
.custom-vote-button:hover { |
|
background-color: #FF8C00 !important; /* Dark Orange for hover state */ |
|
border-color: #FF8C00 !important; /* Dark Orange for hover state */ |
|
color: white !important; |
|
} |
|
""" |
|
|
|
with gr.Blocks(title=config.APP_TITLE, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue), css=custom_css) as demo: |
|
session_id_state = gr.State() |
|
sample_queue_state = gr.State([]) |
|
current_sample_index_state = gr.State(0) |
|
current_sample_data_state = gr.State() |
|
displayed_models_info_state = gr.State([]) |
|
preferences_df_state = gr.State(value=preferences_df) |
|
|
|
gr.Markdown(f"# {config.APP_TITLE}") |
|
gr.Markdown(config.APP_DESCRIPTION) |
|
|
|
with gr.Row(): |
|
start_button = gr.Button("Start New Session / Load First Sample", variant="primary") |
|
|
|
with gr.Row(equal_height=False): |
|
with gr.Column(scale=1): |
|
domain_prompt_info_display = gr.Markdown(value="### Click 'Start New Session' to begin.") |
|
|
|
with gr.Column(scale=2): |
|
with gr.Row(): |
|
input_bg_image_display = gr.Image(label="Input Background", type="pil", height=config.IMAGE_DISPLAY_SIZE[0], interactive=False) |
|
input_fg_image_display = gr.Image(label="Input Foreground", type="pil", height=config.IMAGE_DISPLAY_SIZE[0], interactive=False) |
|
|
|
gr.Markdown("---") |
|
gr.Markdown("## Choose your preferred composed image:") |
|
|
|
output_image_displays = [] |
|
vote_buttons = [] |
|
with gr.Row(): |
|
for i in range(4): |
|
with gr.Column(): |
|
img_display = gr.Image(label=f"Option {i+1}", type="pil", height=config.IMAGE_DISPLAY_SIZE[0], width=config.IMAGE_DISPLAY_SIZE[0], interactive=False) |
|
output_image_displays.append(img_display) |
|
vote_btn = gr.Button(f"Select Option {i+1}", elem_id=f"vote_btn_{i}", elem_classes=["custom-vote-button"]) |
|
vote_buttons.append(vote_btn) |
|
|
|
end_of_session_msg_display = gr.Markdown("", visible=True) |
|
|
|
def handle_start_session(): |
|
s_id, s_queue, s_idx, domain_prompt_or_end_msg, bg, fg, s_data, out_imgs, disp_info, end = start_new_session() |
|
|
|
while len(out_imgs) < 4: out_imgs.append(None) |
|
while len(disp_info) < 4: disp_info.append(("BLANK_SLOT", "N/A")) |
|
|
|
updates = { |
|
session_id_state: s_id, |
|
sample_queue_state: s_queue, |
|
current_sample_index_state: s_idx, |
|
domain_prompt_info_display: domain_prompt_or_end_msg if not end else "", |
|
input_bg_image_display: bg, |
|
input_fg_image_display: fg, |
|
current_sample_data_state: s_data, |
|
displayed_models_info_state: disp_info, |
|
end_of_session_msg_display: domain_prompt_or_end_msg if end else "" |
|
} |
|
for i in range(4): |
|
updates[output_image_displays[i]] = out_imgs[i] if i < len(out_imgs) else None |
|
num_actual_outputs = 0 |
|
if s_data and "output_image_paths" in s_data and s_data["output_image_paths"]: |
|
num_actual_outputs = sum(1 for m_key, _ in disp_info if m_key != "BLANK_SLOT" and m_key is not None) |
|
updates[vote_buttons[i]] = gr.Button(interactive=not end and i < num_actual_outputs) |
|
return updates |
|
|
|
start_button.click( |
|
fn=handle_start_session, |
|
inputs=[], |
|
outputs=[ |
|
session_id_state, sample_queue_state, current_sample_index_state, |
|
domain_prompt_info_display, |
|
input_bg_image_display, input_fg_image_display, |
|
current_sample_data_state, displayed_models_info_state, end_of_session_msg_display, |
|
*output_image_displays, *vote_buttons |
|
] |
|
) |
|
|
|
def make_vote_fn(choice_idx): |
|
def vote_action(s_id, s_queue, s_idx, current_s_data, disp_info_for_sample, prefs_df_val): |
|
global preferences_df |
|
preferences_df = prefs_df_val |
|
|
|
new_prefs_df, new_s_idx, domain_prompt_or_end_msg, bg, fg, new_s_data, out_imgs, new_disp_info, end = process_vote( |
|
choice_idx, s_id, s_queue, s_idx, current_s_data, disp_info_for_sample |
|
) |
|
|
|
while len(out_imgs) < 4: out_imgs.append(None) |
|
while len(new_disp_info) < 4: new_disp_info.append(("BLANK_SLOT", "N/A")) |
|
|
|
updates = { |
|
preferences_df_state: new_prefs_df, |
|
current_sample_index_state: new_s_idx, |
|
domain_prompt_info_display: domain_prompt_or_end_msg if not end else "", |
|
input_bg_image_display: bg, |
|
input_fg_image_display: fg, |
|
current_sample_data_state: new_s_data, |
|
displayed_models_info_state: new_disp_info, |
|
end_of_session_msg_display: domain_prompt_or_end_msg if end else "" |
|
} |
|
for i in range(4): |
|
updates[output_image_displays[i]] = out_imgs[i] if i < len(out_imgs) else None |
|
num_actual_outputs = 0 |
|
if new_s_data and "output_image_paths" in new_s_data and new_s_data["output_image_paths"]: |
|
num_actual_outputs = sum(1 for m_key, _ in new_disp_info if m_key != "BLANK_SLOT" and m_key is not None) |
|
updates[vote_buttons[i]] = gr.Button(interactive=not end and i < num_actual_outputs) |
|
return updates |
|
return vote_action |
|
|
|
for i, btn in enumerate(vote_buttons): |
|
btn.click( |
|
fn=make_vote_fn(i), |
|
inputs=[ |
|
session_id_state, sample_queue_state, current_sample_index_state, |
|
current_sample_data_state, displayed_models_info_state, preferences_df_state |
|
], |
|
outputs=[ |
|
preferences_df_state, current_sample_index_state, |
|
domain_prompt_info_display, |
|
input_bg_image_display, input_fg_image_display, |
|
current_sample_data_state, displayed_models_info_state, end_of_session_msg_display, |
|
*output_image_displays, *vote_buttons |
|
] |
|
) |
|
|
|
gr.Markdown(config.FOOTER_MESSAGE) |
|
|
|
if __name__ == "__main__": |
|
if not os.path.exists(config.DATA_FOLDER): |
|
print(f"Creating dummy data folder: {config.DATA_FOLDER}") |
|
os.makedirs(config.DATA_FOLDER, exist_ok=True) |
|
|
|
dummy_domains = ["Real-Cartoon", "Real-Painting"] |
|
dummy_model_keys = list(config.MODEL_OUTPUT_IMAGE_NAMES.keys()) |
|
|
|
for domain in dummy_domains: |
|
domain_path = os.path.join(config.DATA_FOLDER, domain) |
|
os.makedirs(domain_path, exist_ok=True) |
|
for i in range(config.SAMPLES_PER_DOMAIN + 2): |
|
sample_id = f"sample_{i:03d}" |
|
sample_path = os.path.join(domain_path, sample_id) |
|
os.makedirs(sample_path, exist_ok=True) |
|
|
|
with open(os.path.join(sample_path, config.PROMPT_FILE_NAME), "w") as f: |
|
f.write(f"This is a dummy prompt for {domain} sample {sample_id}.") |
|
|
|
colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (0,255,255)] |
|
try: |
|
img_bg = Image.new('RGB', config.IMAGE_DISPLAY_SIZE, color='gray') |
|
img_bg.save(os.path.join(sample_path, config.BACKGROUND_IMAGE_NAME)) |
|
|
|
img_fg = Image.new('RGB', config.IMAGE_DISPLAY_SIZE, color='lightgray') |
|
img_fg.save(os.path.join(sample_path, config.FOREGROUND_IMAGE_NAME)) |
|
|
|
for idx, model_key in enumerate(dummy_model_keys): |
|
model_img_name = config.MODEL_OUTPUT_IMAGE_NAMES[model_key] |
|
img_model = Image.new('RGB', config.IMAGE_DISPLAY_SIZE, color=colors[idx % len(colors)]) |
|
img_model.save(os.path.join(sample_path, model_img_name)) |
|
except Exception as e: |
|
print(f"Error creating dummy image: {e}") |
|
print("Dummy data creation complete.") |
|
ALL_SAMPLES_BY_DOMAIN = utils.scan_data_directory(config.DATA_FOLDER) |
|
|
|
demo.launch() |
|
|
|
import atexit |
|
atexit.register(lambda: scheduler.shutdown() if scheduler.running else None) |
|
|