Spaces:
Build error
Build error
from __future__ import annotations | |
import os | |
from dataclasses import dataclass | |
from typing import List, Tuple, Optional, Callable | |
import torch | |
import numpy as np | |
from PIL import Image | |
from transformers import AutoModel, AutoProcessor | |
from huggingface_hub import snapshot_download | |
def _pick_device() -> torch.device: | |
if torch.cuda.is_available(): | |
return torch.device("cuda") | |
# Apple Silicon | |
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): | |
return torch.device("mps") | |
return torch.device("cpu") | |
def _pick_dtype(device: torch.device) -> torch.dtype: | |
if device.type == "cuda": | |
# Prefer bf16 if supported; else fp16 | |
if torch.cuda.is_bf16_supported(): | |
return torch.bfloat16 | |
return torch.float16 | |
if device.type == "mps": | |
# mps prefers float32 accuracy | |
return torch.float32 | |
return torch.float32 | |
class SiglipEncoder: | |
model: AutoModel | |
processor: AutoProcessor | |
device: torch.device | |
dtype: torch.dtype | |
ckpt_dir: str | |
ckpt_local_dir: str | |
def from_checkpoint_dir(cls, ckpt_dir: str, log: Optional[Callable[[str], None]] = None, token: str = None) -> "SiglipEncoder": | |
device = _pick_device() | |
dtype = _pick_dtype(device) | |
if log: | |
log(f"Loading processor/model from {ckpt_dir} (device={device}, dtype={dtype})") | |
local_dir = snapshot_download(repo_id=ckpt_dir, token=token) | |
processor = AutoProcessor.from_pretrained(local_dir, trust_remote_code=True, token=token) | |
model = AutoModel.from_pretrained(local_dir, trust_remote_code=True, token=token) | |
model.to(device) | |
model.eval() | |
return cls(model=model, processor=processor, device=device, dtype=dtype, | |
ckpt_dir=ckpt_dir, ckpt_local_dir=local_dir) | |
# ---------- Embedding helpers ---------- | |
def _maybe_autocast(self): | |
# cuda amp context | |
if self.device.type == "cuda" and self.dtype in (torch.float16, torch.bfloat16): | |
return torch.autocast(device_type="cuda", dtype=self.dtype) | |
# for mps/cpu, no autocast by default | |
class DummyCtx: | |
def __enter__(self): return None | |
def __exit__(self, *args): return False | |
return DummyCtx() | |
def _normalize(self, x: np.ndarray) -> np.ndarray: | |
norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12 | |
return x / norms | |
def _pool_mean(self, last_hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor: | |
# mean pooling with attention mask | |
if attention_mask is None: | |
return last_hidden_state.mean(dim=1) | |
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype) | |
summed = (last_hidden_state * mask).sum(dim=1) | |
counts = mask.sum(dim=1).clamp(min=1e-6) | |
return summed / counts | |
def _forward_image(self, pixel_values: torch.Tensor) -> torch.Tensor: | |
# Try common signatures: get_image_features or forward(...).image_embeds | |
# Fallback: mean pool last_hidden_state of vision tower. | |
if hasattr(self.model, "get_image_features"): | |
return self.model.get_image_features(pixel_values=pixel_values) | |
out = self.model(pixel_values=pixel_values) | |
if hasattr(out, "image_embeds") and out.image_embeds is not None: | |
return out.image_embeds | |
if hasattr(out, "last_hidden_state"): | |
return out.last_hidden_state.mean(dim=1) | |
raise RuntimeError("Unable to extract image embeddings from model outputs.") | |
def _forward_text(self, **text_inputs) -> torch.Tensor: | |
if hasattr(self.model, "get_text_features"): | |
return self.model.get_text_features(**text_inputs) | |
out = self.model(**text_inputs) | |
if hasattr(out, "text_embeds") and out.text_embeds is not None: | |
return out.text_embeds | |
if hasattr(out, "last_hidden_state"): | |
return self._pool_mean(out.last_hidden_state, text_inputs.get("attention_mask")) | |
raise RuntimeError("Unable to extract text embeddings from model outputs.") | |
def encode_images(self, images: List[Image.Image], batch_size: int = 64) -> np.ndarray: | |
"""Encode a list of PIL images to L2-normalized embeddings.""" | |
feats: List[np.ndarray] = [] | |
with self._maybe_autocast(): | |
for i in range(0, len(images), batch_size): | |
batch = images[i : i + batch_size] | |
# Ensure RGB | |
batch = [im.convert("RGB") if im.mode != "RGB" else im for im in batch] | |
inputs = self.processor(images=batch, return_tensors="pt") | |
pixel_values = inputs["pixel_values"].to(self.device, dtype=self.dtype if self.device.type == "cuda" else torch.float32) | |
embs = self._forward_image(pixel_values) # (B, D) | |
embs = embs.float().cpu().numpy() | |
feats.append(embs) | |
feats_np = np.concatenate(feats, axis=0) | |
return self._normalize(feats_np) | |
def encode_texts(self, texts: List[str], batch_size: int = 128) -> np.ndarray: | |
"""Encode a list of texts to L2-normalized embeddings.""" | |
feats: List[np.ndarray] = [] | |
with self._maybe_autocast(): | |
for i in range(0, len(texts), batch_size): | |
batch = texts[i : i + batch_size] | |
inputs = self.processor(text=batch, return_tensors="pt", padding=True, truncation=True) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
embs = self._forward_text(**inputs) # (B, D) | |
embs = embs.float().cpu().numpy() | |
feats.append(embs) | |
feats_np = np.concatenate(feats, axis=0) | |
return self._normalize(feats_np) | |