본문 바로가기

Papers/SSL

[코드 리뷰] SimCLR Code Review

Created: October 10, 2021 5:20 PM

This is the report from the class project for code and review reproduction. We utilize the PyTorch version of SimCLR with the most stars.
All source codes and rights belong to sthalles/SimCLR and Google Research.

Contents


Introduction

  • Annotating data is expensive.
  • Training model in an unsupervised way allows us to utilize unlabeled data to learn representations that could be used as a proxy to achieve improvement in performance.
  • Current research on unsupervised representation learning is finally catching up with supervised methods.
  • Contrastive learning is to learn representations by enforcing similar elements to be equal and dissimilar elements to be different.

What makes the model learn good representations?

  • Composition of multiple data augmentation operations is crucial in defining the contrastive prediction tasks that yield effective representations.
  • Introducing a learnable nonlinear transformation between the representation and the contrastive loss substantially improves the quality of the learned representations.
  • Representation learning with contrastive cross entropy loss benefits from normalized embeddings and an appropriately adjusted temperature parameter.
  • Contrastive learning benefits from larger batch sizes, longer training, and deeper and wider networks.

Run SimCLR Code

Image credit @google-research/simclr

  1. Train Backbone ResNet Model: Run Python code
    # Create conda env
    $ conda env create --name simclr --file env.yml
    # Activate conda env
    $ conda activate simclr
    # Run SimCLR
    $ python run.py​
  2. Feature evaluation on downstream tasks: Run on Colab

Code Architecture

└─ data_aug
│  └─ contrastive_learning_dataset.py  # Load datasets and data augmentation
│  └─ gaussian_blur.py                 # Gaussian blur
│  └─ view_generator.py                # Generate two cropped images
└─ exceptions
│  └─ exceptions.py                    # User-defined exceptions
└─ feature_eval
│  └─ mini_batch_logistic_regression_evaluator.ipynb   # Colab code for evaluation
└─ models
│  └─ resnet_simclr.py                 # Load pre-trained model (ResNet) and add FC
└─ env.yml              # Setting environment (conda)
└─ run.py               # Run training
└─ simclr.py            # SimCLR trainer
└─ utils.py             # Save and evaluation utils

Major Components (§2.1)

Image credit @Amit Chaudhary https://amitness.com/2020/03/illustrated-simclr/

1. Composition of data augmentations (§3)

plays a critical role in defining effective predictive tasks.

  • View generator Code
    Figure 3. By "randomly cropping" images, they sample contrastive prediction tasks that include global to local view ($B \to A$) or adjacent view ($D \to C$) prediction.
    • Generate n_views cropped images of target image x.
class ContrastiveLearningViewGenerator(object):
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform, n_views=2):
        self.base_transform = base_transform
        self.n_views = n_views

    def __call__(self, x):
        return [self.base_transform(x) for i in range(self.n_views)]
def get_simclr_pipeline_transform(size, s=1):
      """Return a set of data augmentation transformations as described in the SimCLR paper."""
      color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
      data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.RandomApply([color_jitter], p=0.8),
                                            transforms.RandomGrayscale(p=0.2),
                                            GaussianBlur(kernel_size=int(0.1 * size)),
                                            transforms.ToTensor()])
      return data_transforms

  • You can see full illustration of transforms here.

2. A neural network base encoder f()

extracts representation vectors from augmented data examples.

class ResNetSimCLR(nn.Module):

    def __init__(self, base_model, out_dim):
        super(ResNetSimCLR, self).__init__()
        self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
                            "resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}

        self.backbone = self._get_basemodel(base_model)
        dim_mlp = self.backbone.fc.in_features

        # add mlp projection head
        self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)

    def _get_basemodel(self, model_name):
        try:
            model = self.resnet_dict[model_name]
        except KeyError:
            raise InvalidBackboneError(
                "Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
        else:
            return model

3. Neural network projection head g()

maps representations to the space where contrastive loss is applied.

A MLP with one hidden layer to obtain $z_i = g(h_i) = W^{(2)} \sigma (W^{(1)} h_i)$

self.backbone = self._get_basemodel(base_model)
dim_mlp = self.backbone.fc.in_features

# add mlp projection head
self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)

4. Contrastive loss function

Given a set ${\tilde x^k}$ including a positive pair of examples $\tilde x_i$ and $\tilde x_j$, the contrastive prediction task aims to identify $\tilde x_j$ in ${\tilde x^k}_{k \neq i}$ for a given $\tilde x_i$.

