PyTorch DistributedDataParallel Example In Azure ML - Multi-Node Multi-GPU Distributed Training

In this post, we will discuss how to leverage PyTorch’s DistributedDataParallel (DDP) implementation to run distributed training in Azure Machine Learning using Python SDK.

There is a number of steps that needs to be done to transform a single-process model training into a distributed training using DistributedDataParallel. These steps include assigning GPU devices, initializing a process group, creating a DistributedSampler, setting up the model, and more.

Note that DistributedDataParallel is only one of the options, and there are other ways of implementing distributed training in Azure ML which are described in Distributed GPU training guide.

Contents:

Overview

Code in this post is mainly based on the cifar-distributed example referenced in the documentation. The original code is modified/refactored and enriched with explanations and links.

Our example consists of the following three files located in the same directory:

Each of these files are discussed in the sections: training script, submit script, helper functions.

NOTE: Training script code can be used in any environment, not only in Azure ML.

Prerequisites

To follow and run the code presented in this post, we’ll need to have a few things set up:

Now, given we have everything set up, let’s get started!

Training Script

Training script code in this section is common for any DistributedDataParallel implementation and not specific to training on Azure Machine Learning. So, we can use this code when running training on a local machine or any other platform.

In subsections A-G, we will discuss the differences of the training setup specific to a distributed training and thus different from a single-process/singe-GPU training. These are the things you might need to update in your non-distributed training script to make it work with DistributedDataParallel.

The full code for the training script is available at the end of the section: train.py.

A. Specify GPU To Use

Since our nodes can have multiple GPUs, we need to let each process know which GPU to use, otherwise, they will all use the same GPU which is not desirable.

Environment variable CUDA_VISIBLE_DEVICES controls what GPUs CUDA will see. In the code snippet below, we specify that CUDA in this process will only see one GPU with the number LOCAL_RANK. This is because GPUs are numbered starting from zero - the same way as local rank.

IMPORTANT: CUDA_VISIBLE_DEVICES must be set before importing the torch module, otherwise, the value set in the environment variable won’t be respected.

NOTE: Alternatively, we could specify GPU when creating device variable discussed in the next subsection B: torch.device('cuda', local_rank).

1
2
3
import os
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['LOCAL_RANK']
import torch

B. Prepare For Distributed Training

Azure Machine Learning will launch processes and set the following environment variables:

Notes regarding the code snippet below:

1
2
3
4
5
6
7
8
9
world_size = int(os.environ['WORLD_SIZE'])
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
is_distributed = world_size > 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if is_distributed:
    torch.distributed.init_process_group(backend="nccl")

C. Perform Certain Tasks Only In Specific Processes

Some of the tasks don’t need to be run in all processes, it is enough to execute them only in one or some of the processes. Such use cases are metrics reporting, data preparation, model export.

Usually, the following two types of processes are used:

For example, in our case only one process on each node needs to download training data. This can be achieved by performing the task only in processes with local_rank = 0.

HINT: If you want to make all other processes wait until the training data is downloaded before continuing running, then torch.distributed.barrier can help achieve this synchronization. See the example below.

1
2
3
4
5
6
if local_rank == 0: # Only one process to download training data
    train_set = load_train_set('./data', download=True)
if is_distributed: # All processes continue only after all of them reach this point
    torch.distributed.barrier()
if local_rank != 0: # Other processes just load data which is already downloaded
    train_set = load_train_set('./data', download=False)

Another good use case is test set evaluation and model export which in our case are done only by the main process. The following code is executed at the end of the training, see train.py for the full code.

1
2
3
4
5
if rank == 0:
    test_set = load_test_set('./data', True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False)
    evaluate(model, test_loader, device)
    save_model(model)

D. Create Distributed Sampler and Data Loader

We must have a lot of training data since we decided to use data parallelism for our training. But how do processes know which subset of data to use? We don’t want each process to work with the entire training dataset.

DistributedSampler (line 1) allows loading only subset of the training data which is exclusive to this process. In other words, distributed sampler makes sure that data is split evenly between training processes and also this assignment changes between epochs (see subsection G).

NOTE: In the example below, each DDP process will deal with batch_size training samples. This means that one pass will process batch_size * world_size samples.

