# This script walks through the procedure from Appendix A step by step.
# It demonstrates how recovery of image-image similarities is possible given only text-image similarities.
import torch
import torch.nn.functional as F
########################
# Main Procedure #
########################
N, d = 42, 3 # You can set the number of embeddings (N) and their dimension here (d)
assert N > d*(d+1)/2, f'{N}<{d*(d+1)/2}, make sure N > d(d+1)/2'
# Setup
X_T = F.normalize(torch.rand(N,d),dim=1) # text embeddings
X_I = F.normalize(torch.rand(N,d),dim=1) # image embeddings
S_inter = X_T @ X_I.T # text-image similarities
# Decompose
U, Sigma, Vt = torch.linalg.svd(S_inter)
U, V = U[:,:d], Vt.T[:,:d] # reduced SVD (rank d)
pairwise_products = V.unsqueeze(2) * V.unsqueeze(1)
# Solve
A = pairwise_products.view(N,d*d)
b = torch.ones(N)
x = torch.linalg.lstsq(A, b, driver='gelsd').solution
# Recover
Q = x.view(d,d)
S_intra_recovered = V @ Q @ V.T
# Check
S_intra_true = X_I @ X_I.T
print('RESULT: recovery error:', (S_intra_recovered - S_intra_true).abs().max())
########################
# Analysis #
########################
print('-----------')
print('ANALYSIS')
print('1. Validate "The columns of U span the same space as the columns of X_I":')
def equal_column_space(A, B):
# Theorem: The column spaces of A and B are equal if and only if the ranks of the following three matrices are the same:
# rank(A)=rank(B)=rank([A | B])
rank_A = torch.linalg.matrix_rank(A)
rank_B = torch.linalg.matrix_rank(B)
rank_A_aug_B = torch.linalg.matrix_rank(torch.cat([A,B],dim=1))
return rank_A == rank_B == rank_A_aug_B
print(f'{equal_column_space(X_T, S_inter)=}, True expected') # column space of S_inter is the column space of X_T
print(f'{equal_column_space(U, S_inter)=}, True expected') # X_T -svd-> U
print(f'{equal_column_space(U, X_T)=}, True expected')
print(f'{equal_column_space(X_I, S_inter.T)=}, True expected') # the row space of S_inter is the column space of X_I
print(f'{equal_column_space(V, S_inter.T)=}, True expected') # X_I -svd-> V
print(f'{equal_column_space(V, X_I)=}, True expected')
print(f'{equal_column_space(V, S_inter)=}, False expected') # sanity check
print('-----------')
print('2. Validate "b = 1_N is in column space of A":')
def in_column_space(A, b):
# Theorem: b is in the column space of A if and only if the ranks of the following two matrices are the same:
rank_A = torch.linalg.matrix_rank(A)
rank_A_aug_b = torch.linalg.matrix_rank(torch.cat([A,b[:,None]],dim=1))
return rank_A == rank_A_aug_b
# setup A — pairwise product of all column combinations in V
A = (V[:,None,:] * V[:,:,None]).reshape(N, d*d)
print(f'{torch.linalg.matrix_rank(A)=}, expected {d*(d+1)/2}')
b = torch.ones(N)
print(f'{in_column_space(A, torch.ones(N))=}, True expected')
# Why?
# - A is built from V.
# - From the previous validation, the column space of V and X_I are the same.
A_with_X_I_instead_of_V = (X_I[:,None,:] * X_I[:,:,None]).reshape(N, d*d)
print(f'{in_column_space(A_with_X_I_instead_of_V, b)=}, True expected')
# Why is b in the columns of X_I?
# -> because X_I is normalized, so a combination (the sum) of all columns in X_I**2 is 1:
print(f'{(X_I**2).sum(axis=-1)=}')
# hence:
print(f'{in_column_space(X_I**2, b)=}, True expected')
# hence also the diagonals in A corresponding to the self-squares sum to 1
print((A_with_X_I_instead_of_V * torch.eye(d).flatten()).sum(axis=1))
# such that b is in the column space of A_with_X_I_instead_of_V
# and since A_with_X_I_instead_of_V and A share the column space, also in the column space of A.
print(f'{equal_column_space(A_with_X_I_instead_of_V, A)=}, True expected') # sanity check