vungocthach1112's picture
fix snapshot indexes in FAISS
cff407d
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import List, Tuple, Optional, Callable
import torch
import numpy as np
from PIL import Image
from transformers import AutoModel, AutoProcessor
from huggingface_hub import snapshot_download
def _pick_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
# Apple Silicon
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def _pick_dtype(device: torch.device) -> torch.dtype:
if device.type == "cuda":
# Prefer bf16 if supported; else fp16
if torch.cuda.is_bf16_supported():
return torch.bfloat16
return torch.float16
if device.type == "mps":
# mps prefers float32 accuracy
return torch.float32
return torch.float32
@dataclass
class SiglipEncoder:
model: AutoModel
processor: AutoProcessor
device: torch.device
dtype: torch.dtype
ckpt_dir: str
ckpt_local_dir: str
@classmethod
def from_checkpoint_dir(cls, ckpt_dir: str, log: Optional[Callable[[str], None]] = None, token: str = None) -> "SiglipEncoder":
device = _pick_device()
dtype = _pick_dtype(device)
if log:
log(f"Loading processor/model from {ckpt_dir} (device={device}, dtype={dtype})")
local_dir = snapshot_download(repo_id=ckpt_dir, token=token)
processor = AutoProcessor.from_pretrained(local_dir, trust_remote_code=True, token=token)
model = AutoModel.from_pretrained(local_dir, trust_remote_code=True, token=token)
model.to(device)
model.eval()
return cls(model=model, processor=processor, device=device, dtype=dtype,
ckpt_dir=ckpt_dir, ckpt_local_dir=local_dir)
# ---------- Embedding helpers ----------
@torch.no_grad()
def _maybe_autocast(self):
# cuda amp context
if self.device.type == "cuda" and self.dtype in (torch.float16, torch.bfloat16):
return torch.autocast(device_type="cuda", dtype=self.dtype)
# for mps/cpu, no autocast by default
class DummyCtx:
def __enter__(self): return None
def __exit__(self, *args): return False
return DummyCtx()
def _normalize(self, x: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
return x / norms
def _pool_mean(self, last_hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor:
# mean pooling with attention mask
if attention_mask is None:
return last_hidden_state.mean(dim=1)
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
summed = (last_hidden_state * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1e-6)
return summed / counts
def _forward_image(self, pixel_values: torch.Tensor) -> torch.Tensor:
# Try common signatures: get_image_features or forward(...).image_embeds
# Fallback: mean pool last_hidden_state of vision tower.
if hasattr(self.model, "get_image_features"):
return self.model.get_image_features(pixel_values=pixel_values)
out = self.model(pixel_values=pixel_values)
if hasattr(out, "image_embeds") and out.image_embeds is not None:
return out.image_embeds
if hasattr(out, "last_hidden_state"):
return out.last_hidden_state.mean(dim=1)
raise RuntimeError("Unable to extract image embeddings from model outputs.")
def _forward_text(self, **text_inputs) -> torch.Tensor:
if hasattr(self.model, "get_text_features"):
return self.model.get_text_features(**text_inputs)
out = self.model(**text_inputs)
if hasattr(out, "text_embeds") and out.text_embeds is not None:
return out.text_embeds
if hasattr(out, "last_hidden_state"):
return self._pool_mean(out.last_hidden_state, text_inputs.get("attention_mask"))
raise RuntimeError("Unable to extract text embeddings from model outputs.")
@torch.no_grad()
def encode_images(self, images: List[Image.Image], batch_size: int = 64) -> np.ndarray:
"""Encode a list of PIL images to L2-normalized embeddings."""
feats: List[np.ndarray] = []
with self._maybe_autocast():
for i in range(0, len(images), batch_size):
batch = images[i : i + batch_size]
# Ensure RGB
batch = [im.convert("RGB") if im.mode != "RGB" else im for im in batch]
inputs = self.processor(images=batch, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(self.device, dtype=self.dtype if self.device.type == "cuda" else torch.float32)
embs = self._forward_image(pixel_values) # (B, D)
embs = embs.float().cpu().numpy()
feats.append(embs)
feats_np = np.concatenate(feats, axis=0)
return self._normalize(feats_np)
@torch.no_grad()
def encode_texts(self, texts: List[str], batch_size: int = 128) -> np.ndarray:
"""Encode a list of texts to L2-normalized embeddings."""
feats: List[np.ndarray] = []
with self._maybe_autocast():
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
inputs = self.processor(text=batch, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
embs = self._forward_text(**inputs) # (B, D)
embs = embs.float().cpu().numpy()
feats.append(embs)
feats_np = np.concatenate(feats, axis=0)
return self._normalize(feats_np)