|
\ |
|
import os |
|
import random |
|
import uuid |
|
import pandas as pd |
|
from datetime import datetime |
|
from huggingface_hub import HfApi, hf_hub_download, login |
|
from PIL import Image |
|
import shutil |
|
|
|
import config |
|
|
|
|
|
def login_hugging_face(): |
|
"""Logs in to Hugging Face Hub using token from config or environment variable.""" |
|
token = config.HF_TOKEN or os.getenv("HF_HUB_TOKEN") |
|
if token: |
|
login(token=token) |
|
print("Successfully logged into Hugging Face Hub.") |
|
else: |
|
print("HF_TOKEN not set in config and HF_HUB_TOKEN not in environment. Proceeding without login. Uploads to private repos will fail.") |
|
|
|
def load_preferences_from_hf_hub(repo_id, filename): |
|
"""Downloads the preferences CSV from the Hugging Face Hub dataset repo. |
|
Returns a Pandas DataFrame or None if the file doesn't exist or on error. |
|
""" |
|
try: |
|
downloaded_file_path = hf_hub_download( |
|
repo_id=repo_id, |
|
filename=filename, |
|
repo_type="dataset", |
|
local_dir=".", |
|
local_dir_use_symlinks=False |
|
) |
|
|
|
if os.path.dirname(downloaded_file_path) != os.path.abspath("."): |
|
destination_path = os.path.join(".", os.path.basename(downloaded_file_path)) |
|
shutil.move(downloaded_file_path, destination_path) |
|
downloaded_file_path = destination_path |
|
|
|
if os.path.exists(downloaded_file_path): |
|
print(f"Successfully downloaded {filename} from {repo_id}") |
|
df = pd.read_csv(downloaded_file_path) |
|
|
|
if downloaded_file_path != filename: |
|
os.rename(downloaded_file_path, filename) |
|
return df |
|
else: |
|
print(f"Downloaded file {downloaded_file_path} does not exist locally.") |
|
return None |
|
except Exception as e: |
|
print(f"Could not download {filename} from {repo_id}. Error: {e}") |
|
print("Starting with an empty preferences table or local copy if available.") |
|
if os.path.exists(filename): |
|
print(f"Loading local copy of {filename}") |
|
return pd.read_csv(filename) |
|
return None |
|
|
|
def save_preferences_to_hf_hub(df, repo_id, filename, commit_message="Update preferences"): |
|
"""Saves the DataFrame to a local CSV and uploads it to the Hugging Face Hub.""" |
|
if df is None or df.empty: |
|
print("Preferences DataFrame is empty. Nothing to save or upload.") |
|
return |
|
try: |
|
df.to_csv(filename, index=False) |
|
print(f"Preferences saved locally to {filename}") |
|
|
|
api = HfApi() |
|
api.upload_file( |
|
path_or_fileobj=filename, |
|
path_in_repo=filename, |
|
repo_id=repo_id, |
|
repo_type="dataset", |
|
commit_message=commit_message, |
|
) |
|
print(f"Successfully uploaded {filename} to {repo_id}") |
|
except Exception as e: |
|
print(f"Error saving or uploading {filename} to Hugging Face Hub: {e}") |
|
print("Changes are saved locally. Will attempt upload on next scheduled push.") |
|
|
|
|
|
def scan_data_directory(data_folder): |
|
""" |
|
Scans the data directory to find domains and their samples. |
|
Returns a dictionary: {"domain_name": ["sample_id1", "sample_id2", ...]} |
|
""" |
|
all_samples_by_domain = {} |
|
if not os.path.isdir(data_folder): |
|
print(f"Error: Data folder '{data_folder}' not found.") |
|
return all_samples_by_domain |
|
|
|
for domain_name in os.listdir(data_folder): |
|
domain_path = os.path.join(data_folder, domain_name) |
|
if os.path.isdir(domain_path): |
|
all_samples_by_domain[domain_name] = [] |
|
for sample_id in os.listdir(domain_path): |
|
sample_path = os.path.join(domain_path, sample_id) |
|
|
|
prompt_file = os.path.join(sample_path, config.PROMPT_FILE_NAME) |
|
bg_image = os.path.join(sample_path, config.BACKGROUND_IMAGE_NAME) |
|
if os.path.isdir(sample_path) and os.path.exists(prompt_file) and os.path.exists(bg_image): |
|
all_samples_by_domain[domain_name].append(sample_id) |
|
if not all_samples_by_domain[domain_name]: |
|
print(f"Warning: No valid samples found in domain '{domain_name}'.") |
|
if not all_samples_by_domain: |
|
print(f"Warning: No domains found or no valid samples in any domain in '{data_folder}'.") |
|
return all_samples_by_domain |
|
|
|
def prepare_session_samples(all_samples_by_domain, samples_per_domain): |
|
""" |
|
Prepares a list of (domain, sample_id) tuples for a user session. |
|
Randomly selects 'samples_per_domain' from each domain. |
|
The returned list is shuffled. |
|
""" |
|
session_queue = [] |
|
for domain, samples in all_samples_by_domain.items(): |
|
if samples: |
|
chosen_samples = random.sample(samples, min(len(samples), samples_per_domain)) |
|
for sample_id in chosen_samples: |
|
session_queue.append((domain, sample_id)) |
|
random.shuffle(session_queue) |
|
return session_queue |
|
|
|
|
|
def generate_session_id(): |
|
"""Generates a unique session ID.""" |
|
return uuid.uuid4().hex[:config.SESSION_ID_LENGTH] |
|
|
|
def load_sample_data(domain, sample_id): |
|
""" |
|
Loads data for a specific sample: prompt, input images, and output image paths. |
|
Returns a dictionary or None if data is incomplete. |
|
""" |
|
sample_path = os.path.join(config.DATA_FOLDER, domain, sample_id) |
|
prompt_path = os.path.join(sample_path, config.PROMPT_FILE_NAME) |
|
bg_image_path = os.path.join(sample_path, config.BACKGROUND_IMAGE_NAME) |
|
fg_image_path = os.path.join(sample_path, config.FOREGROUND_IMAGE_NAME) |
|
|
|
if not all(os.path.exists(p) for p in [prompt_path, bg_image_path, fg_image_path]): |
|
print(f"Error: Missing core files for sample {domain}/{sample_id}") |
|
return None |
|
|
|
try: |
|
with open(prompt_path, 'r', encoding='utf-8') as f: |
|
prompt_text = f.read().strip() |
|
except Exception as e: |
|
print(f"Error reading prompt for {domain}/{sample_id}: {e}") |
|
return None |
|
|
|
output_images = {} |
|
for model_key, img_name in config.MODEL_OUTPUT_IMAGE_NAMES.items(): |
|
img_path = os.path.join(sample_path, img_name) |
|
if os.path.exists(img_path): |
|
output_images[model_key] = img_path |
|
else: |
|
print(f"Warning: Missing output image {img_name} for model {model_key} in sample {domain}/{sample_id}") |
|
|
|
|
|
|
|
|
|
if len(output_images) < len(config.MODEL_OUTPUT_IMAGE_NAMES): |
|
print(f"Warning: Sample {domain}/{sample_id} is missing one or more model outputs. It will have fewer than 4 options.") |
|
if not output_images: |
|
return None |
|
|
|
|
|
return { |
|
"prompt": prompt_text, |
|
"background_img_path": bg_image_path, |
|
"foreground_img_path": fg_image_path, |
|
"output_image_paths": output_images |
|
} |
|
|
|
def record_preference(df, session_id, domain, sample_id, prompt, bg_path, fg_path, displayed_models_info, preferred_model_key): |
|
""" |
|
Appends a new preference record to the DataFrame. |
|
displayed_models_info: list of (model_key, image_path) in the order they were displayed. |
|
preferred_model_key: The key of the model the user selected (e.g., "model_a"). |
|
""" |
|
timestamp = datetime.now().isoformat() |
|
|
|
|
|
new_row = { |
|
"session_id": session_id, |
|
"timestamp": timestamp, |
|
"domain": domain, |
|
"sample_id": sample_id, |
|
"prompt": prompt, |
|
"input_background": os.path.basename(bg_path), |
|
"input_foreground": os.path.basename(fg_path), |
|
"preferred_model_key": preferred_model_key, |
|
"preferred_model_filename": config.MODEL_OUTPUT_IMAGE_NAMES.get(preferred_model_key, "N/A") |
|
} |
|
|
|
|
|
for i in range(4): |
|
col_name = f"displayed_order_model_{i+1}" |
|
if i < len(displayed_models_info): |
|
new_row[col_name] = displayed_models_info[i][0] |
|
else: |
|
new_row[col_name] = None |
|
|
|
new_df_row = pd.DataFrame([new_row], columns=config.CSV_HEADERS) |
|
|
|
if df is None: |
|
df = new_df_row |
|
else: |
|
df = pd.concat([df, new_df_row], ignore_index=True) |
|
return df |
|
|