Is CLIP Really Misaligned? / span_projection.py
SpanProjection class — drop-in PCA⁻ projection for CLIP / SigLIP embeddings
import torch
from tqdm import tqdm
import os
import requests


def fit_pca(feats, reduce_factor=2):
    n_components = feats.shape[1] // reduce_factor
    covmat = torch.cov(feats.T)
    eigvals, eigvecs = torch.linalg.eigh(covmat)
    sorted_indcs = torch.sort(eigvals.float().cpu(), descending=True).indices
    return eigvecs[:, sorted_indcs[:n_components]].float().T


def transform_pca(span, feats_tensor):
    U = torch.tensor(span, device=feats_tensor.device, dtype=feats_tensor.dtype) \
        if not isinstance(span, torch.Tensor) else span.to(feats_tensor.device, feats_tensor.dtype)  # (d/2, d)
    # U is the PCA projection from d -> d/2
    # then reproject into original space: d/2 -> d  (not strictly necessary, but conceptually useful)
    feats_tensor_projected = (U.T @ U @ feats_tensor.T).T
    return feats_tensor_projected


class SpanProjection:
    def __init__(self, clip_model, tokenize_fn, span_feats_cache='imagenet_text_features.pt', device=None):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if not device else device
        self.clip_model = clip_model
        self.tokenize_fn = tokenize_fn
        self.span_feats_cache = span_feats_cache
        self.fwd_texts = self.fwd_texts_clip \
            if 'CLIP' in str(type(self.clip_model)) or 'SLIP' in str(type(self.clip_model)) \
            else self.fwd_texts_siglip

        self.fullspan = self.fwd_imagenet_texts()
        self.span = fit_pca(self.fullspan)

    @torch.no_grad()
    def fwd_texts_clip(self, texts):
        # as in CLIP
        tokenized = self.tokenize_fn(texts).to(self.device)
        chunk_size = 64
        text_feats = []
        for i in tqdm(range(1 + (len(tokenized) - 1) // chunk_size), 'Forwarding Texts'):
            text_feats.append(self.clip_model.encode_text(tokenized[i * chunk_size:i * chunk_size + chunk_size]))
        return torch.cat(text_feats)

    @torch.no_grad()
    def fwd_texts_siglip(self, texts):
        # as in SigLIP
        tokenized = self.tokenize_fn(texts, padding="max_length", max_length=64, return_tensors="pt").to(self.device)
        return self.clip_model.encode_text(tokenized)

    @torch.no_grad()
    def fwd_imagenet_texts(self):
        if not os.path.exists(self.span_feats_cache):
            print(f'{self.span_feats_cache=} not exists, extracting text features')
            response = requests.get('https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt')
            imagenet_classes = response.text.strip().split('\n')
            text_feats = self.fwd_texts(imagenet_classes)
            os.makedirs(os.path.dirname(self.span_feats_cache), exist_ok=True)
            torch.save(text_feats, self.span_feats_cache)
        return torch.load(self.span_feats_cache, map_location=self.device).to(torch.float32)

    @torch.no_grad()
    def project(self, features):
        return transform_pca(self.span, features)