mbarnig's picture
Update app.py
201e1a6 verified
import gradio as gr
import os, json, random, glob
from typing import List, Dict
from PIL import Image, ImageDraw, ImageFont
# === Étiquettes standardisées ===
LABEL_VIRTUEL = "Virtuel"
LABEL_REEL = "Réel"
# --- Paramètres généraux ---
IMAGE_DIR = os.getenv("IMAGE_DIR", "assets") # Dossier des images
N_IMAGES = int(os.getenv("N_IMAGES", "24")) # Nombre d’images attendues
IMG_EXTS = (".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif")
# --- Utilitaires ---
def _centered_multiline(draw: ImageDraw.ImageDraw, xy, text: str, font: ImageFont.ImageFont, img_w: int):
"""Dessine un texte multi-lignes centré horizontalement autour de (xy[1]) en Y."""
x, y = xy
line_h = int(font.size * 1.2)
for i, line in enumerate(text.split("\n")):
bbox = draw.textbbox((0, 0), line, font=font)
w = bbox[2] - bbox[0]
draw.text(((img_w - w) // 2, y + i * line_h), line, fill=(0, 0, 0), font=font)
def generate_demo_assets():
os.makedirs(IMAGE_DIR, exist_ok=True)
files = [p for p in glob.glob(os.path.join(IMAGE_DIR, "*")) if p.lower().endswith(IMG_EXTS)]
if len(files) >= N_IMAGES:
return
print("[setup] Génération d’un jeu de données de démonstration…")
w, h = 640, 640
try:
font = ImageFont.truetype("DejaVuSans-Bold.ttf", 36)
except Exception:
font = ImageFont.load_default()
demo_paths = []
for i in range(N_IMAGES):
bg = (random.randint(160, 240), random.randint(160, 240), random.randint(160, 240))
img = Image.new("RGB", (w, h), bg)
d = ImageDraw.Draw(img)
label = LABEL_VIRTUEL if i % 2 == 0 else LABEL_REEL
text = f"DEMO\nImage {i+1}\nVérité : {label}"
_centered_multiline(d, (0, h//2 - 60), text, font, w)
fname = f"demo_{i+1:02d}_{'ai' if label==LABEL_VIRTUEL else 'human'}.png"
path = os.path.join(IMAGE_DIR, fname)
img.save(path)
demo_paths.append((path, label))
# Answer key de démo
key = [{"file": os.path.basename(p), "label": lab} for p, lab in demo_paths]
with open(os.path.join(IMAGE_DIR, "answer_key.json"), "w", encoding="utf-8") as f:
json.dump(key, f, ensure_ascii=False, indent=2)
def load_items() -> List[Dict]:
os.makedirs(IMAGE_DIR, exist_ok=True)
generate_demo_assets() # crée un dataset de démo si le dossier est vide
files = [p for p in glob.glob(os.path.join(IMAGE_DIR, "*")) if p.lower().endswith(IMG_EXTS)]
files.sort()
if len(files) < N_IMAGES:
raise RuntimeError(f"Il faut au moins {N_IMAGES} images dans '{IMAGE_DIR}'. Trouvé : {len(files)}.")
files = files[:N_IMAGES]
# Charger la vérité terrain si disponible
answer_key_path = os.path.join(IMAGE_DIR, "answer_key.json")
label_map = {}
if os.path.exists(answer_key_path):
try:
with open(answer_key_path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
label_map = {k: (LABEL_VIRTUEL if v.lower().startswith(("ia", "ai")) else LABEL_REEL) for k, v in data.items()}
elif isinstance(data, list):
for row in data:
fname = row.get("file") or row.get("name") or row.get("path")
lab = row.get("label") or row.get("truth")
if fname and lab:
lab_low = str(lab).strip().lower()
if lab_low in {"ia", "ai", "artificial", "generated", LABEL_VIRTUEL.lower()}:
lab_norm = LABEL_VIRTUEL
else:
lab_norm = LABEL_REEL
label_map[os.path.basename(fname)] = lab_norm
except Exception as e:
print("[warn] Impossible de lire answer_key.json :", e)
# Heuristique de repli si pas d’answer key
ai_markers = ("_ai", "-ai", " ai ", "midjourney", "sdxl", "stable", "gen", "generated", "dalle", "flux")
items = []
for p in files:
fname = os.path.basename(p)
truth = label_map.get(fname)
if truth is None:
fn_low = fname.lower()
truth = LABEL_VIRTUEL if any(m in fn_low for m in ai_markers) else LABEL_REEL
items.append({"path": p, "file": fname, "truth": truth})
return items
ITEMS = load_items()
# --- Construction de l’UI ---
COLS = 3 # 3 colonnes → 8 lignes pour 24 images
def build_interface(items: List[Dict]):
with gr.Blocks(theme=gr.themes.Soft(), css="""
.quiz-grid .gr-image {max-height: 220px}
.score {font-size: 1.2rem; font-weight: 700}
.center-button {display: flex; justify-content: center; margin-top: 0.5rem;}
.warn-msg {text-align: center; color: #b91c1c; font-weight: 600;}
""") as demo:
gr.Markdown(f"""
# {LABEL_VIRTUEL} ou {LABEL_REEL} ?
Sélectionnez **{LABEL_VIRTUEL}** ou **{LABEL_REEL}** pour chacune des {N_IMAGES} images, puis cliquez sur **Valider mes {N_IMAGES} choix**.
""")
state_items = gr.State(items)
# --- Zone du quiz ---
with gr.Group(visible=True) as quiz_group:
with gr.Row():
btn_shuffle = gr.Button("🔀 Mélanger l’ordre")
btn_reset = gr.Button("♻️ Réinitialiser les choix")
image_comps: List[gr.Image] = []
radio_comps: List[gr.Radio] = []
rows = (N_IMAGES + COLS - 1) // COLS
idx = 0
with gr.Column(elem_classes=["quiz-grid"]):
for r in range(rows):
with gr.Row():
for c in range(COLS):
if idx >= N_IMAGES:
break
with gr.Column():
img = gr.Image(value=items[idx]["path"], label=f"Image {idx+1}", interactive=False)
image_comps.append(img)
radio = gr.Radio(choices=[LABEL_VIRTUEL, LABEL_REEL], label="Votre choix", value=None)
radio_comps.append(radio)
idx += 1
# 🔸 Message d'avertissement placé juste au-dessus du bouton
warn_md = gr.Markdown("", visible=False, elem_classes=["warn-msg"])
# Bouton Valider centré
with gr.Row(elem_classes=["center-button"]):
btn_submit = gr.Button(f"✅ Valider mes {N_IMAGES} choix", variant="primary")
# --- Zone des résultats ---
with gr.Group(visible=False) as result_group:
gr.Markdown("## Résultats")
score_md = gr.Markdown(elem_classes=["score"])
df = gr.Dataframe(
headers=["#", "Fichier", "Vérité", "Votre réponse", "✓"],
row_count=(N_IMAGES, "fixed"),
interactive=False,
)
with gr.Row():
gallery_ok = gr.Gallery(label="Réponses correctes", columns=6, height=180)
gallery_ko = gr.Gallery(label="Réponses incorrectes", columns=6, height=180)
with gr.Row():
btn_again_same = gr.Button("↩️ Rejouer (même ordre)")
btn_again_shuffle = gr.Button("🔁 Rejouer & mélanger")
# --- Callbacks ---
def on_reset():
radio_updates = [gr.update(value=None) for _ in range(N_IMAGES)]
warn_update = gr.update(value="", visible=False)
return [*radio_updates, warn_update]
btn_reset.click(on_reset, inputs=None, outputs=[*radio_comps, warn_md])
def on_submit(*args):
state = args[-1]
answers = list(args[:-1])
if any(a is None for a in answers):
missing = sum(1 for a in answers if a is None)
msg = f"❗ Merci de répondre aux **{missing}** image(s) restante(s) avant de valider."
return (
gr.update(value="", visible=False),
gr.update(value=None),
gr.update(value=None),
gr.update(value=None),
gr.update(visible=True),
gr.update(visible=False),
gr.update(value=msg, visible=True),
)
items = list(state)
rows, ok_imgs, ko_imgs = [], [], []
ok = 0
for i, choice in enumerate(answers):
truth = items[i]["truth"]
path = items[i]["path"]
is_ok = (choice == truth)
ok += 1 if is_ok else 0
rows.append([i + 1, items[i]["file"], truth, choice, "✅" if is_ok else "❌"])
(ok_imgs if is_ok else ko_imgs).append(path)
score_txt = f"**Score : {ok}/{N_IMAGES} ({round(100 * ok / N_IMAGES)}%)**"
return (
gr.update(value=score_txt, visible=True),
gr.update(value=rows),
gr.update(value=ok_imgs),
gr.update(value=ko_imgs),
gr.update(visible=False),
gr.update(visible=True),
gr.update(value="", visible=False),
)
btn_submit.click(
on_submit,
inputs=[*radio_comps, state_items],
outputs=[score_md, df, gallery_ok, gallery_ko, quiz_group, result_group, warn_md],
scroll_to_output=True,
)
def restart(state, do_shuffle: bool):
items = list(state)
if do_shuffle:
random.shuffle(items)
img_updates = [gr.update(value=items[i]["path"], label=f"Image {i+1}") for i in range(N_IMAGES)]
radio_updates = [gr.update(value=None) for _ in range(N_IMAGES)]
return [*img_updates, *radio_updates, gr.update(visible=True), gr.update(visible=False), items]
btn_again_same.click(lambda state: restart(state, False),
inputs=[state_items],
outputs=[*image_comps, *radio_comps, quiz_group, result_group, state_items],
)
btn_again_shuffle.click(lambda state: restart(state, True),
inputs=[state_items],
outputs=[*image_comps, *radio_comps, quiz_group, result_group, state_items],
)
return demo
demo = build_interface(ITEMS)
if __name__ == "__main__":
demo.launch()