SigLip-Fashion-Retrieval / index_builder.py
vungocthach1112's picture
fix snapshot indexes in FAISS
cff407d
from __future__ import annotations
import os
import io
import json
import time
import hashlib
from enum import Enum
from typing import Optional, Callable, Tuple
import numpy as np
try:
import faiss # type: ignore
except Exception as e:
raise RuntimeError(
"Failed to import faiss. Ensure 'faiss-cpu' is in requirements.txt."
) from e
from dataset_utils import SampleAccessor
from encoders import SiglipEncoder
class IndexStatus(Enum):
CREATED = "CREATED"
LOADED = "LOADED"
SKIPPED_FOUND = "SKIPPED_FOUND"
UPDATED = "UPDATED"
def index_signature_from_env(
dataset_name: str,
split: str,
max_samples: int,
image_col: str,
text_col: str,
ckpt_local_dir: str, # <-- dùng tên này xuyên suốt
) -> str:
import os, json, hashlib
cfg_path = os.path.join(ckpt_local_dir, "config.json")
cfg_hash = "nocfg"
if os.path.isfile(cfg_path):
with open(cfg_path, "rb") as f:
cfg_hash = hashlib.md5(f.read()).hexdigest()[:10]
# commit id: .../snapshots/<commit>/
commit = os.path.basename(os.path.dirname(ckpt_local_dir))
# (tuỳ chọn) hash 1 file model
model_hash = "nomodel"
for fn in ("model.safetensors", "pytorch_model.bin"):
p = os.path.join(ckpt_local_dir, fn)
if os.path.isfile(p):
with open(p, "rb") as f:
model_hash = hashlib.sha1(f.read()).hexdigest()[:12]
break
payload = {
"dataset": dataset_name,
"split": split,
"max_samples": int(max_samples),
"ckpt": commit,
"cfg": cfg_hash,
"model": model_hash,
"image_col": image_col,
"text_col": text_col,
}
return hashlib.sha1(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()[:16]
def _index_paths(index_dir: str, signature: str):
os.makedirs(index_dir, exist_ok=True)
idx_path = os.path.join(index_dir, f"{signature}.faiss")
meta_path = os.path.join(index_dir, f"{signature}.meta.json")
return idx_path, meta_path
def _maybe_gpu(index, prefer_gpu: bool):
"""If FAISS GPU is available and prefer_gpu, move index to GPU; else return as-is."""
if not prefer_gpu:
return index
try:
import faiss # noqa
if faiss.get_num_gpus() > 0:
res = faiss.StandardGpuResources()
return faiss.index_cpu_to_gpu(res, 0, index)
except Exception:
pass
return index
def _normalize_rows(x: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
return x / norms
def ensure_index(
accessor: SampleAccessor,
encoder: SiglipEncoder,
index_dir: str,
signature: str,
log: Optional[Callable[[str], None]] = None,
) -> IndexStatus:
"""Create the FAISS index if not present; otherwise leave it."""
idx_path, meta_path = _index_paths(index_dir, signature)
if os.path.isfile(idx_path) and os.path.isfile(meta_path):
if log:
log(f"Index already exists at {idx_path}")
return IndexStatus.SKIPPED_FOUND
# Encode all images in batches
n = len(accessor)
if log:
log(f"Encoding {n} images to build index ...")
batch = 512
feats = []
t0 = time.time()
for start in range(0, n, batch):
end = min(n, start + batch)
imgs = accessor.batched_images(start, end)
emb = encoder.encode_images(imgs) # (B, D), L2 normalized
feats.append(emb)
if log:
pct = (end / n) * 100.0
log(f"Progress: {end}/{n} ({pct:.1f}%)")
feats_np = np.concatenate(feats, axis=0).astype("float32", copy=False)
dim = feats_np.shape[1]
# Build cosine via inner-product on normalized vectors
cpu_index = faiss.IndexFlatIP(dim)
cpu_index.add(feats_np)
# Save to disk (CPU index for compatibility)
faiss.write_index(cpu_index, idx_path)
# Save meta information
meta = {
"signature": signature,
"size": int(n),
"dim": int(dim),
"created_at": time.time(),
"index_path": os.path.basename(idx_path),
"notes": "Embeddings are L2-normalized; cosine == inner product.",
}
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, ensure_ascii=False, indent=2)
if log:
log(f"Index built in {(time.time() - t0):.2f}s. Saved to {idx_path}")
return IndexStatus.CREATED
def load_index_meta(index_dir: str, signature: str) -> Optional[dict]:
idx_path, meta_path = _index_paths(index_dir, signature)
if not os.path.isfile(meta_path):
return None
with open(meta_path, "r", encoding="utf-8") as f:
return json.load(f)
def load_faiss_index(index_dir: str, signature: str, log: Optional[Callable[[str], None]] = None, *, device_pref: str = "auto"):
idx_path, meta_path = _index_paths(index_dir, signature)
if not (os.path.isfile(idx_path) and os.path.isfile(meta_path)):
return None, None
with open(meta_path, "r", encoding="utf-8") as f:
meta = json.load(f)
idx = faiss.read_index(idx_path)
dim = int(meta.get("dim", idx.d))
prefer_gpu = device_pref in ("auto", "gpu")
idx = _maybe_gpu(idx, prefer_gpu=prefer_gpu)
if log:
dev = "GPU" if prefer_gpu else "CPU"
log(f"Loaded FAISS index: {idx_path} (dim={dim}) | device_pref={device_pref}")
return idx, dim
def search_faiss(index, query_embs: np.ndarray, top_k: int = 10):
"""Search FAISS (inner product) with normalized query embeddings."""
assert query_embs.ndim == 2
# Ensure L2-normalized
q = _normalize_rows(query_embs.astype("float32", copy=False))
scores, ids = index.search(q, int(top_k))
return scores, ids