JimSmith007 commited on
Commit
f0a19a1
·
1 Parent(s): c003bec

Ajout des fichiers pour API segformer

Browse files
Files changed (5) hide show
  1. Dockerfile +19 -0
  2. app.py +70 -0
  3. fonctions.py +1822 -0
  4. requirements.txt +7 -0
  5. segformer_b5.pth +3 -0
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12.10
2
+
3
+ # Crée un utilisateur non root
4
+ RUN useradd -m user
5
+ USER user
6
+ ENV PATH="/home/user/.local/bin:$PATH"
7
+
8
+ # Répertoire de travail
9
+ WORKDIR /home/user/app
10
+
11
+ # Installe les dépendances
12
+ COPY --chown=user requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copie tous les fichiers du projet
16
+ COPY --chown=user . .
17
+
18
+ # Lance l'app FastAPI sur le port 7860
19
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api.py
2
+ from fastapi import FastAPI, UploadFile, File
3
+ from fastapi.responses import StreamingResponse
4
+ import torch
5
+ from fonctions import charger_segformer
6
+ from PIL import Image
7
+ import io
8
+ import numpy as np
9
+ import albumentations as A
10
+ from albumentations.pytorch import ToTensorV2
11
+ import torch.nn.functional as F
12
+
13
+ app = FastAPI()
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # Chargement modèle SegFormer
18
+ model = charger_segformer(num_classes=8)
19
+ model.load_state_dict(torch.load("segformer_b5.pth", map_location=device))
20
+ model.to(device)
21
+ model.eval()
22
+
23
+ # Prétraitement Albumentations
24
+ def preprocess(image: Image.Image) -> torch.Tensor:
25
+ transform = A.Compose([
26
+ A.Resize(256, 256),
27
+ A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
28
+ ToTensorV2()
29
+ ])
30
+ image_np = np.array(image.convert("RGB"))
31
+ transformed = transform(image=image_np)
32
+ return transformed['image'].unsqueeze(0).to(device)
33
+
34
+ # Palette couleur
35
+ PALETTE = {
36
+ 0: (0, 0, 0), 1: (50, 50, 150), 2: (102, 0, 204), 3: (255, 85, 0),
37
+ 4: (255, 255, 0), 5: (0, 255, 255), 6: (255, 0, 255), 7: (255, 255, 255)
38
+ }
39
+
40
+ def decode_mask(mask):
41
+ h, w = mask.shape
42
+ mask_rgb = np.zeros((h, w, 3), dtype=np.uint8)
43
+ for class_id, color in PALETTE.items():
44
+ mask_rgb[mask == class_id] = color
45
+ return mask_rgb
46
+
47
+ @app.get("/")
48
+ def home():
49
+ return {"status": "API avec modèle 'SegFormer' opérationnelle"}
50
+
51
+ @app.post("/predict")
52
+ async def predict(image: UploadFile = File(...)):
53
+ contents = await image.read()
54
+ img = Image.open(io.BytesIO(contents))
55
+
56
+ tensor = preprocess(img)
57
+
58
+ with torch.no_grad():
59
+ logits = model(tensor).logits
60
+ logits = F.interpolate(logits, size=(256, 256), mode="bilinear", align_corners=False)
61
+ pred_mask = logits.argmax(dim=1).squeeze().cpu().numpy()
62
+
63
+ mask_rgb = decode_mask(pred_mask)
64
+ mask_img = Image.fromarray(mask_rgb)
65
+
66
+ buf = io.BytesIO()
67
+ mask_img.save(buf, format="PNG")
68
+ buf.seek(0)
69
+
70
+ return StreamingResponse(buf, media_type="image/png")
fonctions.py ADDED
@@ -0,0 +1,1822 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # fonctions.py
2
+
3
+ #from config import DATA_DIR, RESULTS_DIR
4
+
5
+ # -------------------- FONCTIONS DE BASE DATANT DU PROJET 8 --------------------
6
+
7
+ # fonctions.py
8
+
9
+ # Importations nécessaires
10
+ import os
11
+ import tensorflow as tf
12
+ from cityscapesscripts.helpers.labels import name2label
13
+ from cityscapesscripts.preparation.json2labelImg import json2labelImg
14
+ import json
15
+ import numpy as np
16
+ import albumentations as A
17
+ import cv2
18
+ from tensorflow.keras.utils import Sequence
19
+ from tensorflow.keras.preprocessing.image import load_img, img_to_array
20
+ from albumentations import Compose, HorizontalFlip, Rotate, OneOf, RandomScale, Blur, GaussNoise, Resize
21
+ import matplotlib.pyplot as plt
22
+ from typing import List, Tuple
23
+ from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, MaxPooling2D, UpSampling2D, Concatenate, Resizing, BatchNormalization, Dropout
24
+ from tensorflow.keras.models import Model
25
+ from tqdm import tqdm
26
+ from tensorflow.keras.applications import VGG16, ResNet50
27
+ from tensorflow.keras.callbacks import EarlyStopping, CSVLogger, ReduceLROnPlateau, ModelCheckpoint
28
+ from cityscapesscripts.helpers.labels import trainId2label
29
+ import time
30
+ import segmentation_models as sm
31
+ import pandas as pd
32
+ from pathlib import Path
33
+ from datetime import datetime
34
+ from tensorflow.keras.optimizers import Adam
35
+ import glob
36
+ import torch
37
+ from typing import Tuple
38
+ from torchvision import transforms
39
+ import torch.nn.functional as F
40
+
41
+
42
+ # Définition des classes utiles
43
+ CLASSES_UTILES = {
44
+ "void": 0, "flat": 1, "construction": 2, "object": 3,
45
+ "nature": 4, "sky": 5, "human": 6, "vehicle": 7
46
+ }
47
+
48
+ # Correction du chemin pour Projet 9
49
+ root_path = Path(".") # racine du projet 9
50
+ data_path = root_path / "data"
51
+ cityscapes_scripts_path = root_path / "notebook/cityscapesScripts/cityscapesscripts"
52
+ images_path = data_path / "leftImg8bit"
53
+ masks_path = data_path / "gtFine"
54
+
55
+ class CityscapesDataset(torch.utils.data.Dataset):
56
+ def __init__(self, root, split="train", mode="fine", target_type="semantic", image_size=(512, 512)):
57
+ from torchvision.datasets import Cityscapes
58
+ from torchvision import transforms
59
+ self.dataset = Cityscapes(root=root, split=split, mode="fine", target_type="semantic")
60
+ self.image_size = image_size
61
+ self.transforms = transforms
62
+
63
+ def __len__(self):
64
+ return len(self.dataset)
65
+
66
+ def __getitem__(self, index):
67
+ image, mask = self.dataset[index]
68
+ image = image.resize(self.image_size)
69
+ mask = mask.resize(self.image_size)
70
+
71
+ # Convertir l’image en tenseur
72
+ image = self.transforms.ToTensor()(image)
73
+
74
+ # Convertir le masque en tableau numpy puis appliquer le remapping
75
+ mask_np = np.array(mask).astype(np.uint8)
76
+ mask_remap = remap_classes(mask_np)
77
+
78
+ mask_tensor = torch.from_numpy(mask_remap).long()
79
+ return image, mask_tensor
80
+
81
+ def remap_classes(mask: np.ndarray) -> np.ndarray:
82
+ """
83
+ Convertit les classes Cityscapes originales (0-33) vers les 8 catégories principales définies.
84
+ Retourne un masque avec uniquement des valeurs de 0 à 7.
85
+ """
86
+
87
+ # Nettoyage des valeurs non prévues (ex: 34, 35)
88
+ mask = np.where(mask > 33, 0, mask) # Toute valeur > 33 est convertie en void (classe 0)
89
+
90
+ # Définition précise du mapping basé sur les "labelIds" Cityscapes originaux
91
+ labelIds_to_main_classes = {
92
+ 0: 0, # unlabeled → void
93
+ 1: 0, # ego vehicle → void
94
+ 2: 0, # rectification border → void
95
+ 3: 0, # out of roi → void
96
+ 4: 0, # static → void
97
+ 5: 0, # dynamic → void
98
+ 6: 0, # ground → void
99
+ 7: 1, # road → flat
100
+ 8: 1, # sidewalk → flat
101
+ 9: 0, # parking → void
102
+ 10: 0, # rail track → void
103
+ 11: 2, # building → construction
104
+ 12: 2, # wall → construction
105
+ 13: 2, # fence → construction
106
+ 14: 0, # guard rail → void
107
+ 15: 0, # bridge → void
108
+ 16: 0, # tunnel → void
109
+ 17: 3, # pole → object
110
+ 18: 3, # polegroup → object
111
+ 19: 3, # traffic light → object
112
+ 20: 3, # traffic sign → object
113
+ 21: 4, # vegetation → nature
114
+ 22: 4, # terrain → nature
115
+ 23: 5, # sky → sky
116
+ 24: 6, # person → human
117
+ 25: 6, # rider → human
118
+ 26: 7, # car → vehicle
119
+ 27: 7, # truck → vehicle
120
+ 28: 7, # bus → vehicle
121
+ 29: 7, # caravan → vehicle
122
+ 30: 7, # trailer → vehicle
123
+ 31: 7, # train → vehicle
124
+ 32: 7, # motorcycle → vehicle
125
+ 33: 7 # bicycle → vehicle
126
+ }
127
+
128
+ remapped_mask = np.copy(mask)
129
+ for original_class, new_class in labelIds_to_main_classes.items():
130
+ remapped_mask[mask == original_class] = new_class
131
+
132
+ return remapped_mask.astype(np.uint8)
133
+
134
+
135
+ def view_folder(dossier):
136
+ dossier = Path(dossier)
137
+ if not dossier.exists():
138
+ print(f"❌ Le dossier {dossier} n'existe pas.")
139
+ return
140
+ for sous_dossier in dossier.iterdir():
141
+ if sous_dossier.is_dir():
142
+ print(f"|-- {sous_dossier.name}")
143
+ for sous_sous_dossier in sous_dossier.iterdir():
144
+ if sous_sous_dossier.is_dir():
145
+ print(f" |-- {sous_sous_dossier.name}")
146
+
147
+ def load_image(path: str, target_size: Tuple[int, int]) -> np.ndarray:
148
+ """Charge et normalise une image entre 0 et 1."""
149
+ img = load_img(path, target_size=target_size)
150
+ return img_to_array(img).astype("float32") / 255.0
151
+
152
+ def load_mask(path: str, target_size: Tuple[int, int], mask_mode="labelIds") -> np.ndarray:
153
+ """
154
+ Charge, redimensionne et remappe un masque.
155
+ Applique systématiquement le remapping vers les 8 classes principales.
156
+
157
+ Args:
158
+ path (str): Chemin vers le masque.
159
+ target_size (Tuple[int, int]): Taille de sortie (hauteur, largeur).
160
+ mask_mode (str): "labelIds" pour les masques Cityscapes originaux, "trainIds" sinon.
161
+
162
+ Returns:
163
+ np.ndarray: Masque avec valeurs de classe entre 0 et 7.
164
+ """
165
+ mask = load_img(path, target_size=target_size, color_mode="grayscale")
166
+ mask = img_to_array(mask).astype("uint8").squeeze()
167
+
168
+ # Toujours appliquer le remapping pour garantir 8 classes
169
+ mask = remap_classes(mask)
170
+
171
+ return mask
172
+
173
+ def one_hot_encode_mask(mask: np.ndarray, num_classes: int) -> np.ndarray:
174
+ """Encode un masque en One-Hot."""
175
+
176
+ # Vérifier les valeurs uniques avant l'encodage
177
+ unique_values = np.unique(mask)
178
+ if np.any(unique_values >= num_classes):
179
+ print(f"Attention : Certaines valeurs de masques dépassent {num_classes-1}: {unique_values}")
180
+ mask = np.clip(mask, 0, num_classes - 1)
181
+
182
+ return np.eye(num_classes, dtype=np.uint8)[mask]
183
+
184
+ def decode_mask(mask: np.ndarray) -> np.ndarray:
185
+ """Convertit un masque One-Hot en format indexé."""
186
+ return np.argmax(mask, axis=-1)
187
+
188
+ def get_augmentations(image_size: Tuple[int, int]) -> Compose:
189
+ """Définit les transformations Albumentations pour l'entraînement."""
190
+ return Compose([
191
+ HorizontalFlip(p=0.2),
192
+ Rotate(limit=15, p=0.2),
193
+ RandomScale(scale_limit=0.1, p=0.2),
194
+ Resize(*image_size, interpolation=cv2.INTER_NEAREST)
195
+ ])
196
+
197
+ class DataGenerator(Sequence):
198
+ def __init__(self, image_paths, mask_paths, image_size=(256, 256), batch_size=16, num_classes=8, # TEST avec 512x512, 1024x1024, 512x1024, 1024x512, 256x512 et 512x256
199
+ shuffle=True, augmentation_ratio=1.0, use_cache=False):
200
+ self.image_paths = image_paths
201
+ self.mask_paths = mask_paths
202
+ self.image_size = image_size
203
+ self.batch_size = batch_size
204
+ self.num_classes = num_classes
205
+ self.shuffle = shuffle
206
+ self.augmentation_ratio = augmentation_ratio
207
+ self.use_cache = use_cache
208
+ self.cache = {} # Cache des masques transformés
209
+ self.augmentation = get_augmentations(image_size)
210
+ self.on_epoch_end()
211
+
212
+ def __getitem__(self, index):
213
+ start_time = time.time()
214
+ start = index * self.batch_size
215
+ end = start + self.batch_size
216
+ batch_image_paths = self.image_paths[start:end]
217
+ batch_mask_paths = self.mask_paths[start:end]
218
+
219
+ batch_images, batch_masks = [], []
220
+
221
+ for img_path, mask_path in zip(batch_image_paths, batch_mask_paths):
222
+ img = load_image(img_path, self.image_size)
223
+
224
+ if self.use_cache and mask_path in self.cache:
225
+ mask = self.cache[mask_path]
226
+ else:
227
+ mask = load_mask(mask_path, self.image_size, mask_mode="trainIds")
228
+ if self.use_cache:
229
+ self.cache[mask_path] = mask
230
+
231
+ if np.random.rand() < self.augmentation_ratio:
232
+ augmented = self.augmentation(image=img, mask=mask)
233
+ img, mask = augmented["image"], augmented["mask"]
234
+
235
+ batch_images.append(img)
236
+ batch_masks.append(one_hot_encode_mask(mask, self.num_classes))
237
+
238
+ elapsed_time = time.time() - start_time
239
+ # print(f"📊 Génération batch {index} en {elapsed_time:.2f}s")
240
+
241
+ return np.stack(batch_images), np.stack(batch_masks)
242
+
243
+ def __len__(self):
244
+ """Renvoie le nombre total de batches par epoch."""
245
+ return int(np.ceil(len(self.image_paths) / self.batch_size))
246
+
247
+ def on_epoch_end(self) -> None:
248
+ """Mélange les données après chaque epoch si shuffle est activé."""
249
+ if self.shuffle:
250
+ data = list(zip(self.image_paths, self.mask_paths))
251
+ np.random.shuffle(data)
252
+ self.image_paths, self.mask_paths = zip(*data)
253
+
254
+ def visualize_batch(self, num_images: int = 5) -> None:
255
+ """Affiche correctement un lot d'images et de masques."""
256
+ batch_images, batch_masks = self.__getitem__(0)
257
+ num_images = min(num_images, len(batch_images))
258
+ fig, axes = plt.subplots(num_images, 2, figsize=(10, num_images * 5))
259
+
260
+ for i in range(num_images):
261
+ axes[i, 0].imshow(batch_images[i])
262
+ axes[i, 0].set_title("Image")
263
+ axes[i, 0].axis("off")
264
+
265
+ axes[i, 1].imshow(decode_mask(batch_masks[i]), cmap="inferno")
266
+ axes[i, 1].set_title("Mask (decoded)")
267
+ axes[i, 1].axis("off")
268
+
269
+ plt.tight_layout()
270
+ plt.show()
271
+
272
+
273
+ # Test du DataGenerator
274
+ if __name__ == "__main__":
275
+ train_gen = DataGenerator(
276
+ image_paths=train_input_img_paths,
277
+ mask_paths=train_label_ids_img_paths,
278
+ image_size=(256, 256), # TEST avec 512x512
279
+ batch_size=16, # TEST: 8, 16 ou 32
280
+ num_classes=8,
281
+ shuffle=True,
282
+ augmentation_ratio=0.5
283
+ )
284
+
285
+ train_gen.visualize_batch(num_images=3)
286
+
287
+ def on_epoch_end(self) -> None:
288
+ """Mélange les données après chaque epoch si shuffle est activé."""
289
+ if self.shuffle:
290
+ data = list(zip(self.image_paths, self.mask_paths))
291
+ np.random.shuffle(data)
292
+ self.image_paths, self.mask_paths = zip(*data)
293
+
294
+ def visualize_batch(self, num_images: int = 5) -> None:
295
+ """Affiche correctement un lot d'images et de masques."""
296
+ batch_images, batch_masks = self.__getitem__(0)
297
+ num_images = min(num_images, len(batch_images))
298
+ fig, axes = plt.subplots(num_images, 2, figsize=(10, num_images * 5))
299
+
300
+ for i in range(num_images):
301
+ axes[i, 0].imshow(batch_images[i])
302
+ axes[i, 0].set_title("Image")
303
+ axes[i, 0].axis("off")
304
+
305
+ axes[i, 1].imshow(decode_mask(batch_masks[i]), cmap="inferno")
306
+ axes[i, 1].set_title("Mask (decoded)")
307
+ axes[i, 1].axis("off")
308
+
309
+ plt.tight_layout()
310
+ plt.show()
311
+
312
+ def iou_coef(y_true, y_pred, smooth=1e-6):
313
+ """
314
+ Calcule l'Intersection over Union (IoU).
315
+ Correction : conversion explicite en float32.
316
+ """
317
+ y_true = tf.keras.backend.cast(y_true, "float32")
318
+ y_pred = tf.keras.backend.cast(y_pred, "float32")
319
+ y_true_f = tf.keras.backend.flatten(y_true)
320
+ y_pred_f = tf.keras.backend.flatten(y_pred)
321
+ intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
322
+ union = tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) - intersection
323
+ return (intersection + smooth) / (union + smooth)
324
+
325
+
326
+
327
+ def get_logger(nom_modele: str):
328
+ """
329
+ Crée un CSVLogger pour enregistrer les métriques d'entraînement dans un fichier horodaté.
330
+ """
331
+ from datetime import datetime
332
+ from tensorflow.keras.callbacks import CSVLogger
333
+
334
+ RESULTS_DIR.mkdir(parents=True, exist_ok=True)
335
+
336
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
337
+ log_filename = RESULTS_DIR / f"{nom_modele}_{timestamp}.csv"
338
+
339
+ return CSVLogger(log_filename, separator=",", append=False)
340
+
341
+ def charger_metriques(dossier_logs):
342
+ """
343
+ Charge tous les fichiers CSV de métriques présents dans un dossier.
344
+
345
+ Args:
346
+ dossier_logs (str): Chemin vers le dossier contenant les fichiers CSV.
347
+
348
+ Returns:
349
+ dict: Dictionnaire avec nom du modèle en clé et dataframe en valeur.
350
+ """
351
+ fichiers = glob.glob(os.path.join(dossier_logs, "*.csv"))
352
+ resultats = {}
353
+
354
+ for fichier in fichiers:
355
+ # Récupère le nom complet du modèle (par exemple unet_mini, unet_vgg16)
356
+ nom_modele = "_".join(os.path.basename(fichier).split("_")[:-2])
357
+ df = pd.read_csv(fichier)
358
+ resultats[nom_modele] = df
359
+
360
+ return resultats
361
+
362
+ def tracer_metriques(resultats):
363
+ """
364
+ Trace les métriques des différents modèles sur des graphiques.
365
+
366
+ Args:
367
+ resultats (dict): Dictionnaire avec nom modèle et dataframe.
368
+ """
369
+
370
+ # Palette de couleurs spécifique pour chaque modèle
371
+ couleurs = {
372
+ "mini": "blue",
373
+ "vgg16": "green",
374
+ "resnet50": "red",
375
+ "efficientnetb3": "purple"
376
+ }
377
+
378
+ plt.figure(figsize=(18, 18))
379
+
380
+ # Graphique de Loss (Perte)
381
+ plt.subplot(3, 2, 1)
382
+ for modele, df in resultats.items():
383
+ couleur = couleurs.get(modele, "black")
384
+ plt.plot(df["loss"], label=f"{modele} Train Loss", color=couleur, linestyle="--")
385
+ plt.plot(df["val_loss"], label=f"{modele} Val Loss", color=couleur, linestyle="-")
386
+ plt.title("Comparaison des Loss (Perte)")
387
+ plt.xlabel("Epochs")
388
+ plt.ylabel("Loss")
389
+ plt.grid(True)
390
+ plt.legend()
391
+
392
+ # Graphique Mean IoU
393
+ plt.subplot(3, 2, 2)
394
+ for modele, df in resultats.items():
395
+ couleur = couleurs.get(modele, "black")
396
+ if "mean_iou" in df.columns:
397
+ plt.plot(df["mean_iou"], label=f"{modele} Train Mean IoU", color=couleur, linestyle="--")
398
+ plt.plot(df["val_mean_iou"], label=f"{modele} Val Mean IoU", color=couleur, linestyle="-")
399
+ elif "iou_score" in df.columns:
400
+ plt.plot(df["iou_score"], label=f"{modele} Train IoU Score", color=couleur, linestyle="--")
401
+ plt.plot(df["val_iou_score"], label=f"{modele} Val IoU Score", color=couleur, linestyle="-")
402
+ plt.title("Comparaison du Mean IoU / IoU Score")
403
+ plt.xlabel("Epochs")
404
+ plt.ylabel("Mean IoU")
405
+ plt.grid(True)
406
+ plt.legend()
407
+
408
+ # Graphique Dice Coefficient
409
+ plt.subplot(3, 2, 3)
410
+ for modele, df in resultats.items():
411
+ couleur = couleurs.get(modele, "black")
412
+ if "dice_coef" in df.columns:
413
+ plt.plot(df["dice_coef"], label=f"{modele} Train Dice", color=couleur, linestyle="--")
414
+ plt.plot(df["val_dice_coef"], label=f"{modele} Val Dice", color=couleur, linestyle="-")
415
+ plt.title("Comparaison du Dice Coefficient")
416
+ plt.xlabel("Epochs")
417
+ plt.ylabel("Dice Coefficient")
418
+ plt.grid(True)
419
+ plt.legend()
420
+
421
+ # Graphique Accuracy
422
+ plt.subplot(3, 2, 4)
423
+ for modele, df in resultats.items():
424
+ couleur = couleurs.get(modele, "black")
425
+ if "accuracy" in df.columns:
426
+ plt.plot(df["accuracy"], label=f"{modele} Train Accuracy", color=couleur, linestyle="--")
427
+ plt.plot(df["val_accuracy"], label=f"{modele} Val Accuracy", color=couleur, linestyle="-")
428
+ plt.title("Comparaison de l'Accuracy")
429
+ plt.xlabel("Epochs")
430
+ plt.ylabel("Accuracy")
431
+ plt.grid(True)
432
+ plt.legend()
433
+
434
+ # Graphique Temps d'entraînement par modèle
435
+ plt.subplot(3, 1, 3)
436
+ temps_entrainement = {}
437
+ for modele, df in resultats.items():
438
+ couleur = couleurs.get(modele, "black")
439
+ if "temps_total_sec" in df.columns:
440
+ temps = df["temps_total_sec"].iloc[-1] / 60 # converti en minutes
441
+ temps_entrainement[modele] = temps
442
+ plt.bar(modele, temps, color=couleur)
443
+ plt.text(modele, temps, f"{temps:.2f} min", ha="center", va="bottom")
444
+
445
+ plt.title("Comparaison du Temps total d'entraînement (en minutes)")
446
+ plt.ylabel("Temps (minutes)")
447
+ plt.grid(True, axis="y")
448
+
449
+ plt.tight_layout()
450
+ plt.show()
451
+
452
+ # -------------------- NOUVELLES FONCTIONS POUR PROJET 9 --------------------
453
+
454
+ def charger_oneformer(num_classes: int = 8):
455
+ """
456
+ Charge le modèle OneFormer adapté au dataset Cityscapes.
457
+ """
458
+ from transformers import OneFormerForSemanticSegmentation
459
+ model = OneFormerForSemanticSegmentation.from_pretrained("nvidia/oneformer_coco_swin_large")
460
+ model.config.num_labels = num_classes
461
+ return model
462
+
463
+
464
+ def charger_segnext(num_classes: int = 8):
465
+ """
466
+ Charge le modèle SegNeXt-L (simplifié avec timm ou autre wrapper).
467
+ """
468
+ import timm
469
+ model = timm.create_model("segnext_l", pretrained=True, num_classes=num_classes)
470
+ return model
471
+
472
+ def entrainer_model_pytorch(
473
+ model,
474
+ train_loader,
475
+ val_loader,
476
+ model_name="model",
477
+ epochs=10,
478
+ lr=1e-4,
479
+ num_classes=8
480
+ ):
481
+ """
482
+ Entraîne un modèle PyTorch de segmentation avec :
483
+ - Mixed Precision (torch.cuda.amp)
484
+ - GradScaler pour la stabilité
485
+ - Scheduler 'ReduceLROnPlateau'
486
+ - Gestion de la sortie pour SegFormer (SemanticSegmenterOutput)
487
+ ou un simple tenseur
488
+ - Upsampling de la sortie pour correspondre au masque (H, W)
489
+ - Calcul et log des métriques (accuracy, Dice, IoU) pour train et val
490
+ - Mesure du temps par epoch et de la mémoire GPU peak
491
+ - Sauvegarde CSV + .pth dans '../resultats_modeles/'
492
+ - Génération d'un graphique PNG de l'évolution du Dice et du Mean IoU.
493
+ """
494
+
495
+ import torch
496
+ import torch.nn as nn
497
+ import torch.optim as optim
498
+ import torch.optim.lr_scheduler as lr_sched
499
+ from torch.cuda.amp import autocast, GradScaler
500
+ from transformers.modeling_outputs import SemanticSegmenterOutput
501
+ from tqdm import tqdm
502
+ import pandas as pd
503
+ import matplotlib.pyplot as plt
504
+ import os
505
+ import time
506
+ import torch.nn.functional as F
507
+
508
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
509
+ model.to(device)
510
+
511
+ # -------- Définition locale des métriques PyTorch (évite doublons) --------
512
+ def compute_batch_metrics(pred_logits, target, num_classes):
513
+ """
514
+ Calcule accuracy, Dice et IoU moyens (macro) pour un batch.
515
+ - pred_logits: (N, C, H, W)
516
+ - target: (N, H, W) (valeurs entières [0..num_classes-1])
517
+ Retourne un dict: {"accuracy": float, "dice": float, "iou": float}
518
+ """
519
+ # 1) Conversion argmax => (N, H, W)
520
+ pred = torch.argmax(pred_logits, dim=1)
521
+
522
+ # 2) Accuracy globale (tous pixels confondus)
523
+ correct = (pred == target).sum().item()
524
+ total = target.numel() # N*H*W
525
+ accuracy = correct / total
526
+
527
+ # 3) Intersection / union par classe => Dice, IoU
528
+ dice_list = []
529
+ iou_list = []
530
+
531
+ for c in range(num_classes):
532
+ pred_c = (pred == c)
533
+ target_c = (target == c)
534
+
535
+ inter = (pred_c & target_c).sum().item()
536
+ pred_area = pred_c.sum().item()
537
+ target_area = target_c.sum().item()
538
+ union = pred_area + target_area - inter
539
+
540
+ # IoU
541
+ if union == 0:
542
+ # classe absente dans les 2 => convention IoU = 1
543
+ iou_c = 1.0
544
+ else:
545
+ iou_c = inter / union
546
+
547
+ # Dice = 2*inter / (|pred_c| + |target_c|)
548
+ denom = pred_area + target_area
549
+ if denom == 0:
550
+ dice_c = 1.0
551
+ else:
552
+ dice_c = 2.0 * inter / denom
553
+
554
+ dice_list.append(dice_c)
555
+ iou_list.append(iou_c)
556
+
557
+ mean_dice = sum(dice_list) / len(dice_list)
558
+ mean_iou = sum(iou_list) / len(iou_list)
559
+
560
+ return {"accuracy": accuracy, "dice": mean_dice, "iou": mean_iou}
561
+
562
+ # -------- Setup Optim / Loss / Scheduler / GradScaler --------
563
+ criterion = nn.CrossEntropyLoss()
564
+ optimizer = optim.Adam(model.parameters(), lr=lr)
565
+ scheduler = lr_sched.ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)
566
+ scaler = GradScaler()
567
+
568
+ os.makedirs("../resultats_modeles", exist_ok=True)
569
+
570
+ # -------- Structure du log --------
571
+ log = {
572
+ "epoch": [],
573
+ "train_loss": [],
574
+ "val_loss": [],
575
+ "train_accuracy": [],
576
+ "train_dice_coef": [],
577
+ "train_mean_iou": [],
578
+ "val_accuracy": [],
579
+ "val_dice_coef": [],
580
+ "val_mean_iou": [],
581
+ "epoch_time_s": [],
582
+ "peak_gpu_mem_mb": []
583
+ }
584
+
585
+ start_time = time.time()
586
+
587
+ # ============================ BOUCLE D'ENTRAÎNEMENT ============================
588
+ for epoch in range(epochs):
589
+ # Pour mesurer le pic de mémoire GPU sur l'epoch
590
+ torch.cuda.reset_peak_memory_stats(device=device)
591
+ epoch_start = time.time()
592
+
593
+ # -------- TRAIN LOOP --------
594
+ model.train()
595
+ running_loss = 0.0
596
+ running_accuracy = 0.0
597
+ running_dice = 0.0
598
+ running_iou = 0.0
599
+
600
+ for images, masks in tqdm(train_loader, desc=f"[Epoch {epoch+1}/{epochs}] Train"):
601
+ images, masks = images.to(device), masks.to(device)
602
+ optimizer.zero_grad()
603
+
604
+ with autocast():
605
+ outdict = model(images)
606
+ # Gérer SegFormer / DeepLab / simple Tensor
607
+ if isinstance(outdict, SemanticSegmenterOutput):
608
+ logits = outdict.logits
609
+ elif isinstance(outdict, dict):
610
+ logits = outdict["out"]
611
+ else:
612
+ logits = outdict
613
+
614
+ # Upsample -> (N, C, H, W) = taille de masks
615
+ logits = F.interpolate(
616
+ logits,
617
+ size=(masks.shape[-2], masks.shape[-1]),
618
+ mode='bilinear',
619
+ align_corners=False
620
+ )
621
+
622
+ loss = criterion(logits, masks)
623
+
624
+ scaler.scale(loss).backward()
625
+ scaler.step(optimizer)
626
+ scaler.update()
627
+
628
+ running_loss += loss.item()
629
+
630
+ # Calcul des métriques sur ce batch
631
+ metrics_batch = compute_batch_metrics(logits, masks, num_classes=num_classes)
632
+ running_accuracy += metrics_batch["accuracy"]
633
+ running_dice += metrics_batch["dice"]
634
+ running_iou += metrics_batch["iou"]
635
+
636
+ avg_train_loss = running_loss / len(train_loader)
637
+ avg_train_accuracy = running_accuracy / len(train_loader)
638
+ avg_train_dice = running_dice / len(train_loader)
639
+ avg_train_iou = running_iou / len(train_loader)
640
+
641
+ # -------- VALID LOOP --------
642
+ model.eval()
643
+ val_running_loss = 0.0
644
+ val_running_accuracy = 0.0
645
+ val_running_dice = 0.0
646
+ val_running_iou = 0.0
647
+
648
+ with torch.no_grad():
649
+ for images, masks in tqdm(val_loader, desc=f"[Epoch {epoch+1}/{epochs}] Val"):
650
+ images, masks = images.to(device), masks.to(device)
651
+ with autocast():
652
+ outdict = model(images)
653
+ if isinstance(outdict, SemanticSegmenterOutput):
654
+ logits = outdict.logits
655
+ elif isinstance(outdict, dict):
656
+ logits = outdict["out"]
657
+ else:
658
+ logits = outdict
659
+
660
+ logits = F.interpolate(
661
+ logits,
662
+ size=(masks.shape[-2], masks.shape[-1]),
663
+ mode='bilinear',
664
+ align_corners=False
665
+ )
666
+
667
+ loss_val = criterion(logits, masks)
668
+
669
+ val_running_loss += loss_val.item()
670
+
671
+ metrics_batch_val = compute_batch_metrics(logits, masks, num_classes=num_classes)
672
+ val_running_accuracy += metrics_batch_val["accuracy"]
673
+ val_running_dice += metrics_batch_val["dice"]
674
+ val_running_iou += metrics_batch_val["iou"]
675
+
676
+ avg_val_loss = val_running_loss / len(val_loader)
677
+ avg_val_accuracy = val_running_accuracy / len(val_loader)
678
+ avg_val_dice = val_running_dice / len(val_loader)
679
+ avg_val_iou = val_running_iou / len(val_loader)
680
+
681
+ # -------- Scheduler : ReduceLROnPlateau --------
682
+ scheduler.step(avg_val_loss)
683
+
684
+ # -------- Log de fin d’epoch --------
685
+ epoch_time = time.time() - epoch_start
686
+ peak_mem = torch.cuda.max_memory_allocated(device=device)
687
+ peak_mem_mb = peak_mem / (1024 ** 2)
688
+
689
+ log["epoch"].append(epoch + 1)
690
+ log["train_loss"].append(avg_train_loss)
691
+ log["val_loss"].append(avg_val_loss)
692
+ log["train_accuracy"].append(avg_train_accuracy)
693
+ log["train_dice_coef"].append(avg_train_dice)
694
+ log["train_mean_iou"].append(avg_train_iou)
695
+ log["val_accuracy"].append(avg_val_accuracy)
696
+ log["val_dice_coef"].append(avg_val_dice)
697
+ log["val_mean_iou"].append(avg_val_iou)
698
+ log["epoch_time_s"].append(epoch_time)
699
+ log["peak_gpu_mem_mb"].append(peak_mem_mb)
700
+
701
+ print(
702
+ f"📉 Epoch {epoch+1} | "
703
+ f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | "
704
+ f"Train Dice: {avg_train_dice:.4f} | Val Dice: {avg_val_dice:.4f} | "
705
+ f"Train IoU: {avg_train_iou:.4f} | Val IoU: {avg_val_iou:.4f} | "
706
+ f"Time: {epoch_time:.1f}s | GPU: {peak_mem_mb:.1f} MB"
707
+ )
708
+
709
+ # ============================ FIN DE L'ENTRAÎNEMENT ============================
710
+ total_time = time.time() - start_time
711
+
712
+ # -------- Sauvegarde du log en CSV --------
713
+ df = pd.DataFrame(log)
714
+ df["temps_total_sec"] = total_time
715
+ os.makedirs("../resultats_modeles", exist_ok=True)
716
+ csv_path = f"../resultats_modeles/{model_name}_log.csv"
717
+ df.to_csv(csv_path, index=False)
718
+
719
+ # -------- Sauvegarde des poids --------
720
+ torch.save(model.state_dict(), f"../resultats_modeles/{model_name}.pth")
721
+
722
+ # -------- Génération et sauvegarde d'un graphique (Dice/IoU) --------
723
+ plt.figure(figsize=(12, 5))
724
+
725
+ # Subplot 1 : Dice
726
+ plt.subplot(1, 2, 1)
727
+ plt.plot(df["epoch"], df["train_dice_coef"], label="Train Dice", color="blue")
728
+ plt.plot(df["epoch"], df["val_dice_coef"], label="Val Dice", color="orange")
729
+ plt.title("Dice Coefficient")
730
+ plt.xlabel("Epoch")
731
+ plt.ylabel("Dice")
732
+ plt.legend()
733
+ plt.grid(True)
734
+
735
+ # Subplot 2 : IoU
736
+ plt.subplot(1, 2, 2)
737
+ plt.plot(df["epoch"], df["train_mean_iou"], label="Train IoU", color="blue")
738
+ plt.plot(df["epoch"], df["val_mean_iou"], label="Val IoU", color="orange")
739
+ plt.title("Mean IoU")
740
+ plt.xlabel("Epoch")
741
+ plt.ylabel("IoU")
742
+ plt.legend()
743
+ plt.grid(True)
744
+
745
+ plt.tight_layout()
746
+ png_path = f"../resultats_modeles/{model_name}_dice_iou.png"
747
+ plt.savefig(png_path, dpi=100)
748
+ plt.close()
749
+
750
+ print(f"✅ Entraînement {model_name} terminé en {total_time:.1f} secondes.")
751
+ print(f"📁 Logs : {csv_path}")
752
+ print(f"📁 Modèle : ../resultats_modeles/{model_name}.pth")
753
+ print(f"📊 Graphique Dice/IoU sauvegardé : {png_path}")
754
+
755
+ def comparer_resultats(dossier='../resultats_modeles'):
756
+ """
757
+ Affiche les courbes d'apprentissage de chaque modèle entraîné.
758
+ """
759
+ import matplotlib.pyplot as plt
760
+ import pandas as pd
761
+ import os
762
+
763
+ plt.figure(figsize=(10, 6))
764
+ for file in os.listdir(dossier):
765
+ if file.endswith("_log.csv"):
766
+ df = pd.read_csv(os.path.join(dossier, file))
767
+ nom = file.replace("_log.csv", "")
768
+ plt.plot(df["epoch"], df["train_loss"], label=f"{nom} train")
769
+ plt.plot(df["epoch"], df["val_loss"], label=f"{nom} val")
770
+ plt.title("Courbes d'apprentissage")
771
+ plt.xlabel("Epoch")
772
+ plt.ylabel("Loss")
773
+ plt.legend()
774
+ plt.grid(True)
775
+ plt.tight_layout()
776
+ plt.show()
777
+
778
+ # ---------------------- FONCTIONS REECRITE POUR LE PROJET 9 --------------------
779
+
780
+ def charger_donnees_cityscapes(data_dir: str, batch_size: int = 16, image_size: Tuple[int, int] = (256, 256)):
781
+ """
782
+ Charge les données Cityscapes et retourne deux DataLoaders (train et val).
783
+ Utilise CityscapesDataset, et applique:
784
+ - num_workers=4
785
+ - pin_memory=True
786
+ pour des perfs optimales sur GPU
787
+ """
788
+ from torch.utils.data import DataLoader
789
+
790
+ train_dataset = CityscapesDataset(root=data_dir, split="train", image_size=image_size)
791
+ val_dataset = CityscapesDataset(root=data_dir, split="val", image_size=image_size)
792
+
793
+ train_loader = DataLoader(
794
+ train_dataset,
795
+ batch_size=batch_size,
796
+ shuffle=True,
797
+ num_workers=0,
798
+ pin_memory=True
799
+ )
800
+ val_loader = DataLoader(
801
+ val_dataset,
802
+ batch_size=batch_size,
803
+ shuffle=False,
804
+ num_workers=0,
805
+ pin_memory=True
806
+ )
807
+
808
+ return train_loader, val_loader
809
+
810
+ import matplotlib.patches as mpatches
811
+
812
+ # Palette colorimétrique douce (8 classes utiles)
813
+ PALETTE = {
814
+ 0: (0, 0, 0), # void → noir
815
+ 1: (50, 50, 150), # flat → bleu foncé
816
+ 2: (102, 0, 204), # construction → violet
817
+ 3: (255, 85, 0), # object → orange
818
+ 4: (255, 255, 0), # nature → jaune
819
+ 5: (0, 255, 255), # sky → cyan
820
+ 6: (255, 0, 255), # human → magenta
821
+ 7: (255, 255, 255), # vehicle → blanc
822
+ }
823
+
824
+ CLASS_NAMES = {
825
+ 0: "void",
826
+ 1: "flat",
827
+ 2: "construction",
828
+ 3: "object",
829
+ 4: "nature",
830
+ 5: "sky",
831
+ 6: "human",
832
+ 7: "vehicle"
833
+ }
834
+
835
+ def decode_cityscapes_mask(mask):
836
+ """
837
+ Convertit un masque 2D (valeurs de 0 à 7) en image RGB pour affichage.
838
+ """
839
+ h, w = mask.shape
840
+ mask_rgb = np.zeros((h, w, 3), dtype=np.uint8)
841
+ for class_id, color in PALETTE.items():
842
+ mask_rgb[mask == class_id] = color
843
+ return mask_rgb
844
+
845
+ def afficher_image_et_masque(image_tensor, mask_tensor):
846
+ import matplotlib.pyplot as plt
847
+ from matplotlib.colors import ListedColormap
848
+ import numpy as np
849
+
850
+ PALETTE = [
851
+ (0, 0, 0), # 0 - void
852
+ (100, 0, 200), # 1 - flat
853
+ (70, 70, 70), # 2 - construction
854
+ (250, 170, 30), # 3 - object
855
+ (107, 142, 35), # 4 - nature
856
+ (70, 130, 180), # 5 - sky
857
+ (220, 20, 60), # 6 - human
858
+ (0, 0, 142), # 7 - vehicle
859
+ ]
860
+ PALETTE_NP = np.array(PALETTE) / 255.0
861
+ cmap = ListedColormap(PALETTE_NP)
862
+
863
+ image_np = image_tensor.permute(1, 2, 0).cpu().numpy()
864
+ mask_np = mask_tensor.cpu().numpy()
865
+
866
+ plt.figure(figsize=(12, 5))
867
+
868
+ plt.subplot(1, 2, 1)
869
+ plt.imshow(image_np)
870
+ plt.title("Image")
871
+ plt.axis("off")
872
+
873
+ plt.subplot(1, 2, 2)
874
+ im = plt.imshow(mask_np, cmap=cmap, vmin=0, vmax=7)
875
+ cbar = plt.colorbar(im, ticks=range(8))
876
+ cbar.ax.set_yticklabels(['void', 'flat', 'construction', 'object', 'nature', 'sky', 'human', 'vehicle'])
877
+ cbar.set_label("Catégories", rotation=270, labelpad=15)
878
+ plt.title("Masque (8 classes colorisées)")
879
+ plt.axis("off")
880
+
881
+ plt.tight_layout()
882
+ plt.show()
883
+
884
+ def charger_segformer(num_classes=8):
885
+ from transformers import SegformerForSemanticSegmentation
886
+
887
+ model = SegformerForSemanticSegmentation.from_pretrained(
888
+ "nvidia/segformer-b5-finetuned-ade-640-640",
889
+ num_labels=8,
890
+ ignore_mismatched_sizes=True
891
+ )
892
+ model.config.num_labels = num_classes
893
+ model.config.output_hidden_states = False
894
+ return model
895
+
896
+ def charger_deeplabv3plus(num_classes=8):
897
+ import torchvision.models.segmentation as models
898
+ import torch.nn as nn
899
+
900
+ model = models.deeplabv3_resnet101(pretrained=True)
901
+ model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1)
902
+ return model
903
+
904
+ class MiniCityscapesDataset(torch.utils.data.Dataset):
905
+ def __init__(self, image_paths, mask_paths, image_size=(256, 256)):
906
+ self.image_paths = image_paths
907
+ self.mask_paths = mask_paths
908
+ self.image_size = image_size
909
+
910
+ def __len__(self):
911
+ return len(self.image_paths)
912
+
913
+ def __getitem__(self, idx):
914
+ # Charger l’image et le masque
915
+ image_path = self.image_paths[idx]
916
+ mask_path = self.mask_paths[idx]
917
+
918
+ # Charger l’image
919
+ from PIL import Image
920
+ image = Image.open(image_path).convert("RGB").resize(self.image_size)
921
+
922
+ # Charger le masque
923
+ mask = Image.open(mask_path).convert("L").resize(self.image_size)
924
+
925
+ # Convertir en tenseur PyTorch
926
+ import torchvision.transforms as T
927
+ to_tensor = T.ToTensor()
928
+ image = to_tensor(image) # shape (3, H, W)
929
+
930
+ # Numpy + remap classes
931
+ import numpy as np
932
+ mask_np = np.array(mask, dtype=np.uint8)
933
+
934
+ # Remap
935
+ mask_np = remap_classes(mask_np)
936
+ mask_tensor = torch.from_numpy(mask_np).long() # shape (H, W)
937
+
938
+ return image, mask_tensor
939
+
940
+ def show_predictions(model, dataset, num_images=3, num_classes=8):
941
+ """
942
+ Affiche quelques prédictions vs masques réels depuis un dataset PyTorch.
943
+ Gère upsample, SegFormer / DeepLab / etc.
944
+ """
945
+ import torch
946
+ import matplotlib.pyplot as plt
947
+ from transformers.modeling_outputs import SemanticSegmenterOutput
948
+ import torch.nn.functional as F
949
+
950
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
951
+ model.eval().to(device)
952
+
953
+ fig, axes = plt.subplots(num_images, 3, figsize=(12, 4 * num_images))
954
+
955
+ for i in range(num_images):
956
+ # Choisir un index aléatoire
957
+ idx = np.random.randint(0, len(dataset))
958
+ image, mask_gt = dataset[idx] # (3, H, W), (H, W)
959
+
960
+ image_t = image.unsqueeze(0).to(device) # (1, 3, H, W)
961
+ mask_gt_np = mask_gt.numpy() # (H, W)
962
+
963
+ with torch.no_grad():
964
+ outdict = model(image_t)
965
+ if isinstance(outdict, SemanticSegmenterOutput):
966
+ logits = outdict.logits
967
+ elif isinstance(outdict, dict):
968
+ logits = outdict["out"]
969
+ else:
970
+ logits = outdict
971
+
972
+ logits = F.interpolate(
973
+ logits,
974
+ size=mask_gt.shape,
975
+ mode='bilinear',
976
+ align_corners=False
977
+ )
978
+ pred = logits.argmax(dim=1).squeeze(0).cpu().numpy() # (H, W)
979
+
980
+ # AFFICHAGES
981
+ axes[i, 0].imshow(image.permute(1, 2, 0).numpy())
982
+ axes[i, 0].set_title("Image")
983
+ axes[i, 0].axis("off")
984
+
985
+ axes[i, 1].imshow(mask_gt_np, cmap="tab10", vmin=0, vmax=num_classes-1)
986
+ axes[i, 1].set_title("Masque GT")
987
+ axes[i, 1].axis("off")
988
+
989
+ axes[i, 2].imshow(pred, cmap="tab10", vmin=0, vmax=num_classes-1)
990
+ axes[i, 2].set_title("Masque Prédit")
991
+ axes[i, 2].axis("off")
992
+
993
+ plt.tight_layout()
994
+ plt.show()
995
+
996
+ def charger_maskformer(num_classes=8):
997
+ """
998
+ Charge un modèle MaskFormer (HuggingFace Transformers)
999
+ pour la segmentation.
1000
+ S'appuie sur un checkpoint préentraîné sur ADE20K.
1001
+ """
1002
+ from transformers import MaskFormerForInstanceSegmentation
1003
+
1004
+ # Exemple : "facebook/maskformer-swin-large-ade" (semantic sur ADE20K)
1005
+ # ou "facebook/maskformer-swin-base-coco" (panoptic/instance, COCO)
1006
+ # À adapter selon votre besoin.
1007
+ checkpoint = "facebook/maskformer-swin-large-ade"
1008
+
1009
+ model = MaskFormerForInstanceSegmentation.from_pretrained(
1010
+ checkpoint,
1011
+ ignore_mismatched_sizes=True # parfois nécessaire si on change num_labels
1012
+ )
1013
+
1014
+ # Ajuster le nombre de classes pour Cityscapes (8)
1015
+ model.config.num_labels = num_classes
1016
+ # Facultatif : désactiver l'output des hidden states
1017
+ model.config.output_hidden_states = False
1018
+
1019
+ return model
1020
+
1021
+
1022
+ import torch
1023
+ import torch.nn.functional as F
1024
+
1025
+ def maskformer_aggregator(
1026
+ class_queries_logits: torch.Tensor,
1027
+ masks_queries_logits: torch.Tensor
1028
+ ) -> torch.Tensor:
1029
+ """
1030
+ Combine les prédictions de Mask(2)Former (class_queries_logits, masks_queries_logits)
1031
+ en un tenseur de forme (N, C, H, W) pour la segmentation sémantique.
1032
+
1033
+ Hypothèses :
1034
+ - class_queries_logits: (N, Q, C) [logits par classe pour chaque query]
1035
+ - masks_queries_logits: (N, Q, H, W) [logits masques (souvent à interpréter en sigmoid)]
1036
+
1037
+ Approche naïve :
1038
+ 1) On transforme class_queries_logits en probabilités par softmax sur la dimension 'classe' (C).
1039
+ 2) On applique une sigmoïde sur masks_queries_logits pour obtenir p(query=1) par pixel.
1040
+ 3) On effectue un produit de chacun de ces masques par la proba de sa classe,
1041
+ puis on somme sur la dimension 'Q' pour obtenir un tenseur (N, C, H, W).
1042
+ 4) On laisse ce tenseur en l'état (non normalisé) pour que CrossEntropyLoss effectue
1043
+ son propre softmax. On l'appelle 'aggregated_logits'.
1044
+
1045
+ Résultat :
1046
+ aggregated_logits.shape == (N, C, H, W),
1047
+ que vous pourrez envoyer dans F.cross_entropy(aggregated_logits, targets).
1048
+ """
1049
+ # 1) Softmax sur la dimension 'classe' => shape (N, Q, C)
1050
+ class_probs = F.softmax(class_queries_logits, dim=2)
1051
+
1052
+ # 2) Sigmoïde sur la dimension 'pixel' => shape (N, Q, H, W)
1053
+ mask_probs = torch.sigmoid(masks_queries_logits)
1054
+
1055
+ # 3) Produit puis somme : on fait un Einstein summation ou un broadcasting
1056
+ # aggregated[b, c, h, w] = sum_q( class_probs[b,q,c] * mask_probs[b,q,h,w] )
1057
+ aggregated = torch.einsum('bqc, bqhw -> bchw', class_probs, mask_probs)
1058
+
1059
+ # Ici, aggregated est un "score" par classe et par pixel, non normalisé.
1060
+ # CrossEntropyLoss attend un tenseur (N, C, H, W) de logits,
1061
+ # puis fait un log_softmax interne. aggregated étant positif, on peut
1062
+ # éventuellement l'écraser un peu. Mais on le laisse tel quel.
1063
+ return aggregated
1064
+
1065
+ def training_for_maskformer(
1066
+ model,
1067
+ train_loader,
1068
+ val_loader,
1069
+ model_name="maskformer",
1070
+ epochs=10,
1071
+ lr=1e-4,
1072
+ num_classes=8
1073
+ ):
1074
+ import torch
1075
+ import torch.nn as nn
1076
+ import torch.optim as optim
1077
+ import torch.optim.lr_scheduler as lr_sched
1078
+ from torch.cuda.amp import autocast, GradScaler
1079
+ from tqdm import tqdm
1080
+ import pandas as pd
1081
+ import matplotlib.pyplot as plt
1082
+ import os
1083
+ import time
1084
+ import torch.nn.functional as F
1085
+
1086
+ # On importe la fonction aggregator
1087
+ from fonctions import maskformer_aggregator
1088
+
1089
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1090
+ model.to(device)
1091
+
1092
+ # Métriques
1093
+ def compute_batch_metrics(pred_logits, target, nb_classes):
1094
+ pred = torch.argmax(pred_logits, dim=1)
1095
+ correct = (pred == target).sum().item()
1096
+ total = target.numel()
1097
+ accuracy = correct / total
1098
+
1099
+ dice_list = []
1100
+ iou_list = []
1101
+ for c in range(nb_classes):
1102
+ pred_c = (pred == c)
1103
+ target_c = (target == c)
1104
+ inter = (pred_c & target_c).sum().item()
1105
+ pred_area = pred_c.sum().item()
1106
+ target_area = target_c.sum().item()
1107
+ union = pred_area + target_area - inter
1108
+
1109
+ iou_c = 1.0 if union == 0 else inter / union
1110
+ denom = pred_area + target_area
1111
+ dice_c = 1.0 if denom == 0 else (2.0 * inter / denom)
1112
+
1113
+ dice_list.append(dice_c)
1114
+ iou_list.append(iou_c)
1115
+
1116
+ mean_dice = sum(dice_list) / len(dice_list)
1117
+ mean_iou = sum(iou_list) / len(iou_list)
1118
+ return {"accuracy": accuracy, "dice": mean_dice, "iou": mean_iou}
1119
+
1120
+ criterion = nn.CrossEntropyLoss()
1121
+ optimizer = optim.Adam(model.parameters(), lr=lr)
1122
+ scheduler = lr_sched.ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)
1123
+ scaler = GradScaler()
1124
+
1125
+ os.makedirs("../resultats_modeles", exist_ok=True)
1126
+
1127
+ log = {
1128
+ "epoch": [],
1129
+ "train_loss": [],
1130
+ "val_loss": [],
1131
+ "train_accuracy": [],
1132
+ "train_dice_coef": [],
1133
+ "train_mean_iou": [],
1134
+ "val_accuracy": [],
1135
+ "val_dice_coef": [],
1136
+ "val_mean_iou": [],
1137
+ "epoch_time_s": [],
1138
+ "peak_gpu_mem_mb": []
1139
+ }
1140
+
1141
+ start_time = time.time()
1142
+
1143
+ for epoch in range(epochs):
1144
+ torch.cuda.reset_peak_memory_stats(device=device)
1145
+ epoch_start = time.time()
1146
+
1147
+ # ---------------- TRAIN ----------------
1148
+ model.train()
1149
+ running_loss = 0.0
1150
+ running_accuracy = 0.0
1151
+ running_dice = 0.0
1152
+ running_iou = 0.0
1153
+
1154
+ for images, masks in tqdm(train_loader, desc=f"[Epoch {epoch+1}/{epochs}] Train"):
1155
+ images, masks = images.to(device), masks.to(device)
1156
+ optimizer.zero_grad()
1157
+
1158
+ with autocast():
1159
+ outputs = model(images)
1160
+ # outputs est de type MaskFormerForInstanceSegmentationOutput
1161
+ class_queries = outputs.class_queries_logits # (N, Q, num_labels)
1162
+ masks_queries = outputs.masks_queries_logits # (N, Q, h, w)
1163
+
1164
+ # On upsample les masques pour correspondre à la taille des ground truth
1165
+ masks_queries = F.interpolate(
1166
+ masks_queries,
1167
+ size=(masks.shape[-2], masks.shape[-1]),
1168
+ mode='bilinear',
1169
+ align_corners=False
1170
+ )
1171
+
1172
+ # On agrège en un tenseur (N, C, H, W)
1173
+ aggregated_logits = maskformer_aggregator(class_queries, masks_queries)
1174
+
1175
+ loss = criterion(aggregated_logits, masks)
1176
+
1177
+ scaler.scale(loss).backward()
1178
+ scaler.step(optimizer)
1179
+ scaler.update()
1180
+
1181
+ running_loss += loss.item()
1182
+
1183
+ # Métriques
1184
+ metrics_batch = compute_batch_metrics(aggregated_logits, masks, num_classes)
1185
+ running_accuracy += metrics_batch["accuracy"]
1186
+ running_dice += metrics_batch["dice"]
1187
+ running_iou += metrics_batch["iou"]
1188
+
1189
+ avg_train_loss = running_loss / len(train_loader)
1190
+ avg_train_accuracy = running_accuracy / len(train_loader)
1191
+ avg_train_dice = running_dice / len(train_loader)
1192
+ avg_train_iou = running_iou / len(train_loader)
1193
+
1194
+ # ---------------- VAL ----------------
1195
+ model.eval()
1196
+ val_running_loss = 0.0
1197
+ val_running_accuracy = 0.0
1198
+ val_running_dice = 0.0
1199
+ val_running_iou = 0.0
1200
+
1201
+ with torch.no_grad():
1202
+ for images, masks in tqdm(val_loader, desc=f"[Epoch {epoch+1}/{epochs}] Val"):
1203
+ images, masks = images.to(device), masks.to(device)
1204
+
1205
+ with autocast():
1206
+ outputs = model(images)
1207
+ class_queries = outputs.class_queries_logits
1208
+ masks_queries = outputs.masks_queries_logits
1209
+
1210
+ masks_queries = F.interpolate(
1211
+ masks_queries,
1212
+ size=(masks.shape[-2], masks.shape[-1]),
1213
+ mode='bilinear',
1214
+ align_corners=False
1215
+ )
1216
+ aggregated_logits = maskformer_aggregator(class_queries, masks_queries)
1217
+
1218
+ loss_val = criterion(aggregated_logits, masks)
1219
+
1220
+ val_running_loss += loss_val.item()
1221
+ val_metrics = compute_batch_metrics(aggregated_logits, masks, num_classes)
1222
+ val_running_accuracy += val_metrics["accuracy"]
1223
+ val_running_dice += val_metrics["dice"]
1224
+ val_running_iou += val_metrics["iou"]
1225
+
1226
+ avg_val_loss = val_running_loss / len(val_loader)
1227
+ avg_val_accuracy = val_running_accuracy / len(val_loader)
1228
+ avg_val_dice = val_running_dice / len(val_loader)
1229
+ avg_val_iou = val_running_iou / len(val_loader)
1230
+
1231
+ scheduler.step(avg_val_loss)
1232
+
1233
+ epoch_time = time.time() - epoch_start
1234
+ peak_mem = torch.cuda.max_memory_allocated(device=device) / (1024 ** 2)
1235
+
1236
+ log["epoch"].append(epoch + 1)
1237
+ log["train_loss"].append(avg_train_loss)
1238
+ log["val_loss"].append(avg_val_loss)
1239
+ log["train_accuracy"].append(avg_train_accuracy)
1240
+ log["train_dice_coef"].append(avg_train_dice)
1241
+ log["train_mean_iou"].append(avg_train_iou)
1242
+ log["val_accuracy"].append(avg_val_accuracy)
1243
+ log["val_dice_coef"].append(avg_val_dice)
1244
+ log["val_mean_iou"].append(avg_val_iou)
1245
+ log["epoch_time_s"].append(epoch_time)
1246
+ log["peak_gpu_mem_mb"].append(peak_mem)
1247
+
1248
+ print(
1249
+ f"Epoch {epoch+1} | "
1250
+ f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | "
1251
+ f"Train Dice: {avg_train_dice:.4f} | Val Dice: {avg_val_dice:.4f} | "
1252
+ f"Train IoU: {avg_train_iou:.4f} | Val IoU: {avg_val_iou:.4f} | "
1253
+ f"Time: {epoch_time:.1f}s | GPU: {peak_mem:.1f} MB"
1254
+ )
1255
+
1256
+ total_time = time.time() - start_time
1257
+ df = pd.DataFrame(log)
1258
+ df["temps_total_sec"] = total_time
1259
+ csv_path = f"../resultats_modeles/{model_name}_log.csv"
1260
+ df.to_csv(csv_path, index=False)
1261
+
1262
+ # Sauvegarde du modèle
1263
+ torch.save(model.state_dict(), f"../resultats_modeles/{model_name}.pth")
1264
+
1265
+ # Génération d’un graphique Dice/IoU
1266
+ plt.figure(figsize=(12, 5))
1267
+
1268
+ # Plot Dice
1269
+ plt.subplot(1, 2, 1)
1270
+ plt.plot(df["epoch"], df["train_dice_coef"], label="Train Dice", color="blue")
1271
+ plt.plot(df["epoch"], df["val_dice_coef"], label="Val Dice", color="orange")
1272
+ plt.title("Dice Coefficient")
1273
+ plt.xlabel("Epoch")
1274
+ plt.ylabel("Dice")
1275
+ plt.legend()
1276
+ plt.grid(True)
1277
+
1278
+ # Plot IoU
1279
+ plt.subplot(1, 2, 2)
1280
+ plt.plot(df["epoch"], df["train_mean_iou"], label="Train IoU", color="blue")
1281
+ plt.plot(df["epoch"], df["val_mean_iou"], label="Val IoU", color="orange")
1282
+ plt.title("Mean IoU")
1283
+ plt.xlabel("Epoch")
1284
+ plt.ylabel("IoU")
1285
+ plt.legend()
1286
+ plt.grid(True)
1287
+
1288
+ plt.tight_layout()
1289
+ png_path = f"../resultats_modeles/{model_name}_dice_iou.png"
1290
+ plt.savefig(png_path, dpi=100)
1291
+ plt.close()
1292
+
1293
+ print(f"✅ Entraînement {model_name} terminé en {total_time:.1f} secondes.")
1294
+ print(f"📁 Logs : {csv_path}")
1295
+ print(f"📁 Modèle : ../resultats_modeles/{model_name}.pth")
1296
+ print(f"📊 Graphique Dice/IoU sauvegardé : {png_path}")
1297
+
1298
+ def training_for_mask2former(
1299
+ model,
1300
+ train_loader,
1301
+ val_loader,
1302
+ model_name="mask2former",
1303
+ epochs=10,
1304
+ lr=1e-4,
1305
+ num_classes=8
1306
+ ):
1307
+ import torch
1308
+ import torch.nn as nn
1309
+ import torch.optim as optim
1310
+ import torch.optim.lr_scheduler as lr_sched
1311
+ from torch.cuda.amp import autocast, GradScaler
1312
+ from tqdm import tqdm
1313
+ import pandas as pd
1314
+ import matplotlib.pyplot as plt
1315
+ import os
1316
+ import time
1317
+ import torch.nn.functional as F
1318
+
1319
+ from fonctions import maskformer_aggregator
1320
+
1321
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1322
+ model.to(device)
1323
+
1324
+ def compute_batch_metrics(pred_logits, target, nb_classes):
1325
+ pred = torch.argmax(pred_logits, dim=1)
1326
+ correct = (pred == target).sum().item()
1327
+ total = target.numel()
1328
+ accuracy = correct / total
1329
+
1330
+ dice_list = []
1331
+ iou_list = []
1332
+ for c in range(nb_classes):
1333
+ pred_c = (pred == c)
1334
+ target_c = (target == c)
1335
+ inter = (pred_c & target_c).sum().item()
1336
+ pred_area = pred_c.sum().item()
1337
+ target_area = target_c.sum().item()
1338
+ union = pred_area + target_area - inter
1339
+
1340
+ iou_c = 1.0 if union == 0 else inter / union
1341
+ denom = pred_area + target_area
1342
+ dice_c = 1.0 if denom == 0 else (2.0 * inter / denom)
1343
+
1344
+ dice_list.append(dice_c)
1345
+ iou_list.append(iou_c)
1346
+
1347
+ mean_dice = sum(dice_list) / len(dice_list)
1348
+ mean_iou = sum(iou_list) / len(iou_list)
1349
+ return {"accuracy": accuracy, "dice": mean_dice, "iou": mean_iou}
1350
+
1351
+ criterion = nn.CrossEntropyLoss()
1352
+ optimizer = optim.Adam(model.parameters(), lr=lr)
1353
+ scheduler = lr_sched.ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)
1354
+ scaler = GradScaler()
1355
+
1356
+ os.makedirs("../resultats_modeles", exist_ok=True)
1357
+
1358
+ log = {
1359
+ "epoch": [],
1360
+ "train_loss": [],
1361
+ "val_loss": [],
1362
+ "train_accuracy": [],
1363
+ "train_dice_coef": [],
1364
+ "train_mean_iou": [],
1365
+ "val_accuracy": [],
1366
+ "val_dice_coef": [],
1367
+ "val_mean_iou": [],
1368
+ "epoch_time_s": [],
1369
+ "peak_gpu_mem_mb": []
1370
+ }
1371
+
1372
+ start_time = time.time()
1373
+
1374
+ for epoch in range(epochs):
1375
+ torch.cuda.reset_peak_memory_stats(device=device)
1376
+ epoch_start = time.time()
1377
+
1378
+ # ---------------- TRAIN ----------------
1379
+ model.train()
1380
+ running_loss = 0.0
1381
+ running_accuracy = 0.0
1382
+ running_dice = 0.0
1383
+ running_iou = 0.0
1384
+
1385
+ for images, masks in tqdm(train_loader, desc=f"[Epoch {epoch+1}/{epochs}] Train"):
1386
+ images, masks = images.to(device), masks.to(device)
1387
+ optimizer.zero_grad()
1388
+
1389
+ with autocast():
1390
+ outputs = model(images)
1391
+ # outputs est de type Mask2FormerForUniversalSegmentationOutput
1392
+ class_queries = outputs.class_queries_logits # (N, Q, num_labels)
1393
+ masks_queries = outputs.masks_queries_logits # (N, Q, h, w)
1394
+
1395
+ masks_queries = F.interpolate(
1396
+ masks_queries,
1397
+ size=(masks.shape[-2], masks.shape[-1]),
1398
+ mode='bilinear',
1399
+ align_corners=False
1400
+ )
1401
+
1402
+ aggregated_logits = maskformer_aggregator(class_queries, masks_queries)
1403
+ loss = criterion(aggregated_logits, masks)
1404
+
1405
+ scaler.scale(loss).backward()
1406
+ scaler.step(optimizer)
1407
+ scaler.update()
1408
+
1409
+ running_loss += loss.item()
1410
+ metrics_batch = compute_batch_metrics(aggregated_logits, masks, num_classes)
1411
+ running_accuracy += metrics_batch["accuracy"]
1412
+ running_dice += metrics_batch["dice"]
1413
+ running_iou += metrics_batch["iou"]
1414
+
1415
+ avg_train_loss = running_loss / len(train_loader)
1416
+ avg_train_accuracy = running_accuracy / len(train_loader)
1417
+ avg_train_dice = running_dice / len(train_loader)
1418
+ avg_train_iou = running_iou / len(train_loader)
1419
+
1420
+ # ---------------- VAL ----------------
1421
+ model.eval()
1422
+ val_running_loss = 0.0
1423
+ val_running_accuracy = 0.0
1424
+ val_running_dice = 0.0
1425
+ val_running_iou = 0.0
1426
+
1427
+ with torch.no_grad():
1428
+ for images, masks in tqdm(val_loader, desc=f"[Epoch {epoch+1}/{epochs}] Val"):
1429
+ images, masks = images.to(device), masks.to(device)
1430
+
1431
+ with autocast():
1432
+ outputs = model(images)
1433
+ class_queries = outputs.class_queries_logits
1434
+ masks_queries = outputs.masks_queries_logits
1435
+
1436
+ masks_queries = F.interpolate(
1437
+ masks_queries,
1438
+ size=(masks.shape[-2], masks.shape[-1]),
1439
+ mode='bilinear',
1440
+ align_corners=False
1441
+ )
1442
+ aggregated_logits = maskformer_aggregator(class_queries, masks_queries)
1443
+
1444
+ loss_val = criterion(aggregated_logits, masks)
1445
+
1446
+ val_running_loss += loss_val.item()
1447
+ val_metrics = compute_batch_metrics(aggregated_logits, masks, num_classes)
1448
+ val_running_accuracy += val_metrics["accuracy"]
1449
+ val_running_dice += val_metrics["dice"]
1450
+ val_running_iou += val_metrics["iou"]
1451
+
1452
+ avg_val_loss = val_running_loss / len(val_loader)
1453
+ avg_val_accuracy = val_running_accuracy / len(val_loader)
1454
+ avg_val_dice = val_running_dice / len(val_loader)
1455
+ avg_val_iou = val_running_iou / len(val_loader)
1456
+
1457
+ scheduler.step(avg_val_loss)
1458
+
1459
+ epoch_time = time.time() - epoch_start
1460
+ peak_mem = torch.cuda.max_memory_allocated(device=device) / (1024 ** 2)
1461
+
1462
+ log["epoch"].append(epoch + 1)
1463
+ log["train_loss"].append(avg_train_loss)
1464
+ log["val_loss"].append(avg_val_loss)
1465
+ log["train_accuracy"].append(avg_train_accuracy)
1466
+ log["train_dice_coef"].append(avg_train_dice)
1467
+ log["train_mean_iou"].append(avg_train_iou)
1468
+ log["val_accuracy"].append(avg_val_accuracy)
1469
+ log["val_dice_coef"].append(avg_val_dice)
1470
+ log["val_mean_iou"].append(avg_val_iou)
1471
+ log["epoch_time_s"].append(epoch_time)
1472
+ log["peak_gpu_mem_mb"].append(peak_mem)
1473
+
1474
+ print(
1475
+ f"Epoch {epoch+1} | "
1476
+ f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | "
1477
+ f"Train Dice: {avg_train_dice:.4f} | Val Dice: {avg_val_dice:.4f} | "
1478
+ f"Train IoU: {avg_train_iou:.4f} | Val IoU: {avg_val_iou:.4f} | "
1479
+ f"Time: {epoch_time:.1f}s | GPU: {peak_mem:.1f} MB"
1480
+ )
1481
+
1482
+ total_time = time.time() - start_time
1483
+ df = pd.DataFrame(log)
1484
+ df["temps_total_sec"] = total_time
1485
+ csv_path = f"../resultats_modeles/{model_name}_log.csv"
1486
+ df.to_csv(csv_path, index=False)
1487
+ torch.save(model.state_dict(), f"../resultats_modeles/{model_name}.pth")
1488
+
1489
+ # Génération courbes Dice/IoU
1490
+ plt.figure(figsize=(12, 5))
1491
+
1492
+ plt.subplot(1, 2, 1)
1493
+ plt.plot(df["epoch"], df["train_dice_coef"], label="Train Dice", color="blue")
1494
+ plt.plot(df["epoch"], df["val_dice_coef"], label="Val Dice", color="orange")
1495
+ plt.title("Dice Coefficient")
1496
+ plt.xlabel("Epoch")
1497
+ plt.ylabel("Dice")
1498
+ plt.legend()
1499
+ plt.grid(True)
1500
+
1501
+ plt.subplot(1, 2, 2)
1502
+ plt.plot(df["epoch"], df["train_mean_iou"], label="Train IoU", color="blue")
1503
+ plt.plot(df["epoch"], df["val_mean_iou"], label="Val IoU", color="orange")
1504
+ plt.title("Mean IoU")
1505
+ plt.xlabel("Epoch")
1506
+ plt.ylabel("IoU")
1507
+ plt.legend()
1508
+ plt.grid(True)
1509
+
1510
+ plt.tight_layout()
1511
+ png_path = f"../resultats_modeles/{model_name}_dice_iou.png"
1512
+ plt.savefig(png_path, dpi=100)
1513
+ plt.close()
1514
+
1515
+ print(f"✅ Entraînement {model_name} terminé en {total_time:.1f} secondes.")
1516
+ print(f"📁 Logs : {csv_path}")
1517
+ print(f"📁 Modèle : ../resultats_modeles/{model_name}.pth")
1518
+ print(f"📊 Graphique Dice/IoU sauvegardé : {png_path}")
1519
+
1520
+ def show_predictions_maskformer(
1521
+ model,
1522
+ dataset,
1523
+ num_images=3,
1524
+ num_classes=8
1525
+ ):
1526
+ """
1527
+ Affiche quelques prédictions vs masques réels depuis un dataset PyTorch,
1528
+ pour un modèle MaskFormer-like (avec class_queries_logits et masks_queries_logits).
1529
+
1530
+ 1) On récupère `class_queries_logits` et `masks_queries_logits`.
1531
+ 2) On upsample le masks_queries_logits à la taille du masque target.
1532
+ 3) On agrège via maskformer_aggregator pour obtenir un tenseur (N, C, H, W).
1533
+ 4) On calcule un argmax (H, W) pour l'affichage.
1534
+ """
1535
+
1536
+ import torch
1537
+ import matplotlib.pyplot as plt
1538
+ import numpy as np
1539
+ from torch.cuda.amp import autocast
1540
+ import torch.nn.functional as F
1541
+
1542
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1543
+ model.eval().to(device)
1544
+
1545
+ # On importe la fonction aggregator déjà définie
1546
+ # (celle qui combine class_queries_logits et masks_queries_logits)
1547
+ from fonctions import maskformer_aggregator
1548
+
1549
+ fig, axes = plt.subplots(num_images, 3, figsize=(12, 4 * num_images))
1550
+
1551
+ for i in range(num_images):
1552
+ idx = np.random.randint(0, len(dataset))
1553
+ image, mask_gt = dataset[idx] # (3, H, W), (H, W)
1554
+
1555
+ image_t = image.unsqueeze(0).to(device) # (1, 3, H, W)
1556
+ mask_gt_np = mask_gt.numpy() # (H, W)
1557
+
1558
+ with torch.no_grad(), autocast():
1559
+ outputs = model(image_t)
1560
+ # Récupération des logits
1561
+ class_queries = outputs.class_queries_logits # (1, Q, num_labels)
1562
+ masks_queries = outputs.masks_queries_logits # (1, Q, h, w)
1563
+
1564
+ # Upsample le masks_queries à la taille du mask GT
1565
+ masks_queries = F.interpolate(
1566
+ masks_queries,
1567
+ size=(mask_gt_np.shape[0], mask_gt_np.shape[1]),
1568
+ mode='bilinear',
1569
+ align_corners=False
1570
+ )
1571
+
1572
+ # Agrégation => (1, C, H, W)
1573
+ aggregated_logits = maskformer_aggregator(class_queries, masks_queries)
1574
+ # Argmax => (H, W)
1575
+ pred = torch.argmax(aggregated_logits, dim=1).squeeze(0).cpu().numpy()
1576
+
1577
+ # AFFICHAGE
1578
+ if num_images == 1:
1579
+ # Juste 1 image => axes est un tableau 1D [3 subplots]
1580
+ ax_img, ax_gt, ax_pred = axes
1581
+ else:
1582
+ ax_img, ax_gt, ax_pred = axes[i]
1583
+
1584
+ ax_img.imshow(image.permute(1, 2, 0).cpu().numpy())
1585
+ ax_img.set_title("Image")
1586
+ ax_img.axis("off")
1587
+
1588
+ ax_gt.imshow(mask_gt_np, cmap="tab10", vmin=0, vmax=num_classes-1)
1589
+ ax_gt.set_title("Masque GT")
1590
+ ax_gt.axis("off")
1591
+
1592
+ ax_pred.imshow(pred, cmap="tab10", vmin=0, vmax=num_classes-1)
1593
+ ax_pred.set_title("Masque Prédit")
1594
+ ax_pred.axis("off")
1595
+
1596
+ plt.tight_layout()
1597
+ plt.show()
1598
+
1599
+ import matplotlib.pyplot as plt
1600
+ import pandas as pd
1601
+ import os
1602
+
1603
+ def comparer_modeles(list_csv_files, model_names=None):
1604
+ """
1605
+ Compare plusieurs modèles sur les métriques d'entraînement (loss, dice, iou, accuracy)
1606
+ et affiche un bar chart du temps total.
1607
+
1608
+ Args:
1609
+ list_csv_files (list): liste des chemins vers les fichiers CSV de logs.
1610
+ model_names (list): noms courts à afficher en légende. Doit être de même taille que list_csv_files.
1611
+ Si None, on utilise le nom de fichier.
1612
+ """
1613
+ import os
1614
+ import pandas as pd
1615
+ import matplotlib.pyplot as plt
1616
+
1617
+ if model_names is None:
1618
+ model_names = [os.path.splitext(os.path.basename(csv_file))[0] for csv_file in list_csv_files]
1619
+
1620
+ # On charge chaque CSV dans un DataFrame, qu'on stocke dans un dict
1621
+ model_data = {}
1622
+ for csv_file, name in zip(list_csv_files, model_names):
1623
+ df = pd.read_csv(csv_file)
1624
+ model_data[name] = df
1625
+
1626
+ # Couleurs prédéfinies pour la cohérence
1627
+ color_list = ["red", "blue", "green", "purple", "orange", "black"]
1628
+ # Création de la figure : 3 lignes, 2 colonnes → 5 subplots (le dernier occupant une ligne entière)
1629
+ fig = plt.figure(figsize=(14, 14))
1630
+
1631
+ # -- SUBPLOT 1 : Loss (en haut à gauche) --
1632
+ ax1 = plt.subplot2grid((3, 2), (0, 0))
1633
+ ax1.set_title("Comparaison des Loss (Perte)")
1634
+ ax1.set_xlabel("Epochs")
1635
+ ax1.set_ylabel("Loss")
1636
+ for i, (name, df) in enumerate(model_data.items()):
1637
+ c = color_list[i % len(color_list)]
1638
+ if "train_loss" in df.columns and "val_loss" in df.columns:
1639
+ ax1.plot(df["epoch"], df["train_loss"], label=f"{name} Train Loss", color=c, linestyle="--")
1640
+ ax1.plot(df["epoch"], df["val_loss"], label=f"{name} Val Loss", color=c, linestyle="-")
1641
+ ax1.grid(True)
1642
+ ax1.legend()
1643
+
1644
+ # -- SUBPLOT 2 : Accuracy (en haut à droite) --
1645
+ ax2 = plt.subplot2grid((3, 2), (0, 1))
1646
+ ax2.set_title("Comparaison de l'Accuracy")
1647
+ ax2.set_xlabel("Epochs")
1648
+ ax2.set_ylabel("Accuracy")
1649
+ for i, (name, df) in enumerate(model_data.items()):
1650
+ c = color_list[i % len(color_list)]
1651
+ if "train_accuracy" in df.columns and "val_accuracy" in df.columns:
1652
+ ax2.plot(df["epoch"], df["train_accuracy"], label=f"{name} Train Acc", color=c, linestyle="--")
1653
+ ax2.plot(df["epoch"], df["val_accuracy"], label=f"{name} Val Acc", color=c, linestyle="-")
1654
+ ax2.grid(True)
1655
+ ax2.legend()
1656
+
1657
+ # -- SUBPLOT 3 : Dice (en bas à gauche) --
1658
+ ax3 = plt.subplot2grid((3, 2), (1, 0))
1659
+ ax3.set_title("Comparaison du Dice Coefficient")
1660
+ ax3.set_xlabel("Epochs")
1661
+ ax3.set_ylabel("Dice Coefficient")
1662
+ for i, (name, df) in enumerate(model_data.items()):
1663
+ c = color_list[i % len(color_list)]
1664
+ if "train_dice_coef" in df.columns and "val_dice_coef" in df.columns:
1665
+ ax3.plot(df["epoch"], df["train_dice_coef"], label=f"{name} Train Dice", color=c, linestyle="--")
1666
+ ax3.plot(df["epoch"], df["val_dice_coef"], label=f"{name} Val Dice", color=c, linestyle="-")
1667
+ ax3.grid(True)
1668
+ ax3.legend()
1669
+
1670
+ # -- SUBPLOT 4 : Mean IoU (en bas à droite) --
1671
+ ax4 = plt.subplot2grid((3, 2), (1, 1))
1672
+ ax4.set_title("Comparaison du Mean IoU")
1673
+ ax4.set_xlabel("Epochs")
1674
+ ax4.set_ylabel("Mean IoU")
1675
+ for i, (name, df) in enumerate(model_data.items()):
1676
+ c = color_list[i % len(color_list)]
1677
+ if "train_mean_iou" in df.columns and "val_mean_iou" in df.columns:
1678
+ ax4.plot(df["epoch"], df["train_mean_iou"], label=f"{name} Train IoU", color=c, linestyle="--")
1679
+ ax4.plot(df["epoch"], df["val_mean_iou"], label=f"{name} Val IoU", color=c, linestyle="-")
1680
+ ax4.grid(True)
1681
+ ax4.legend()
1682
+
1683
+ # -- SUBPLOT 5 : Temps total (bar chart) --
1684
+ ax5 = plt.subplot2grid((3, 2), (2, 0), colspan=2)
1685
+ ax5.set_title("Comparaison du Temps total d'entraînement (en minutes)")
1686
+ training_times = []
1687
+ for i, (name, df) in enumerate(model_data.items()):
1688
+ if "temps_total_sec" in df.columns:
1689
+ total_time_sec = df["temps_total_sec"].iloc[-1]
1690
+ total_time_min = total_time_sec / 60
1691
+ else:
1692
+ total_time_min = 0
1693
+ training_times.append((name, total_time_min))
1694
+
1695
+ x_labels = [t[0] for t in training_times]
1696
+ y_values = [t[1] for t in training_times]
1697
+ bars = ax5.bar(x_labels, y_values, color=color_list[:len(y_values)])
1698
+ for bar in bars:
1699
+ height = bar.get_height()
1700
+ ax5.text(bar.get_x() + bar.get_width() / 2, height + 0.1, f"{height:.2f} min",
1701
+ ha='center', va='bottom')
1702
+ ax5.set_ylabel("Temps (minutes)")
1703
+ ax5.grid(True, axis='y')
1704
+
1705
+ plt.tight_layout()
1706
+ plt.show()
1707
+
1708
+ # ------------------------------------------------------------------
1709
+ # FONCTIONS POUR SIMULER LA PLUIE ET COMPARER LES PRÉDICTIONS
1710
+ # ------------------------------------------------------------------
1711
+
1712
+ import albumentations as A
1713
+ from torchvision import transforms
1714
+ import torch
1715
+ import torch.nn.functional as F
1716
+ import numpy as np
1717
+ from PIL import Image
1718
+ import io
1719
+ import matplotlib.pyplot as plt
1720
+
1721
+ # Transformation globale (effet pluie)
1722
+ rain_transform = A.Compose([
1723
+ A.RandomRain(
1724
+ brightness_coefficient=0.9,
1725
+ drop_length=20,
1726
+ drop_width=1,
1727
+ blur_value=3,
1728
+ rain_type='heavy'
1729
+ )
1730
+ ])
1731
+
1732
+ def apply_rain_effect(image_pil: Image.Image) -> Image.Image:
1733
+ """
1734
+ Applique l'effet de pluie à une image PIL et renvoie une nouvelle image PIL.
1735
+ """
1736
+ # Convertir en NumPy
1737
+ image_np = np.array(image_pil)
1738
+
1739
+ # Appliquer la transformation Albumentations
1740
+ augmented = rain_transform(image=image_np)
1741
+ rain_np = augmented['image']
1742
+
1743
+ # Reconvertir en PIL
1744
+ rain_pil = Image.fromarray(rain_np)
1745
+ return rain_pil
1746
+
1747
+ def predict_mask(model, image_pil, device="cpu", num_classes=8):
1748
+ """
1749
+ Utilise 'model' (PyTorch) pour prédire le masque de l'image PIL.
1750
+ Retourne un array NumPy (H,W) avec les classes prédites [0..7].
1751
+ """
1752
+ # Conversion PIL -> Tensor
1753
+ transform = transforms.ToTensor() # [0..1], shape (3,H,W)
1754
+ image_tensor = transform(image_pil).unsqueeze(0).to(device)
1755
+
1756
+ model.eval()
1757
+ with torch.no_grad():
1758
+ outputs = model(image_tensor)
1759
+ # Ex.: si c’est un SegFormer, on accède à outputs.logits
1760
+ if hasattr(outputs, "logits"):
1761
+ logits = outputs.logits
1762
+ elif isinstance(outputs, dict):
1763
+ logits = outputs["out"]
1764
+ else:
1765
+ logits = outputs
1766
+
1767
+ # Upsample => taille de l'image originale
1768
+ _, _, h_img, w_img = image_tensor.shape
1769
+ logits = F.interpolate(
1770
+ logits,
1771
+ size=(h_img, w_img),
1772
+ mode='bilinear',
1773
+ align_corners=False
1774
+ )
1775
+
1776
+ # argmax => (H,W)
1777
+ pred_mask = logits.argmax(dim=1).squeeze(0).cpu().numpy()
1778
+
1779
+ return pred_mask
1780
+
1781
+ def compare_rain_predictions(
1782
+ baseline_model,
1783
+ new_model,
1784
+ image_path,
1785
+ device="cpu",
1786
+ size=(256,256)
1787
+ ):
1788
+ """
1789
+ 1) Charge l'image d'origine.
1790
+ 2) Redimensionne en (size), applique la pluie.
1791
+ 3) Fait prédire le masque par baseline_model et new_model.
1792
+ 4) Retourne un fig (matplotlib) avec 4 colonnes :
1793
+ - image originale
1794
+ - image "pluie"
1795
+ - masque baseline
1796
+ - masque new model
1797
+ """
1798
+ # 1) Charger et redimensionner l'image
1799
+ pil_image = Image.open(image_path).convert("RGB").resize(size)
1800
+
1801
+ # 2) Appliquer la pluie
1802
+ rain_pil = apply_rain_effect(pil_image)
1803
+
1804
+ # 3) Prédictions
1805
+ mask_old = predict_mask(baseline_model, rain_pil, device=device)
1806
+ mask_new = predict_mask(new_model, rain_pil, device=device)
1807
+
1808
+ # 4) Préparer l'affichage
1809
+ fig, axs = plt.subplots(1, 4, figsize=(16, 5))
1810
+ axs[0].imshow(np.array(pil_image))
1811
+ axs[0].set_title("Original")
1812
+ axs[1].imshow(np.array(rain_pil))
1813
+ axs[1].set_title("Pluie")
1814
+ axs[2].imshow(mask_old, cmap="magma", vmin=0, vmax=7)
1815
+ axs[2].set_title("Masque (baseline)")
1816
+ axs[3].imshow(mask_new, cmap="magma", vmin=0, vmax=7)
1817
+ axs[3].set_title("Masque (nouveau)")
1818
+
1819
+ for ax in axs:
1820
+ ax.axis("off")
1821
+ plt.tight_layout()
1822
+ return fig
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ torch
4
+ transformers
5
+ Pillow
6
+ opencv-python
7
+ numpy
segformer_b5.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fef41a23f62352d996221a6440c3f1c9bd96e2286342b904655c4b63c67ff93a
3
+ size 338889838