|
import os |
|
import re |
|
import shutil |
|
|
|
from config import MODEL_OUTPUT_IMAGE_NAMES |
|
|
|
ROOT_DIR = "benchmark_images_generations" |
|
OUTPUT_DIR = "data/" |
|
|
|
def main(): |
|
for domain in os.listdir(ROOT_DIR): |
|
domain_dir = os.path.join(ROOT_DIR, domain) |
|
for i, sample_dir in enumerate(os.listdir(domain_dir)): |
|
if sample_dir == ".DS_Store": |
|
continue |
|
sample_dir_path = os.path.join(domain_dir, sample_dir) |
|
prompt = sample_dir[4:].strip() |
|
|
|
output_sample_dir = os.path.join(OUTPUT_DIR, domain, f"sample_{i}") |
|
os.makedirs(output_sample_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
input_bg = None |
|
for file in os.listdir(sample_dir_path): |
|
if re.match(r"bg\d+\.(jpg|png)", file): |
|
input_bg = file |
|
break |
|
if input_bg: |
|
input_bg_path = os.path.join(sample_dir_path, input_bg) |
|
shutil.copy(input_bg_path, os.path.join(output_sample_dir, "input_bg.jpg")) |
|
else: |
|
print(f"Warning: No input background found in {sample_dir_path}. Skipping sample {i}...") |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
input_fg = None |
|
for file in os.listdir(sample_dir_path): |
|
if re.match(r"fg(_\w+)?\.(jpg|png)", file) or re.match(r"fg\d+(_\w+)?\.(jpg|png)", file): |
|
if "mask" not in file: |
|
input_fg = file |
|
break |
|
if input_fg: |
|
input_fg_path = os.path.join(sample_dir_path, input_fg) |
|
shutil.copy(input_fg_path, os.path.join(output_sample_dir, "input_fg.jpg")) |
|
else: |
|
print(f"Warning: No input foreground found in {sample_dir_path}. Skipping sample {i}...") |
|
continue |
|
|
|
|
|
if any([not os.path.exists(os.path.join(sample_dir_path, image_name)) for image_name in MODEL_OUTPUT_IMAGE_NAMES.values()]): |
|
print(f"Warning: Not all output images found in {sample_dir_path}. Skipping sample {i}...") |
|
|
|
shutil.rmtree(output_sample_dir) |
|
continue |
|
|
|
for _, image_name in MODEL_OUTPUT_IMAGE_NAMES.items(): |
|
image_path = os.path.join(sample_dir_path, image_name) |
|
target_path = os.path.join(output_sample_dir, image_name) |
|
if os.path.exists(image_path): |
|
shutil.copy(image_path, target_path) |
|
else: |
|
print(f"Warning: {image_name} not found in {sample_dir_path}. Skipping...") |
|
|
|
|
|
|
|
prompt_file_path = os.path.join(output_sample_dir, "prompt.txt") |
|
with open(prompt_file_path, "w") as prompt_file: |
|
prompt_file.write(prompt) |
|
|
|
if __name__ == "__main__": |
|
main() |