Spaces:
Build error
Build error
from __future__ import annotations | |
import os | |
import io | |
import json | |
import time | |
import hashlib | |
from enum import Enum | |
from typing import Optional, Callable, Tuple | |
import numpy as np | |
try: | |
import faiss # type: ignore | |
except Exception as e: | |
raise RuntimeError( | |
"Failed to import faiss. Ensure 'faiss-cpu' is in requirements.txt." | |
) from e | |
from dataset_utils import SampleAccessor | |
from encoders import SiglipEncoder | |
class IndexStatus(Enum): | |
CREATED = "CREATED" | |
LOADED = "LOADED" | |
SKIPPED_FOUND = "SKIPPED_FOUND" | |
UPDATED = "UPDATED" | |
def index_signature_from_env( | |
dataset_name: str, | |
split: str, | |
max_samples: int, | |
image_col: str, | |
text_col: str, | |
ckpt_local_dir: str, # <-- dùng tên này xuyên suốt | |
) -> str: | |
import os, json, hashlib | |
cfg_path = os.path.join(ckpt_local_dir, "config.json") | |
cfg_hash = "nocfg" | |
if os.path.isfile(cfg_path): | |
with open(cfg_path, "rb") as f: | |
cfg_hash = hashlib.md5(f.read()).hexdigest()[:10] | |
# commit id: .../snapshots/<commit>/ | |
commit = os.path.basename(os.path.dirname(ckpt_local_dir)) | |
# (tuỳ chọn) hash 1 file model | |
model_hash = "nomodel" | |
for fn in ("model.safetensors", "pytorch_model.bin"): | |
p = os.path.join(ckpt_local_dir, fn) | |
if os.path.isfile(p): | |
with open(p, "rb") as f: | |
model_hash = hashlib.sha1(f.read()).hexdigest()[:12] | |
break | |
payload = { | |
"dataset": dataset_name, | |
"split": split, | |
"max_samples": int(max_samples), | |
"ckpt": commit, | |
"cfg": cfg_hash, | |
"model": model_hash, | |
"image_col": image_col, | |
"text_col": text_col, | |
} | |
return hashlib.sha1(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()[:16] | |
def _index_paths(index_dir: str, signature: str): | |
os.makedirs(index_dir, exist_ok=True) | |
idx_path = os.path.join(index_dir, f"{signature}.faiss") | |
meta_path = os.path.join(index_dir, f"{signature}.meta.json") | |
return idx_path, meta_path | |
def _maybe_gpu(index, prefer_gpu: bool): | |
"""If FAISS GPU is available and prefer_gpu, move index to GPU; else return as-is.""" | |
if not prefer_gpu: | |
return index | |
try: | |
import faiss # noqa | |
if faiss.get_num_gpus() > 0: | |
res = faiss.StandardGpuResources() | |
return faiss.index_cpu_to_gpu(res, 0, index) | |
except Exception: | |
pass | |
return index | |
def _normalize_rows(x: np.ndarray) -> np.ndarray: | |
norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12 | |
return x / norms | |
def ensure_index( | |
accessor: SampleAccessor, | |
encoder: SiglipEncoder, | |
index_dir: str, | |
signature: str, | |
log: Optional[Callable[[str], None]] = None, | |
) -> IndexStatus: | |
"""Create the FAISS index if not present; otherwise leave it.""" | |
idx_path, meta_path = _index_paths(index_dir, signature) | |
if os.path.isfile(idx_path) and os.path.isfile(meta_path): | |
if log: | |
log(f"Index already exists at {idx_path}") | |
return IndexStatus.SKIPPED_FOUND | |
# Encode all images in batches | |
n = len(accessor) | |
if log: | |
log(f"Encoding {n} images to build index ...") | |
batch = 512 | |
feats = [] | |
t0 = time.time() | |
for start in range(0, n, batch): | |
end = min(n, start + batch) | |
imgs = accessor.batched_images(start, end) | |
emb = encoder.encode_images(imgs) # (B, D), L2 normalized | |
feats.append(emb) | |
if log: | |
pct = (end / n) * 100.0 | |
log(f"Progress: {end}/{n} ({pct:.1f}%)") | |
feats_np = np.concatenate(feats, axis=0).astype("float32", copy=False) | |
dim = feats_np.shape[1] | |
# Build cosine via inner-product on normalized vectors | |
cpu_index = faiss.IndexFlatIP(dim) | |
cpu_index.add(feats_np) | |
# Save to disk (CPU index for compatibility) | |
faiss.write_index(cpu_index, idx_path) | |
# Save meta information | |
meta = { | |
"signature": signature, | |
"size": int(n), | |
"dim": int(dim), | |
"created_at": time.time(), | |
"index_path": os.path.basename(idx_path), | |
"notes": "Embeddings are L2-normalized; cosine == inner product.", | |
} | |
with open(meta_path, "w", encoding="utf-8") as f: | |
json.dump(meta, f, ensure_ascii=False, indent=2) | |
if log: | |
log(f"Index built in {(time.time() - t0):.2f}s. Saved to {idx_path}") | |
return IndexStatus.CREATED | |
def load_index_meta(index_dir: str, signature: str) -> Optional[dict]: | |
idx_path, meta_path = _index_paths(index_dir, signature) | |
if not os.path.isfile(meta_path): | |
return None | |
with open(meta_path, "r", encoding="utf-8") as f: | |
return json.load(f) | |
def load_faiss_index(index_dir: str, signature: str, log: Optional[Callable[[str], None]] = None, *, device_pref: str = "auto"): | |
idx_path, meta_path = _index_paths(index_dir, signature) | |
if not (os.path.isfile(idx_path) and os.path.isfile(meta_path)): | |
return None, None | |
with open(meta_path, "r", encoding="utf-8") as f: | |
meta = json.load(f) | |
idx = faiss.read_index(idx_path) | |
dim = int(meta.get("dim", idx.d)) | |
prefer_gpu = device_pref in ("auto", "gpu") | |
idx = _maybe_gpu(idx, prefer_gpu=prefer_gpu) | |
if log: | |
dev = "GPU" if prefer_gpu else "CPU" | |
log(f"Loaded FAISS index: {idx_path} (dim={dim}) | device_pref={device_pref}") | |
return idx, dim | |
def search_faiss(index, query_embs: np.ndarray, top_k: int = 10): | |
"""Search FAISS (inner product) with normalized query embeddings.""" | |
assert query_embs.ndim == 2 | |
# Ensure L2-normalized | |
q = _normalize_rows(query_embs.astype("float32", copy=False)) | |
scores, ids = index.search(q, int(top_k)) | |
return scores, ids | |