Introducing a learnable nonlinear transformation between the representation and the contrastive loss substantially improves the quality of the learned representations

SimCLR Algorithm to Code

1. Initialize components and arguments

Code: run.py

  • Load dataset from path and define valid datasets. Pick one dataset either CIFAR-10 or STL-10.
# load dataset while setting data path
dataset = ContrastiveLearningDataset(args.data)

# load valid dataset with defined transform
train_dataset = dataset.get_dataset(args.dataset_name, args.n_views)
  • Put dataset in torch DataLoader.
# put dataset in torch DataLoader
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True, drop_last=True
)
  • Load base encoder f (ResNet-50), optimizer (LARS), and scheduler (cosine decay schedule without restarts).
# load base encoder f (ResNet)
model = ResNetSimCLR(base_model=args.arch, out_dim=args.out_dim)

### LARS optimizer  
# To overcome the optimization difficulties of large batch training,
# **Layer-wise Adaptive Rate Scaling(LARS)** was used.  
# * LARS uses a separate learning rate for each layer and not for each weight,
#   which leads to better stability.
# * The magnitude of the update is controlled with respect to 
#   the weight norm for better control of training speed.     
base_optimizer = optim.SGD(model.parameters(), lr=0.1)
optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
							T_max=len(train_loader),
							eta_min=0,
							last_epoch=-1)
  • Initialize SimCLR model and train it.
#  It’s a no-op if the 'gpu_index' argument is a negative integer or None.
with torch.cuda.device(args.gpu_index):
	# initialize SimCLR model
    simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args)
	# train SimCLR model
    simclr.train(train_loader)

2. Train SimCLR model

Code: simclr.py — train()

  • Initialize variables
class SimCLR(object):

    def __init__(self, *args, **kwargs):
        self.args = kwargs['args']
        self.model = kwargs['model'].to(self.args.device)
        self.optimizer = kwargs['optimizer']
        self.scheduler = kwargs['scheduler']
        self.writer = SummaryWriter()
        logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG)
        self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)

    def train(self, train_loader):
        scaler = GradScaler(enabled=self.args.fp16_precision)

        # save config file
        save_config_file(self.writer.log_dir, self.args)

        n_iter = 0   # global optimization step
        logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
        logging.info(f"Training with gpu: {self.args.disable_cuda}.")

        # ... (train code described as below) ...
      n_iter += 1
  • Iterate epoch and train_loader for each epoch
for epoch_counter in range(self.args.epochs):
    for images, _ in tqdm(train_loader):
        images = torch.cat(images, dim=0)
        images = images.to(self.args.device)
  • Produce image representation features (z) from self.model
        with autocast(enabled=self.args.fp16_precision):
			# backbone model + MLP outputs
            features = self.model(images)
  • Calculate Info NCE loss from features(← g ← f)
			# calculate Info NCE loss
            logits, labels = self.info_nce_loss(features)
            loss = self.criterion(logits, labels)
  • Update network f and g, and minimize loss
        self.optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(self.optimizer)
        scaler.update()
  • Calculate top-1 and top-5 accuracy, and log training loss and learning rateIterate epoch and train_loader for each epoch
        if n_iter % self.args.log_every_n_steps == 0:
            top1, top5 = accuracy(logits, labels, topk=(1, 5))
            self.writer.add_scalar('loss', loss, global_step=n_iter)
            self.writer.add_scalar('acc/top1', top1[0], global_step=n_iter)
            self.writer.add_scalar('acc/top5', top5[0], global_step=n_iter)
            self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)
  • Warmup for 10 epochs

❌ CosineAnnealingLR without warmup

https://gaussian37.github.io/dl-pytorch-lr_scheduler/#cosineannealinglr-1

✔️ CosineAnnealingLR with warmup

https://www.researchgate.net/figure/Comparison-between-constant-lr-scheduler-and-cosine-annealing-lr-scheduler-with-linear_fig1_336936339

# warmup for the first 10 epochs
if epoch_counter >= 10:
    self.scheduler.step()
logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")
  • Save checkpoints
# save model checkpoints
checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(self.args.epochs)
save_checkpoint({
    'epoch': self.args.epochs,
    'arch': self.args.arch,
    'state_dict': self.model.state_dict(),
    'optimizer': self.optimizer.state_dict(),
}, is_best=False, filename=os.path.join(self.writer.log_dir, checkpoint_name))
logging.info(f"Model checkpoint and metadata has been saved at {self.writer.log_dir}.")

3. Calculate Info NCE loss

Code: simclr.py — info_nce_loss()

