Last modified: 2024-10-10 16:12
Background
SimCLR was proposed by Chen et al. (ICML 2020, paper on arxiv). You can find concise pseudocode for their original approach in Alg. 1 of that paper
Here, we develop a simplified version of SimCLR for our two half moons dataset in HW2.
Given a batch of \(N\) instances (unlabeled feature vectors), our method performs two steps
-
Augmentation Step: We create two augmented variants (nicknamed 'left' and 'right') of each instance's feature vector, by adding zero-mean Gaussian noise, and then create embeddings of each, resulting in M=2N total embedding vectors.
-
Contrastive Loss Step: Evaluate contrastive loss for each positive pair (left-to-right and right-to-left). Thus, the average is over M=2N terms.
Key differences from the original SimCLR
- Our notation indexes the total set of M=2N embeddings differently. We keep "left" variants at indices 1 ... N and the "right" variants at indices N+1 ... 2N. This makes the code a bit easier to read.
- Our MLPClassifier's encoder produces L2-normalized embeddings by construction. So embeddings for any input are already on the unit circle.
- To keep things simple, we do not have an extra "projection head". We compute the loss directly on the unit circle embeddings.
Key hyperparameters include
- \(\sigma\) (
sigma
) : standard deviation of Gaussian noise - \(\tau\) (
temp
) : temperature
Both \(\sigma\) and \(\tau\) need to be carefully set for a specific dataset to make things work smoothly. Please use the recommended values in hw2.ipynb
.
Pseudocode: Simplified SimCLR loss for one batch
Inputs:
N
: num instances in batchF
: num feature dimensions in raw inputx_NF
: unlabeled batch of feature vectors, as tensor of shape (N,F)model
(\(\theta\)) : Parameters of neural net encoderD
: num embedding dimensions (for HW2, D=2) produced by modelsigma
(\(\sigma\)) : standard deviation of Gaussian used for augmentationtemp
(\(\tau\)) : temperature parameter
Return Values:
total_loss
: scalar loss for this batch, lower is better
Procedure:
# Part 1: Create left and right embeddings for n in 1, 2, ... N:
\(x^L_n \gets \text{Augment}(x_n, \sigma)\)
\(x^R_n \gets \text{Augment}(x_n, \sigma)\)
\(z^L_n \gets \text{Encode}(x^L_n, \theta)\)
\(z^R_n \gets \text{Encode}(x^R_n, \theta)\)
\(z = \text{vstack}( z^L, z^R )\)
# Part 2: Evaluate contrastive loss for all left/right pairs total_loss = 0.0 for n in 1, 2, ... N:
\(\ell^L_n \gets \ell( z, n, n + N)\)
\(\ell^R_n \gets \ell( z, n + N, n)\)
total_loss = total_loss \( + \ell^L_n + \ell^R_n\)
return total_loss
Contrastive loss: Temperature-scaled cross-entropy
Following Chen et al, we use the normalized temperature-scaled cross entropy ("NT-xent") loss.
For indices \(i,j\) that index distinct embeddings in the merged set \(z\) of \(M = 2N\) total embeddings, we compute the loss as
Remember that this is about instance discrimination. We know (by construction) that a specific left embedding (at index \(n\)) has its correct "right" partner at index \(N+n\), and vice versa. This loss can be interpreted as:
- \(\ell(z, n, n+N)\) : cross entropy of the probability of picking the correct partner, out of M-1 possible partners for index \(n\).
- \(\ell(z, n+N, n)\) : cross entropy of the probability of picking the correct partner, out of M-1 possible partners for index \(n+N\).
We want each of these values to be as low as possible.