from __future__ import annotations import os from dataclasses import dataclass from typing import List, Tuple, Optional, Callable import torch import numpy as np from PIL import Image from transformers import AutoModel, AutoProcessor from huggingface_hub import snapshot_download def _pick_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") # Apple Silicon if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def _pick_dtype(device: torch.device) -> torch.dtype: if device.type == "cuda": # Prefer bf16 if supported; else fp16 if torch.cuda.is_bf16_supported(): return torch.bfloat16 return torch.float16 if device.type == "mps": # mps prefers float32 accuracy return torch.float32 return torch.float32 @dataclass class SiglipEncoder: model: AutoModel processor: AutoProcessor device: torch.device dtype: torch.dtype ckpt_dir: str ckpt_local_dir: str @classmethod def from_checkpoint_dir(cls, ckpt_dir: str, log: Optional[Callable[[str], None]] = None, token: str = None) -> "SiglipEncoder": device = _pick_device() dtype = _pick_dtype(device) if log: log(f"Loading processor/model from {ckpt_dir} (device={device}, dtype={dtype})") local_dir = snapshot_download(repo_id=ckpt_dir, token=token) processor = AutoProcessor.from_pretrained(local_dir, trust_remote_code=True, token=token) model = AutoModel.from_pretrained(local_dir, trust_remote_code=True, token=token) model.to(device) model.eval() return cls(model=model, processor=processor, device=device, dtype=dtype, ckpt_dir=ckpt_dir, ckpt_local_dir=local_dir) # ---------- Embedding helpers ---------- @torch.no_grad() def _maybe_autocast(self): # cuda amp context if self.device.type == "cuda" and self.dtype in (torch.float16, torch.bfloat16): return torch.autocast(device_type="cuda", dtype=self.dtype) # for mps/cpu, no autocast by default class DummyCtx: def __enter__(self): return None def __exit__(self, *args): return False return DummyCtx() def _normalize(self, x: np.ndarray) -> np.ndarray: norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12 return x / norms def _pool_mean(self, last_hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor: # mean pooling with attention mask if attention_mask is None: return last_hidden_state.mean(dim=1) mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype) summed = (last_hidden_state * mask).sum(dim=1) counts = mask.sum(dim=1).clamp(min=1e-6) return summed / counts def _forward_image(self, pixel_values: torch.Tensor) -> torch.Tensor: # Try common signatures: get_image_features or forward(...).image_embeds # Fallback: mean pool last_hidden_state of vision tower. if hasattr(self.model, "get_image_features"): return self.model.get_image_features(pixel_values=pixel_values) out = self.model(pixel_values=pixel_values) if hasattr(out, "image_embeds") and out.image_embeds is not None: return out.image_embeds if hasattr(out, "last_hidden_state"): return out.last_hidden_state.mean(dim=1) raise RuntimeError("Unable to extract image embeddings from model outputs.") def _forward_text(self, **text_inputs) -> torch.Tensor: if hasattr(self.model, "get_text_features"): return self.model.get_text_features(**text_inputs) out = self.model(**text_inputs) if hasattr(out, "text_embeds") and out.text_embeds is not None: return out.text_embeds if hasattr(out, "last_hidden_state"): return self._pool_mean(out.last_hidden_state, text_inputs.get("attention_mask")) raise RuntimeError("Unable to extract text embeddings from model outputs.") @torch.no_grad() def encode_images(self, images: List[Image.Image], batch_size: int = 64) -> np.ndarray: """Encode a list of PIL images to L2-normalized embeddings.""" feats: List[np.ndarray] = [] with self._maybe_autocast(): for i in range(0, len(images), batch_size): batch = images[i : i + batch_size] # Ensure RGB batch = [im.convert("RGB") if im.mode != "RGB" else im for im in batch] inputs = self.processor(images=batch, return_tensors="pt") pixel_values = inputs["pixel_values"].to(self.device, dtype=self.dtype if self.device.type == "cuda" else torch.float32) embs = self._forward_image(pixel_values) # (B, D) embs = embs.float().cpu().numpy() feats.append(embs) feats_np = np.concatenate(feats, axis=0) return self._normalize(feats_np) @torch.no_grad() def encode_texts(self, texts: List[str], batch_size: int = 128) -> np.ndarray: """Encode a list of texts to L2-normalized embeddings.""" feats: List[np.ndarray] = [] with self._maybe_autocast(): for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] inputs = self.processor(text=batch, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(self.device) for k, v in inputs.items()} embs = self._forward_text(**inputs) # (B, D) embs = embs.float().cpu().numpy() feats.append(embs) feats_np = np.concatenate(feats, axis=0) return self._normalize(feats_np)