Image credit @Silva, Thalles Santos

  • Calculate cosine similarity

$$ sim(u, v) =  u^T  v / || u|| || v || $$

  • Define indicator functions as matrices
💭 Remind: Given a set $\{\tilde x^k\}$ including a positive pair of examples $\tilde x_i$ and $\tilde x_j$, the contrastive prediction task aims to identify $\tilde x_j$ in $\{\tilde x^k\}_{k \neq i}$ for a given $\tilde x_i$.
  • mask : evaluates to 1 iff k=i (1 for same samples and 0 for different samples from anchor image — 기준 이미지 $\tilde x_i$)
    → we will use ~mask which evaluates to 1 iff k≠i
  • labels : evaluates to 1 iff k=i (anchor image itself) or k=j (positive samples)
    → we will use labels[~mask] which evaluates to 1 iff k=j
def info_nce_loss(self, features):
  """
  features: FC layer output (z)
  return logits, loss
  """

  labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0)
  labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
  labels = labels.to(self.args.device)
  features = F.normalize(features, dim=1)

  similarity_matrix = torch.matmul(features, features.T)
  # assert similarity_matrix.shape == (
  #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
  # assert similarity_matrix.shape == labels.shape​
  • Discard diagonal elements
	  # discard the main diagonal from both: labels and similarities matrix
	  mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)
	  labels = labels[~mask].view(labels.shape[0], -1)
	  similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
	  # assert similarity_matrix.shape == labels.shape
  • Select negatives and positives
    • logits : [positives, negatives]
    • label : [0, 0, ...., 0] indicates the indices of positive pairs (always located at column 0)
    • logits and labels are used to calculate CrossEntropyLoss

    # select and combine multiple positives
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

    # select only the negatives the negatives
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)
  • Explanation
    • They put cos_sim values of positive pairs as the first column of a matrix. And cos_sim of negative pairs are the rest columns of this matrix. Now in this matrix, you can consider each row logits, the 0th element is the "right" class, you can do CE on each row and expect the model to pick the right "class"(the positive pair). That is why they have a zero-value vector "labels" — Github Issue 1, Github Issue 2
  • Temperature-scaled logits for contrastive loss
    • if $\tau < 1$ → sharpening ✔️
    • if $\tau > 1$ → smoothing

Image credit @Shixiang Gu, Categorical Reparameterization with Gumbel-Softmax

	  # normalized temperature-scaled logits for CE loss
	  logits = logits / self.args.temperature
	  return logits, labels

4. Remove FC weights and bias

Code: mini-batch-logistic-regression-evaluator.ipynb

  • Remove state_dict of fc.weight, fc.bias.
  • But we still have name and architecture of fc.weight, fc.bias in model.named_parameters().
checkpoint = torch.load('checkpoint_0100.pth.tar', map_location=device)
state_dict = checkpoint['state_dict']

for k in list(state_dict.keys()):

      if k.startswith('backbone.'):
            if k.startswith('backbone') and not k.startswith('backbone.fc'):
                  # remove prefix
                  state_dict[k[len("backbone."):]] = state_dict[k]
      del state_dict[k]

Feature Evaluation (§2.3, §4.2)

Code: Feature Evaluation Code

1. Load data

def get_stl10_data_loaders(download, batch_size=256):
      train_dataset = datasets.STL10('./data', split='train', download=download,
                                      transform=transforms.ToTensor())

      train_loader = DataLoader(train_dataset, batch_size=batch_size,
                                num_workers=10, drop_last=False, shuffle=True)

      test_dataset = datasets.STL10('./data', split='test', download=download,
                                      transform=transforms.ToTensor())

      test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                                num_workers=10, drop_last=False, shuffle=False)
      return train_loader, test_loader

def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
      train_dataset = datasets.CIFAR10('./data', train=True, download=download,
                                      transform=transforms.ToTensor())

      train_loader = DataLoader(train_dataset, batch_size=batch_size,
                                num_workers=10, drop_last=False, shuffle=True)

      test_dataset = datasets.CIFAR10('./data', train=False, download=download,
                                      transform=transforms.ToTensor())

      test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                                num_workers=10, drop_last=False, shuffle=False)
      return train_loader, test_loader
if config.dataset_name == 'cifar10':
      train_loader, test_loader = get_cifar10_data_loaders(download=True)
elif config.dataset_name == 'stl10':
      train_loader, test_loader = get_stl10_data_loaders(download=True)

print("Dataset:", config.dataset_name)
  • Freeze all layers except for the laste FC layer
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias

