Created: October 10, 2021 5:20 PM
This is the report from the class project for code and review reproduction. We utilize the PyTorch version of SimCLR with the most stars.
All source codes and rights belong to sthalles/SimCLR and Google Research.
Contents
- Introduction
- Run SimCLR Code
- Code Architecture
- Major Components (§2.1)
- SimCLR Algorithm to Code
- Feature Evaluation (§2.3, §4.2)
- Q&A
- Reference
Introduction
- Annotating data is expensive.
- Training model in an unsupervised way allows us to utilize unlabeled data to learn representations that could be used as a proxy to achieve improvement in performance.
- Current research on unsupervised representation learning is finally catching up with supervised methods.
- Contrastive learning is to learn representations by enforcing similar elements to be equal and dissimilar elements to be different.
What makes the model learn good representations?
- Composition of multiple data augmentation operations is crucial in defining the contrastive prediction tasks that yield effective representations.
- Introducing a learnable nonlinear transformation between the representation and the contrastive loss substantially improves the quality of the learned representations.
- Representation learning with contrastive cross entropy loss benefits from normalized embeddings and an appropriately adjusted temperature parameter.
- Contrastive learning benefits from larger batch sizes, longer training, and deeper and wider networks.
Run SimCLR Code
- Train Backbone ResNet Model: Run Python code
# Create conda env $ conda env create --name simclr --file env.yml # Activate conda env $ conda activate simclr # Run SimCLR $ python run.py
- Feature evaluation on downstream tasks: Run on Colab
Code Architecture
└─ data_aug
│ └─ contrastive_learning_dataset.py # Load datasets and data augmentation
│ └─ gaussian_blur.py # Gaussian blur
│ └─ view_generator.py # Generate two cropped images
└─ exceptions
│ └─ exceptions.py # User-defined exceptions
└─ feature_eval
│ └─ mini_batch_logistic_regression_evaluator.ipynb # Colab code for evaluation
└─ models
│ └─ resnet_simclr.py # Load pre-trained model (ResNet) and add FC
└─ env.yml # Setting environment (conda)
└─ run.py # Run training
└─ simclr.py # SimCLR trainer
└─ utils.py # Save and evaluation utils
Major Components (§2.1)
1. Composition of data augmentations (§3)
plays a critical role in defining effective predictive tasks.
- View generator Code
- Generate
n_views
cropped images of target imagex
.
- Generate
class ContrastiveLearningViewGenerator(object):
"""Take two random crops of one image as the query and key."""
def __init__(self, base_transform, n_views=2):
self.base_transform = base_transform
self.n_views = n_views
def __call__(self, x):
return [self.base_transform(x) for i in range(self.n_views)]
- Data augmentation Code
- Compose chained transforms together using Compose.
def get_simclr_pipeline_transform(size, s=1):
"""Return a set of data augmentation transformations as described in the SimCLR paper."""
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(kernel_size=int(0.1 * size)),
transforms.ToTensor()])
return data_transforms
- You can see full illustration of transforms here.
2. A neural network base encoder f()
extracts representation vectors from augmented data examples.
class ResNetSimCLR(nn.Module):
def __init__(self, base_model, out_dim):
super(ResNetSimCLR, self).__init__()
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
"resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}
self.backbone = self._get_basemodel(base_model)
dim_mlp = self.backbone.fc.in_features
# add mlp projection head
self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)
def _get_basemodel(self, model_name):
try:
model = self.resnet_dict[model_name]
except KeyError:
raise InvalidBackboneError(
"Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
else:
return model
3. Neural network projection head g()
maps representations to the space where contrastive loss is applied.
A MLP with one hidden layer to obtain $z_i = g(h_i) = W^{(2)} \sigma (W^{(1)} h_i)$
- Projection head g( ) Code
- Add Fully Connected (FC) layer (MLP projection head
g
) to backbone network (ResNet-50).
self.backbone = self._get_basemodel(base_model)
dim_mlp = self.backbone.fc.in_features
# add mlp projection head
self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)
4. Contrastive loss function
Given a set ${\tilde x^k}$ including a positive pair of examples $\tilde x_i$ and $\tilde x_j$, the contrastive prediction task aims to identify $\tilde x_j$ in ${\tilde x^k}_{k \neq i}$ for a given $\tilde x_i$.
Introducing a learnable nonlinear transformation between the representation and the contrastive loss substantially improves the quality of the learned representations
SimCLR Algorithm to Code
1. Initialize components and arguments
Code: run.py
- Load dataset from path and define valid datasets. Pick one dataset either CIFAR-10 or STL-10.
# load dataset while setting data path
dataset = ContrastiveLearningDataset(args.data)
# load valid dataset with defined transform
train_dataset = dataset.get_dataset(args.dataset_name, args.n_views)
- Put dataset in torch DataLoader.
# put dataset in torch DataLoader
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True, drop_last=True
)
- Load base encoder
f
(ResNet-50),optimizer
(LARS), andscheduler
(cosine decay schedule without restarts).
# load base encoder f (ResNet)
model = ResNetSimCLR(base_model=args.arch, out_dim=args.out_dim)
### LARS optimizer
# To overcome the optimization difficulties of large batch training,
# **Layer-wise Adaptive Rate Scaling(LARS)** was used.
# * LARS uses a separate learning rate for each layer and not for each weight,
# which leads to better stability.
# * The magnitude of the update is controlled with respect to
# the weight norm for better control of training speed.
base_optimizer = optim.SGD(model.parameters(), lr=0.1)
optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
T_max=len(train_loader),
eta_min=0,
last_epoch=-1)
- Initialize SimCLR model and train it.
# It’s a no-op if the 'gpu_index' argument is a negative integer or None.
with torch.cuda.device(args.gpu_index):
# initialize SimCLR model
simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args)
# train SimCLR model
simclr.train(train_loader)
2. Train SimCLR model
Code: simclr.py — train()
- Initialize variables
class SimCLR(object):
def __init__(self, *args, **kwargs):
self.args = kwargs['args']
self.model = kwargs['model'].to(self.args.device)
self.optimizer = kwargs['optimizer']
self.scheduler = kwargs['scheduler']
self.writer = SummaryWriter()
logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG)
self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
def train(self, train_loader):
scaler = GradScaler(enabled=self.args.fp16_precision)
# save config file
save_config_file(self.writer.log_dir, self.args)
n_iter = 0 # global optimization step
logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
logging.info(f"Training with gpu: {self.args.disable_cuda}.")
# ... (train code described as below) ...
n_iter += 1
- Iterate epoch and train_loader for each epoch
for epoch_counter in range(self.args.epochs):
for images, _ in tqdm(train_loader):
images = torch.cat(images, dim=0)
images = images.to(self.args.device)
- Produce image representation
features
(z) fromself.model
with autocast(enabled=self.args.fp16_precision):
# backbone model + MLP outputs
features = self.model(images)
- Calculate Info NCE loss from
features
(← g ← f)
# calculate Info NCE loss
logits, labels = self.info_nce_loss(features)
loss = self.criterion(logits, labels)
- Update network
f
andg
, and minimize loss
self.optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(self.optimizer)
scaler.update()
- Calculate top-1 and top-5 accuracy, and log training loss and learning rateIterate epoch and train_loader for each epoch
if n_iter % self.args.log_every_n_steps == 0:
top1, top5 = accuracy(logits, labels, topk=(1, 5))
self.writer.add_scalar('loss', loss, global_step=n_iter)
self.writer.add_scalar('acc/top1', top1[0], global_step=n_iter)
self.writer.add_scalar('acc/top5', top5[0], global_step=n_iter)
self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)
- Warmup for 10 epochs
❌ CosineAnnealingLR without warmup
✔️ CosineAnnealingLR with warmup
# warmup for the first 10 epochs
if epoch_counter >= 10:
self.scheduler.step()
logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")
- Save checkpoints
# save model checkpoints
checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(self.args.epochs)
save_checkpoint({
'epoch': self.args.epochs,
'arch': self.args.arch,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
}, is_best=False, filename=os.path.join(self.writer.log_dir, checkpoint_name))
logging.info(f"Model checkpoint and metadata has been saved at {self.writer.log_dir}.")
3. Calculate Info NCE loss
Code: simclr.py — info_nce_loss()
- Calculate cosine similarity
$$ sim(u, v) = u^T v / || u|| || v || $$
- Define indicator functions as matrices
💭 Remind: Given a set $\{\tilde x^k\}$ including a positive pair of examples $\tilde x_i$ and $\tilde x_j$, the contrastive prediction task aims to identify $\tilde x_j$ in $\{\tilde x^k\}_{k \neq i}$ for a given $\tilde x_i$.
- mask : evaluates to 1 iff k=i (1 for same samples and 0 for different samples from anchor image — 기준 이미지 $\tilde x_i$)
→ we will use ~mask which evaluates to 1 iff k≠i - labels : evaluates to 1 iff k=i (anchor image itself) or k=j (positive samples)
→ we will use labels[~mask] which evaluates to 1 iff k=j
def info_nce_loss(self, features):
"""
features: FC layer output (z)
return logits, loss
"""
labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0)
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
labels = labels.to(self.args.device)
features = F.normalize(features, dim=1)
similarity_matrix = torch.matmul(features, features.T)
# assert similarity_matrix.shape == (
# self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
# assert similarity_matrix.shape == labels.shape
- Discard diagonal elements
# discard the main diagonal from both: labels and similarities matrix
mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)
labels = labels[~mask].view(labels.shape[0], -1)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
# assert similarity_matrix.shape == labels.shape
- Select negatives and positives
logits
: [positives, negatives]label
: [0, 0, ...., 0] indicates the indices of positive pairs (always located at column 0)- logits and labels are used to calculate CrossEntropyLoss
# select and combine multiple positives
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
# select only the negatives the negatives
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)
- Explanation
- They put cos_sim values of positive pairs as the first column of a matrix. And cos_sim of negative pairs are the rest columns of this matrix. Now in this matrix, you can consider each row logits, the 0th element is the "right" class, you can do CE on each row and expect the model to pick the right "class"(the positive pair). That is why they have a zero-value vector "labels" — Github Issue 1, Github Issue 2
- Temperature-scaled logits for contrastive loss
- if $\tau < 1$ → sharpening ✔️
- if $\tau > 1$ → smoothing
# normalized temperature-scaled logits for CE loss
logits = logits / self.args.temperature
return logits, labels
4. Remove FC weights and bias
Code: mini-batch-logistic-regression-evaluator.ipynb
- Remove
state_dict
offc.weight
,fc.bias
. - But we still have name and architecture of
fc.weight
,fc.bias
inmodel.named_parameters()
.
checkpoint = torch.load('checkpoint_0100.pth.tar', map_location=device)
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
if k.startswith('backbone.'):
if k.startswith('backbone') and not k.startswith('backbone.fc'):
# remove prefix
state_dict[k[len("backbone."):]] = state_dict[k]
del state_dict[k]
Feature Evaluation (§2.3, §4.2)
Code: Feature Evaluation Code
1. Load data
def get_stl10_data_loaders(download, batch_size=256):
train_dataset = datasets.STL10('./data', split='train', download=download,
transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size,
num_workers=10, drop_last=False, shuffle=True)
test_dataset = datasets.STL10('./data', split='test', download=download,
transform=transforms.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
num_workers=10, drop_last=False, shuffle=False)
return train_loader, test_loader
def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
train_dataset = datasets.CIFAR10('./data', train=True, download=download,
transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size,
num_workers=10, drop_last=False, shuffle=True)
test_dataset = datasets.CIFAR10('./data', train=False, download=download,
transform=transforms.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
num_workers=10, drop_last=False, shuffle=False)
return train_loader, test_loader
if config.dataset_name == 'cifar10':
train_loader, test_loader = get_cifar10_data_loaders(download=True)
elif config.dataset_name == 'stl10':
train_loader, test_loader = get_stl10_data_loaders(download=True)
print("Dataset:", config.dataset_name)
- Freeze all layers except for the laste FC layer
# freeze all layers but the last fc
for name, param in model.named_parameters():
if name not in ['fc.weight', 'fc.bias']:
param.requires_grad = False
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2 # fc.weight, fc.bias
2. Load trained SimCLR backbone model
if config.arch == 'resnet18':
model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)
elif config.arch == 'resnet50':
model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)
checkpoint = torch.load('checkpoint_0100.pth.tar', map_location=device)
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
if k.startswith('backbone.'):
if k.startswith('backbone') and not k.startswith('backbone.fc'):
# remove prefix ex. backbone.layer -> layer
state_dict[k[len("backbone."):]] = state_dict[k]
del state_dict[k]
3. Define evaluation metric: top-k accuracy
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
4. Train FC layer and evaluate model
epochs = 100
for epoch in range(epochs):
top1_train_accuracy = 0
for counter, (x_batch, y_batch) in enumerate(train_loader):
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
logits = model(x_batch)
loss = criterion(logits, y_batch)
top1 = accuracy(logits, y_batch, topk=(1,))
top1_train_accuracy += top1[0]
optimizer.zero_grad()
loss.backward()
optimizer.step()
top1_train_accuracy /= (counter + 1)
top1_accuracy = 0
top5_accuracy = 0
for counter, (x_batch, y_batch) in enumerate(test_loader):
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
logits = model(x_batch)
top1, top5 = accuracy(logits, y_batch, topk=(1,5))
top1_accuracy += top1[0]
top5_accuracy += top5[0]
top1_accuracy /= (counter + 1)
top5_accuracy /= (counter + 1)
print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")
5. Reproduce results
arch='resnet-18'
dataset_name='stl10'
Thank you for reading 😊
Q&A
Q) similarity를 고를 때, positive sample을 선별할 때 이미지의 label을 사용한다는 말씀이신가요?
A) Similarity를 positive/negative로 나눌 때에는 image label을 사용하는 것이 아니라 similarity에 indicator (labels)를 곱해 선택해 사용합니다. sim(img1, img1') 이런 식으로 계산된 similarity를 사용합니다.
예를 들어 img1이 주어졌다고 했을 때, sim(img1, img1'), sim(img1', img1)을 positive sample로 선택하는 것입니다.
- '은 augmented image를, sim은 similarity score를 말합니다.
Q) positive는 자기 자신이고, 나머지가 negative란 것으로 저는 알고 있었는데, 그게 아니라 같은 label이면 postive로 분류하나요?
A) positive pair는 자기 자신(주어진 anchor image)과 augmented image입니다. 원본 이미지끼리, augmented image 끼리의 similarity sim(img1, img1), sim(img1', img1')는 loss 계산 시 배제합니다.
→ 위 질문은 코드 저자가 선정한 labels 이라는 variable 이름 때문에 헷갈리실 수 있을 것 같습니다. 코드 중간의 labels는 indicator function이고 마지막에 return 하기 직전에만 일반적으로 사용하는 labels라고 생각하시면 좋을 것 같습니다.
Q) 혹시, SimCLR로 얻는 이미지의 Representation 방식을 가지고, 서로 다른 이미지끼리 Representation을 비교했을 때, 유사한 이미지끼리는 Similarity 측정치가 높게 나오고, 전혀 다른 이미지끼리는 Similarity가 낮게 나오는 측정 결과도 있나요?
A) 실험 결과를 중간에 찍어 보아야 알 수 있지만, 유사 이미지 representation 끼리 similarity가 높게, 다른 이미지 representation 끼리는 similarity가 낮게 만들어 주기 위해, Info NCE loss를 계산하여 최소화합니다. 결과적으로 representation learning이 잘 되었는지(= label을 잘 맞추는지)는 feature evaluation 부분에서 top-k accuracy를 통해 알 수 있습니다.
Q) SimCLR에서 Projection Head를 사용하여 Contrastive를 수행했을 때 더 좋은 성능을 내는것에 대해 궁금한 점이 있습니다. 추가적인 MLP를 통해 Feature를 한 번 더 Embedding시킨 Vector를 사용하는 것이 더 깊은 네트워크(ResNet101, ResNet152, etc.)에서 사용하였을 때도 ResNet50에서 사용했을 때 만큼의 성능 향상이 있을까요?
A) Big Self-Supervised Models are Strong Semi-Supervised Learners 라는 논문의 Table1을 참고하시면 좋을 것 같습니다. 더 큰 모델을 적용할수록, 더 깊은 projection head을 사용할수록 성능 향상이 되는것을 확인하실 수 있습니다.
(Ref) https://arxiv.org/pdf/2006.10029.pdf
Reference
- Annotated Code https://github.com/iamchosenlee/SimCLR-1
- Original Code https://github.com/sthalles/SimCLR
- Blog Post “Exploring SimCLR: A Simple Framework for Contrastive Learning of Visual Representations” https://sthalles.github.io/simple-self-supervised-learning/
- Paper “SimCLR: A Simple Framework for Contrastive Learning of Visual Representations” https://arxiv.org/pdf/2002.05709.pdf