1
2
3
4
5
6
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) if is_distributed else None
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=args.batch_size,
    shuffle=(train_sampler is None),
    sampler=train_sampler)

E. Initialize Model Using DistributedDataParallel

As usual, we first initialize our model and move it to a corresponding GPU device.

Next, if doing a distributed training, we wrap our model with the DistributedDataParallel class. Additionally, we specify the GPU device of the model (device_ids) and its output (output_device).

Note that the initial model on line 1 will be accessible as model.module after wrapping with DDP (line 3).

1
2
3
model = get_model().to(device)
if is_distributed:
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], output_device=device)

F. Set Learning Rate and Optimizer

This is a standard step for any training. Here, we instantiate SGD optimizer and CrossEntropyLoss loss function which are commonly used in the field.

One thing to mention is the discussion about the learning rate. Since DistributedDataParallel averages gradients across processes, some people suggest that learning rate should be scaled by world_size.

However, PyTorch documentation contains a note about gradients saying that in most cases we can treat DDP and non-DDP models as the same, i.e. use the same learning rate for the same batch size. See this PyTorch forum thread for more discussion.

1
2
3
# learning_rate = args.learning_rate * world_size
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
criterion = torch.nn.CrossEntropyLoss()

G. Update Distributed Sampler On Each Epoch

Remember that we created a distributed sampler in section D to divide training data between DDP processes? Additionally, we would want to reshuffle the training data so that on each epoch a process gets different set of samples.

DistributedSampler’s set_epoch method updates the random seed used to shuffle the sampler. Simply invoke this method at the beginning of each epoch.

1
2
3
4
for epoch in range(args.epochs):
    if is_distributed:
        train_sampler.set_epoch(epoch)
    train_model(model, train_loader, criterion, optimizer, device)

Code: train.py

Here is the full code of train.py with comments.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# =============== train.py ===============

# A. Specify GPU To Use
import os
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['LOCAL_RANK']
import torch
import argparse

# These functions are implemented in Helper Functions section of this post
from utils import (
    load_train_set,
    load_test_set,
    get_model,
    train_model,
    evaluate,
    save_model
)

# These arguments are passed when submitting the job using submit_job.py
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--learning_rate', type=float)
    return parser.parse_args()

def main(args):
    # B. Prepare For Distributed Training
    world_size = int(os.environ['WORLD_SIZE'])
    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])
    is_distributed = world_size > 1
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if is_distributed:
        torch.distributed.init_process_group(backend="nccl")

    # C. Perform Certain Tasks Only In Specific Processes
    if local_rank == 0:
        train_set = load_train_set('./data', download=True)
    if is_distributed:
        torch.distributed.barrier()
    if local_rank != 0:
        train_set = load_train_set('./data', download=False)

    # D. Create Distributed Sampler and Data Loader
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) if is_distributed else None
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        sampler=train_sampler)
    
    # E. Initialize Model Using DistributedDataParallel
    model = get_model().to(device)
    if is_distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], output_device=device)
    
    # F. Set Learning Rate and Optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
    criterion = torch.nn.CrossEntropyLoss()
    
    # G. Update Distributed Sampler On Each Epoch
    for epoch in range(args.epochs):
        if is_distributed:
            train_sampler.set_epoch(epoch)
        train_model(model, train_loader, criterion, optimizer, device)
    
    # C. Perform Certain Tasks Only In Specific Processes
    # Evaluate and save the model only in the main process (with rank 0)
    # Note that it is also possible to perform evaluation using multiple processes in parallel if needed
    if rank == 0:
        test_set = load_test_set('./data', True)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False)
        evaluate(model, test_loader, device)
        save_model(model)

if __name__ == '__main__':
    args = parse_args()
    main(args)

Submit Job Script

We configure and submit our Azure Machine Learning job using submit_job.py script. To run it, we need to have azureml-core python package installed in our local environment as was already mentioned in the prerequisites section.

When the job is submitted, Azure Machine Learning will do the following:

The full code is available later at submit_job.py. In the next subsections, we will break down the script into parts and discuss each of them in detail.

To run the script, use the simple command shown below. Note that you might also need to authenticate to Azure before that, e.g. using az login command.

python submit_job.py

1. Get Workspace and Environment

First, we use Workspace.get method to retrieve our existing workspace by name, subscription ID, and resource group name. Alternatively, we could create a config file and pass it to Workspace.from_config method.

