If you’re planning to fine-tune a trained model on a different dataset, chances are you’re going to freeze some of the early layers and only update the later layers. I won’t go into the details of why you may want to freeze some layers and which ones should be frozen, but I’ll show you how to do it in PyTorch. Let’s get started!

We first need a pre-trained model to start with. The models subpackage in the torchvision package provides definitions for many of the poplular model architectures for image classification. You can construct these models by simply calling their constructor, which would initialize the model with random weights. To use the pre-trained models from the PyTorch Model Zoo, you can call the constructor with the pretrained=True argument. Let’s load the pretrained VGG16 model:

import torch
import torch.nn as nn
import torchvision.models as models

vgg16 = models.vgg16(pretrained=True)

This will start downloading the pretrained model into your computer’s PyTorch cache folder, which usually is the .cache/torch/checkpoints folder under your home directory.

There are multiple ways you can look into the model to see its modules and layers. One way is using the .modules() member function, which returns in iterator containing all the member objects of the model. The .modules() functions recursively goes thruogh all the modules and submodules of the model:

print(list(vgg16.modules()))
[VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
), Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (18): ReLU(inplace=True)
  (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (20): ReLU(inplace=True)
  (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (22): ReLU(inplace=True)
  (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (25): ReLU(inplace=True)
  (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (27): ReLU(inplace=True)
  (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (29): ReLU(inplace=True)
  (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
), Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), AdaptiveAvgPool2d(output_size=(7, 7)), Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): ReLU(inplace=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=4096, out_features=1000, bias=True)
), Linear(in_features=25088, out_features=4096, bias=True), ReLU(inplace=True), Dropout(p=0.5, inplace=False), Linear(in_features=4096, out_features=4096, bias=True), ReLU(inplace=True), Dropout(p=0.5, inplace=False), Linear(in_features=4096, out_features=1000, bias=True)]

That’s a lot of information spewed out onto the screen! Let’s use the .named_module() function instead, which returns a (name, module) tuple and only print the names:

for (name, module) in vgg16.named_modules():
    print(name)
features
features.0
features.1
features.2
features.3
features.4
features.5
features.6
features.7
features.8
features.9
features.10
features.11
features.12
features.13
features.14
features.15
features.16
features.17
features.18
features.19
features.20
features.21
features.22
features.23
features.24
features.25
features.26
features.27
features.28
features.29
features.30
avgpool
classifier
classifier.0
classifier.1
classifier.2
classifier.3
classifier.4
classifier.5
classifier.6

That’s much better! We can see the top level modules are features, avgpool and classifier. We can also see that the features and calssifier modules consist of 31 and 7 layers respectively. These layers are not named, and only have numbers associated with them. If you want to see an even more concise representation of the network, you can use the .named_children() function which does not go inside the top level modules recursively:

for (name, module) in vgg16.named_children():
    print(name)
features
avgpool
classifier

Now let’s see what layers are there under the features module. Here we use the .children() function to get the layers under the features module, since these layers are not ’named':

for (name, module) in vgg16.named_children():
    if name == 'features':
        for layer in module.children():
            print(layer)
Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

We can even go deeper and look at the parameters in each layer. Let’s get the parameters of the first layer under the features module:

for (name, module) in vgg16.named_children():
    if name == 'features':
        for layer in module.children():
            for param in layer.parameters():
                print(param)
            break
Parameter containing:
tensor([[[[-5.5373e-01,  1.4270e-01,  5.2896e-01],
          [-5.8312e-01,  3.5655e-01,  7.6566e-01],
          [-6.9022e-01, -4.8019e-02,  4.8409e-01]],

          [[ 1.7548e-01,  9.8630e-03, -8.1413e-02],
          [ 4.4089e-02, -7.0323e-02, -2.6035e-01],
          [ 1.3239e-01, -1.7279e-01, -1.3226e-01]],

          [[ 3.1303e-01, -1.6591e-01, -4.2752e-01],
          [ 4.7519e-01, -8.2677e-02, -4.8700e-01],
          [ 6.3203e-01,  1.9308e-02, -2.7753e-01]]],


        [[[ 2.3254e-01,  1.2666e-01,  1.8605e-01],
          [-4.2805e-01, -2.4349e-01,  2.4628e-01],
          [-2.5066e-01,  1.4177e-01, -5.4864e-03]],

          [[-1.4076e-01, -2.1903e-01,  1.5041e-01],
          [-8.4127e-01, -3.5176e-01,  5.6398e-01],
          [-2.4194e-01,  5.1928e-01,  5.3915e-01]],

          [[-3.1432e-01, -3.7048e-01, -1.3094e-01],
          [-4.7144e-01, -1.5503e-01,  3.4589e-01],
          [ 5.4384e-02,  5.8683e-01,  4.9580e-01]]],


        [[[ 1.7715e-01,  5.2149e-01,  9.8740e-03],
          [-2.7185e-01, -7.1709e-01,  3.1292e-01],
          [-7.5753e-02, -2.2079e-01,  3.3455e-01]],

          [[ 3.0924e-01,  6.7071e-01,  2.0546e-02],
          [-4.6607e-01, -1.0697e+00,  3.3501e-01],
          [-8.0284e-02, -3.0522e-01,  5.4460e-01]],

          [[ 3.1572e-01,  4.2335e-01, -3.4976e-01],
          [ 8.6354e-02, -4.6457e-01,  1.1803e-02],
          [ 1.0483e-01, -1.4584e-01, -1.5765e-02]]],


        ...,


        [[[ 7.7599e-02,  1.2692e-01,  3.2305e-02],
          [ 2.2131e-01,  2.4681e-01, -4.6637e-02],
          [ 4.6407e-02,  2.8246e-02,  1.7528e-02]],

          [[-1.8327e-01, -6.7425e-02, -7.2120e-03],
          [-4.8855e-02,  7.0427e-03, -1.2883e-01],
          [-6.4601e-02, -6.4566e-02,  4.4235e-02]],

          [[-2.2547e-01, -1.1931e-01, -2.3425e-02],
          [-9.9171e-02, -1.5143e-02,  9.5385e-04],
          [-2.6137e-02,  1.3567e-03,  1.4282e-01]]],


        [[[ 1.6520e-02, -3.2225e-02, -3.8450e-03],
          [-6.8206e-02, -1.9445e-01, -1.4166e-01],
          [-6.9528e-02, -1.8340e-01, -1.7422e-01]],

          [[ 4.2781e-02, -6.7529e-02, -7.0309e-03],
          [ 1.1765e-02, -1.4958e-01, -1.2361e-01],
          [ 1.0205e-02, -1.0393e-01, -1.1742e-01]],

          [[ 1.2661e-01,  8.5046e-02,  1.3066e-01],
          [ 1.7585e-01,  1.1288e-01,  1.1937e-01],
          [ 1.4656e-01,  9.8892e-02,  1.0348e-01]]],


        [[[ 3.2176e-02, -1.0766e-01, -2.6388e-01],
          [ 2.7957e-01, -3.7416e-02, -2.5471e-01],
          [ 3.4872e-01,  3.0041e-02, -5.5898e-02]],

          [[ 2.5063e-01,  1.5543e-01, -1.7432e-01],
          [ 3.9255e-01,  3.2306e-02, -3.5191e-01],
          [ 1.9299e-01, -1.9898e-01, -2.9713e-01]],

          [[ 4.6032e-01,  4.3399e-01,  2.8352e-01],
          [ 1.6341e-01, -5.8165e-02, -1.9196e-01],
          [-1.9521e-01, -4.5630e-01, -4.2732e-01]]]], requires_grad=True)
Parameter containing:
tensor([ 0.4034,  0.3778,  0.4644, -0.3228,  0.3940, -0.3953,  0.3951, -0.5496,
          0.2693, -0.7602, -0.3508,  0.2334, -1.3239, -0.1694,  0.3938, -0.1026,
          0.0460, -0.6995,  0.1549,  0.5628,  0.3011,  0.3425,  0.1073,  0.4651,
          0.1295,  0.0788, -0.0492, -0.5638,  0.1465, -0.3890, -0.0715,  0.0649,
          0.2768,  0.3279,  0.5682, -1.2640, -0.8368, -0.9485,  0.1358,  0.2727,
          0.1841, -0.5325,  0.3507, -0.0827, -1.0248, -0.6912, -0.7711,  0.2612,
          0.4033, -0.4802, -0.3066,  0.5807, -1.3325,  0.4844, -0.8160,  0.2386,
          0.2300,  0.4979,  0.5553,  0.5230, -0.2182,  0.0117, -0.5516,  0.2108],
        requires_grad=True)

Now that we have access to all the modules, layers and their parameters, we can easily freeze them by setting the parameters’ requires_grad flag to False. This would prevent calculating the gradients for these parameters in the backward step which in turn prevents the optimizer from updating them.

Now let’s freeze all the parameters in the features module:

layer_counter = 0
for (name, module) in vgg16.named_children():
    if name == 'features':
        for layer in module.children():
            for param in layer.parameters():
                param.requires_grad = False
            
            print('Layer "{}" in module "{}" was frozen!'.format(layer_counter, name))
            layer_counter+=1
Layer "0" in module "features" was frozen!
Layer "1" in module "features" was frozen!
Layer "2" in module "features" was frozen!
Layer "3" in module "features" was frozen!
Layer "4" in module "features" was frozen!
Layer "5" in module "features" was frozen!
Layer "6" in module "features" was frozen!
Layer "7" in module "features" was frozen!
Layer "8" in module "features" was frozen!
Layer "9" in module "features" was frozen!
Layer "10" in module "features" was frozen!
Layer "11" in module "features" was frozen!
Layer "12" in module "features" was frozen!
Layer "13" in module "features" was frozen!
Layer "14" in module "features" was frozen!
Layer "15" in module "features" was frozen!
Layer "16" in module "features" was frozen!
Layer "17" in module "features" was frozen!
Layer "18" in module "features" was frozen!
Layer "19" in module "features" was frozen!
Layer "20" in module "features" was frozen!
Layer "21" in module "features" was frozen!
Layer "22" in module "features" was frozen!
Layer "23" in module "features" was frozen!
Layer "24" in module "features" was frozen!
Layer "25" in module "features" was frozen!
Layer "26" in module "features" was frozen!
Layer "27" in module "features" was frozen!
Layer "28" in module "features" was frozen!
Layer "29" in module "features" was frozen!
Layer "30" in module "features" was frozen!

Now that some of the parameters are frozen, the optimizer needs to be modified to only get the parameters with requires_grad=True. We can do this by writing a Lambda function when constructing the optimizer:

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, vgg16.parameters()), lr=0.001)

You can now start training your partially frozen model!