2. Load trained SimCLR backbone model

if config.arch == 'resnet18':
      model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)
elif config.arch == 'resnet50':
      model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)

checkpoint = torch.load('checkpoint_0100.pth.tar', map_location=device)
state_dict = checkpoint['state_dict']

for k in list(state_dict.keys()):
      if k.startswith('backbone.'):
            if k.startswith('backbone') and not k.startswith('backbone.fc'):
                  # remove prefix ex. backbone.layer -> layer
                  state_dict[k[len("backbone."):]] = state_dict[k]
      del state_dict[k]

3. Define evaluation metric: top-k accuracy

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

4. Train FC layer and evaluate model

epochs = 100

for epoch in range(epochs):
      top1_train_accuracy = 0

      for counter, (x_batch, y_batch) in enumerate(train_loader):
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            logits = model(x_batch)
            loss = criterion(logits, y_batch)

            top1 = accuracy(logits, y_batch, topk=(1,))
            top1_train_accuracy += top1[0]

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

      top1_train_accuracy /= (counter + 1)
      top1_accuracy = 0
      top5_accuracy = 0

      for counter, (x_batch, y_batch) in enumerate(test_loader):
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            logits = model(x_batch)

            top1, top5 = accuracy(logits, y_batch, topk=(1,5))
            top1_accuracy += top1[0]
            top5_accuracy += top5[0]

      top1_accuracy /= (counter + 1)
      top5_accuracy /= (counter + 1)
      print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")

5. Reproduce results

arch='resnet-18'
dataset_name='stl10'

 

Thank you for reading 😊


Q&A

Q) similarity를 고를 때, positive sample을 선별할 때 이미지의 label을 사용한다는 말씀이신가요?

A) Similarity를 positive/negative로 나눌 때에는 image label을 사용하는 것이 아니라 similarity에 indicator (labels)를 곱해 선택해 사용합니다. sim(img1, img1') 이런 식으로 계산된 similarity를 사용합니다.

예를 들어 img1이 주어졌다고 했을 때, sim(img1, img1'), sim(img1', img1)을 positive sample로 선택하는 것입니다.

  • '은 augmented image를, sim은 similarity score를 말합니다.

 

Q) positive는 자기 자신이고, 나머지가 negative란 것으로 저는 알고 있었는데, 그게 아니라 같은 label이면 postive로 분류하나요?

A) positive pair는 자기 자신(주어진 anchor image)과 augmented image입니다. 원본 이미지끼리, augmented image 끼리의 similarity sim(img1, img1), sim(img1', img1')는 loss 계산 시 배제합니다.

→ 위 질문은 코드 저자가 선정한 labels 이라는 variable 이름 때문에 헷갈리실 수 있을 것 같습니다. 코드 중간의 labels는 indicator function이고 마지막에 return 하기 직전에만 일반적으로 사용하는 labels라고 생각하시면 좋을 것 같습니다.

 

Q) 혹시, SimCLR로 얻는 이미지의 Representation 방식을 가지고, 서로 다른 이미지끼리 Representation을 비교했을 때, 유사한 이미지끼리는 Similarity 측정치가 높게 나오고, 전혀 다른 이미지끼리는 Similarity가 낮게 나오는 측정 결과도 있나요?

A) 실험 결과를 중간에 찍어 보아야 알 수 있지만, 유사 이미지 representation 끼리 similarity가 높게, 다른 이미지 representation 끼리는 similarity가 낮게 만들어 주기 위해, Info NCE loss를 계산하여 최소화합니다. 결과적으로 representation learning이 잘 되었는지(= label을 잘 맞추는지)는 feature evaluation 부분에서 top-k accuracy를 통해 알 수 있습니다.

 

Q) SimCLR에서 Projection Head를 사용하여 Contrastive를 수행했을 때 더 좋은 성능을 내는것에 대해 궁금한 점이 있습니다. 추가적인 MLP를 통해 Feature를 한 번 더 Embedding시킨 Vector를 사용하는 것이 더 깊은 네트워크(ResNet101, ResNet152, etc.)에서 사용하였을 때도 ResNet50에서 사용했을 때 만큼의 성능 향상이 있을까요?

A) Big Self-Supervised Models are Strong Semi-Supervised Learners 라는 논문의 Table1을 참고하시면 좋을 것 같습니다. 더 큰 모델을 적용할수록, 더 깊은 projection head을 사용할수록 성능 향상이 되는것을 확인하실 수 있습니다.
(Ref) https://arxiv.org/pdf/2006.10029.pdf

 

Reference

반응형