Next, we get an environment curated by Azure ML using Environment.get method. You can also create custom environments if needed, see environments.

1
2
ws = Workspace.get('aml-contoso', subscription_id='075f60ad-3c3c-4e38-b796-20aa693e6c94', resource_group='rg-azureml')
env = Environment.get(ws, 'AzureML-pytorch-1.10-ubuntu18.04-py38-cuda11-gpu')

2. Configure Number of Processes and Nodes

To run a distributed training, we specify how many nodes and processes our job should use, this is done using PyTorchConfiguration class. We pass the following parameters:

1
pytorch_config = PyTorchConfiguration(process_count=4, node_count=2)

3. Configure Training Job

Here, we create an instance of ScriptRunConfig class which contains information about the job:

1
2
3
4
5
6
7
8
9
10
11
12
13
arguments = [
    '--epochs', 20,
    '--batch_size', 32,
    '--learning_rate', 0.001,
]
src_config = ScriptRunConfig(
    source_directory='.',
    script='train.py',
    arguments=arguments,
    compute_target='train-cluster',
    environment=env,
    distributed_job_config=pytorch_config,
)

4. Submit Job

The last step is to specify a name for the experiment. Experiment name is a way to group jobs so that jobs with the same experiment name will appear together. No need to create experiment upfront, it will be created on the fly if doesn’t exists already.

Experiment.submit method returns an object of class Run. The latter can be used to monitor and manage our run. In the example below, we just print an URL of the run in Azure Machine Learning.

1
2
3
exp = Experiment(ws, 'ddp-example')
run = exp.submit(src_config)
print(run.get_portal_url())

Code: submit_job.py

Please find the full code of submit_job.py file below.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# =============== submit_job.py ===============

from azureml.core import Workspace
from azureml.core import ScriptRunConfig, Experiment, Environment
from azureml.core.runconfig import PyTorchConfiguration

def main():
    # 1. Get Workspace and Environment
    ws = Workspace.get('aml-contoso', subscription_id='075f60ad-3c3c-4e38-b796-20aa693e6c94', resource_group='rg-azureml')
    env = Environment.get(ws, 'AzureML-pytorch-1.10-ubuntu18.04-py38-cuda11-gpu')

    # 2. Configure Number of Processes and Nodes
    pytorch_config = PyTorchConfiguration(process_count=4, node_count=2)

    # 3. Configure Training Script Run
    arguments = [
        '--epochs', 20,
        '--batch_size', 32,
        '--learning_rate', 0.001,
    ]
    src_config = ScriptRunConfig(
        source_directory='.',
        script='train.py',
        arguments=arguments,
        compute_target='train-cluster',
        environment=env,
        distributed_job_config=pytorch_config,
    )

    # 4. Submit Job
    exp = Experiment(ws, 'ddp-example')
    run = exp.submit(src_config)
    print(run.get_portal_url())

if __name__ == '__main__':
    main()

Helper Functions

At the beginning of the train.py script on lines 10-17, we imported a bunch of functions which are being used throughout the training. Their implementation is very similar to cifar-distributed example and listed below in utils.py.

The functions implemented in utils.py are not important for the discussion of DistributedDataParallel, hence, we extracted them into a separate python module to keep the training script clean and easy to understand.

Code: utils.py

Here is the full code of our helper functions for model training.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# =============== utils.py ===============

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 128, 3)
        self.fc1 = nn.Linear(128 * 6 * 6, 120)
        self.dropout = nn.Dropout(p=0.2)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 6 * 6)
        x = self.dropout(F.relu(self.fc1(x)))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def load_train_set(location, download):
    return torchvision.datasets.CIFAR10(root=location, train=True, download=download, transform=_get_transform())

def load_test_set(location, download):
    return torchvision.datasets.CIFAR10(root=location, train=False, download=download, transform=_get_transform())

def _get_transform():
    return torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

def get_model():
    return Net()

def train_model(model, train_loader, criterion, optimizer, device):
    model.train()
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()
        optimizer.step()

def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs.to(device))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
    accuracy = correct / total
    print(f"Accuracy: {accuracy:.3f}")

def save_model(model):
    torch.save(model.state_dict(), 'model.pt')