Skip to content

PyTorch on Aurora

PyTorch is a popular, open-source deep learning framework developed and released by Facebook. The PyTorch home page, has more information about PyTorch, which you can refer to. For troubleshooting on Aurora, please contact support@alcf.anl.gov.

Major changes in the frameworks module of Spring 2026 (frameworks/2025.3.1)

  • The torch_ccl module has been removed. import oneccl_bindings_for_pytorch as torch_ccl is no longer needed.
  • When initializing torch.distributed, the backend must be changed to xccl from ccl.
  • import intel_extension_for_pytorch as ipex is now deprecated. The vendor is upstreaming all of the functionality from IPEX to the mainline PyTorch distribution. If you experience performance variations after removing the import, please switch back to importing it.
  • horovod support for PyTorch has been removed.
  • ONEAPI_DEVICE_SELECTOR has been set to "opencl:gpu;level_zero:gpu", if this causes any issues, please revert to Level Zero only with export ONEAPI_DEVICE_SELECTOR="level_zero:gpu"

Provided Installation

PyTorch is already installed on Aurora with GPU support and available through the frameworks module. To use it from a compute node, please load the following modules:

module load frameworks

Then, you can import PyTorch in Python as usual (below showing results from the frameworks/2025.3.1 module):

>>> import torch
>>> torch.__version__
'2.10.0a0+git449b176'

A simple but useful check could be to use PyTorch to get device information on a compute node. You can do this the following way:

get-device-info.py
1
2
3
4
5
6
7
8
import torch

print(f"GPU availability: {torch.xpu.is_available()}")
print(f'Number of tiles = {torch.xpu.device_count()}')
current_tile = torch.xpu.current_device()
print(f'Current tile = {current_tile}')
print(f'Current device ID = {torch.xpu.device(current_tile)}')
print(f'Device properties = {torch.xpu.get_device_properties()}')
Example output:
GPU availability: True
Number of tiles = 12
Current tile = 0
Current device ID = <torch.xpu.device object at 0x154c8fad4d40>
Device properties = _XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', device_id=0xBD6, uuid=d20ebf0c-4ca0-6be7-0000-000000000001, driver_version='1.6.33578+42', total_memory=65520MB, max_compute_units=448, gpu_eu_count=448, gpu_subslice_count=56, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)

Tile-as-device setting for AI/ML worklaods

Each Aurora node has 6 GPUs (also called "Devices" or "cards") and each GPU is composed of two tiles (also called "Sub-device"). By default, the frameworks module sets ZE_FLAT_DEVICE_HIERARCHY=FLAT, meaning that the 12 PVC tiles are exposed as devices (see more details on the Python page). This is the recommended setting for AI/ML workloads.

Using the entire PVC GPU as PyTorch devices

By default, each tile is mapped to one PyTorch device, giving a total of 12 devices per node, as seen above. To map a PyTorch device to an entire PVC GPU out of the 6 available on a compute node, set

export ZE_FLAT_DEVICE_HIERARCHY=COMPOSITE

and mask the devices with

# To mask entire PVC GPUs
export ZE_AFFINITY_MASK=0,1

# or to mask particular tiles only (use syntax `Device.Sub-device`)
export ZE_AFFINITY_MASK=0.0,1.0

You can check that each PyTorch device is now mapped to one GPU with:

module load frameworks
ZE_FLAT_DEVICE_HIERARCHY=COMPOSITE ZE_AFFINITY_MASK=0 python test_affinity.py
test_affinity.py
1
2
3
import torch
print(torch.xpu.device_count())
print(torch.xpu.get_device_properties())
Example output
1
_XpuDeviceProperties(name='Intel(R) Data Center GPU Max 1550', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', device_id=0xBD6, uuid=d20ebf0c-4ca0-6be7-0000-000000000000, driver_version='1.6.33578+42', total_memory=131040MB, max_compute_units=896, gpu_eu_count=896, gpu_subslice_count=112, max_work_group_size=1024, max_num_sub_groups=64, sub_group_sizes=[16 32], has_fp16=1, has_fp64=1, has_atomic64=1)

More information and details are available through the Level Zero Specification Documentation - Affinity Mask

Code changes to run PyTorch on Aurora GPUs

Here we list some common changes that you may need to do to your PyTorch code in order to use Intel GPUs.

  1. All the API calls involving torch.cuda, should be replaced with torch.xpu. For example:
    - torch.cuda.device_count()
    + torch.xpu.device_count()
    
  2. When moving tensors and model to GPU, replace "cuda" with "xpu". For example:
    - model = model.to("cuda")
    + model = model.to("xpu")
    

Tip

A more portable solution to select the appropriate device is the following:

if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.xpu.is_available():
    device = torch.device('xpu')
else: 
    device = torch.device('cpu')
model = model.to(device)

Example: training a PyTorch model on a single GPU tile

Here is a simple code to train a dummy PyTorch model on CPU:

pytorch_cpu.py
import torch

torch.manual_seed(0)

src = torch.rand((2048, 1, 512))
tgt = torch.rand((2048, 20, 512))
dataset = torch.utils.data.TensorDataset(src, tgt)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

model = torch.nn.Transformer(batch_first=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
model.train()

for epoch in range(10):
    for source, targets in loader:
        optimizer.zero_grad()

        output = model(source, targets)
        loss = criterion(output, targets)

        loss.backward()
        optimizer.step()

And here is the code to train the same model on a single GPU tile on Aurora, with new or modified lines highlighted:

pytorch_xpu.py
import torch
device = torch.device('xpu')

torch.manual_seed(0)

src = torch.rand((2048, 1, 512))
tgt = torch.rand((2048, 20, 512))
dataset = torch.utils.data.TensorDataset(src, tgt)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

model = torch.nn.Transformer(batch_first=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
model.train()
model = model.to(device)
criterion = criterion.to(device)

for epoch in range(10):
    for source, targets in loader:
        source = source.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()

        output = model(source, targets)
        loss = criterion(output, targets)

        loss.backward()
        optimizer.step()

PyTorch Best Practices on Aurora

When running PyTorch applications, we have found the following practices to be generally, if not universally, useful and encourage you to try some of these techniques to boost performance of your own applications.

  1. Use Reduced Precision. Reduced Precision is available on Intel Max 1550 and is supported with PyTorch operations. In general, the way to do this is via the PyTorch Automatic Mixed Precision package (AMP), as described in the mixed precision documentation. In PyTorch, users generally need to manage casting and loss scaling manually, though context managers and function decorators can provide easy tools to do this.

  2. PyTorch has a JIT module as well as backends to support op fusion, similar to TensorFlow's tf.function tools. See TorchScript for more information.

  3. torch.compile is available for Intel Max 1550 GPU and can be used to speed up training and inference. See PyTorch Docs for more information.

  4. For convolutional neural networks, using channels_last (NHWC) memory format gives better performance. More info here and here

Distributed Training on multiple GPUs

Distributed training with PyTorch on Aurora is facilitated through both Distributed Data Parallel (DDP). Horovod is no longer supported in recent frameworks modules.

Distributed Data Parallel (DDP)

Code changes to train on multiple GPUs using DDP

The key steps in performing distributed training are:

  1. Initialize PyTorch's DistributedDataParallel with backend='xccl'
  2. Use DistributedSampler to partition the training data among the ranks
  3. Pin each rank to a GPU
  4. Wrap the model in DDP to keep it in sync across the ranks
  5. Rescale the learning rate
  6. Use set_epoch for shuffling data across epochs

Here is the code to train the same dummy PyTorch model on multiple GPUs, where new or modified lines have been highlighted:

pytorch_ddp.py
from mpi4py import MPI
import os, socket
import torch
from torch.nn.parallel import DistributedDataParallel as DDP

# DDP: Set environment variables used by PyTorch
SIZE = MPI.COMM_WORLD.Get_size()
RANK = MPI.COMM_WORLD.Get_rank()
LOCAL_RANK = os.environ.get('PALS_LOCAL_RANKID')
os.environ['RANK'] = str(RANK)
os.environ['WORLD_SIZE'] = str(SIZE)
MASTER_ADDR = socket.gethostname() if RANK == 0 else None
MASTER_ADDR = MPI.COMM_WORLD.bcast(MASTER_ADDR, root=0)
os.environ['MASTER_ADDR'] = f"{MASTER_ADDR}.hsn.cm.aurora.alcf.anl.gov"
os.environ['MASTER_PORT'] = str(2345)
print(f"DDP: Hi from rank {RANK} of {SIZE} with local rank {LOCAL_RANK}. {MASTER_ADDR}")

# DDP: initialize distributed communication with xccl backend
torch.distributed.init_process_group(backend='xccl', init_method='env://', rank=int(RANK), world_size=int(SIZE))

# DDP: pin GPU to local rank.
torch.xpu.set_device(int(LOCAL_RANK))
device = torch.device('xpu')
torch.manual_seed(0)

src = torch.rand((2048, 1, 512))
tgt = torch.rand((2048, 20, 512))
dataset = torch.utils.data.TensorDataset(src, tgt)
# DDP: use DistributedSampler to partition the training data
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True, num_replicas=SIZE, rank=RANK, seed=0)
loader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=32)

model = torch.nn.Transformer(batch_first=True)
# DDP: scale learning rate by the number of GPUs.
optimizer = torch.optim.Adam(model.parameters(), lr=(0.001*SIZE))
criterion = torch.nn.CrossEntropyLoss()
model.train()
model = model.to(device)
criterion = criterion.to(device)
# DDP: wrap the model in DDP
model = DDP(model)

for epoch in range(10):
    # DDP: set epoch to sampler for shuffling
    sampler.set_epoch(epoch)

    for source, targets in loader:
        source = source.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()

        output = model(source, targets)
        loss = criterion(output, targets)

        loss.backward()
        optimizer.step()

# DDP: cleanup
torch.distributed.destroy_process_group()

CPU bindings for best performance on Aurora

For good performance, it is important to set the appropriate CPU affinity when launching training scripts with mpiexec. When using all 12 PVC tiles on each of the nodes, the following setting is recommended

export CPU_BIND="verbose,list:4-7:8-11:12-15:16-19:20-23:24-27:56-59:60-63:64-67:68-71:72-75:76-79" # (1)! 
mpiexec ... --cpu-bind=${CPU_BIND} python pytorch_ddp.py
  1. 12 processes per node, evenly split across the 2 CPU sockets, with each rank having 4 cores available

Distributed Training with Multiple CCSs

The Intel PVC GPUs contain 4 Compute Command Streamers (CCSs) on each tile, which can be used to group Execution Units (EUs) into common pools. These pools can then be accessed by separate processes thereby enabling distributed training with multiple MPI processes per tile. This feature on PVC is similar to MPS on NVIDIA GPUs and can be beneficial for increasing computational throughput when training or performing inference with smaller models which do not require the entire memory of a PVC tile. For more information, see the section on using multiple CCSs under the Running Jobs on Aurora page.

For DDP distributed training with multiple CCSs can be enabled programmatically within the user code by explicitly setting the xpu device in PyTorch, for example

import os
from argparse import ArgumentParser
import torch

parser = ArgumentParser(description='CCS Test')
parser.add_argument('--ppd', default=1, type=int, choices=[1,2,4], 
                    help='Number of MPI processes per GPU device') # (1)!
args = parser.parse_args()

local_rank = int(os.environ.get('PALS_LOCAL_RANKID'))
if torch.xpu.is_available():
    xpu_id = local_rank//args.ppd if torch.xpu.device_count()>1 else 0
    assert xpu_id>=0 and xpu_id<torch.xpu.device_count(), \
           f"Assert failed: xpu_id={xpu_id} and {torch.xpu.device_count()} available devices"
    torch.xpu.set_device(xpu_id)
  1. PVC GPU allow the use of 1, 2 or 4 CCSs on each tile

and then adding the proper environment variables and mpiexec settings in the run script. For example, to run distributed training with 48 MPI processes per node exposing 4 CCSs per tile, set

1
2
3
4
5
export ZEX_NUMBER_OF_CCS=0:4,1:4,2:4,3:4,4:4,5:4,6:4,7:4,8:4,9:4,10:4,11:4
BIND_LIST="1:2:4:6:8:10:12:14:16:18:20:22:24:26:28:30:32:34:36:38:40:42:44:46:53:54:56:58:60:62:64:66:68:70:72:74:76:78:80:82:84:86:88:90:92:94:96:98"
mpiexec --pmi=pmix --envall -n 48 --ppn 48 \
    --cpu-bind=verbose,list:${BIND_LIST} \
    python training_script.py --ppd=4

Alternatively, users can use the following modified GPU affinity script in their mpiexec command in order to bind multiple MPI processes to each tile by setting ZE_AFFINITY_MASK

gpu_affinity_ccs.sh
1
2
3
4
5
6
7
#!/bin/bash

num_ccs=$1 # (1)!
shift
gpu_id=$(( PALS_LOCAL_RANKID / num_ccs ))
export ZE_AFFINITY_MASK=$gpu_id
exec "$@"
  1. Note that the script takes the number of CCSs exposed as a command line argument

Checking PVC usage with xpu-smi

Users are invited to check correct placement of the MPI ranks on the different tiles by connecting to the compute node being used and executing

module load xpu-smi
watch -n 0.1 xpu-smi stats -d <GPU_ID> # (1)!

  1. In this case, GPU_ID refers to the 6 GPU on each node, not an individual tile

and checking the GPU and memory utilization of both tiles.

Alternatively, execute

/soft/tools/igt-gpu-tools/master-2022.05.26/bin/intel_gpu_top -d drm:/dev/dri/card0 # (1)!
  1. card0 refers to GPU 0, card1 for GPU 1, etc.

and press 1 on the keybord to see the utilization of the CCS on the selected GPU.

Multiple CCSs and oneCCL

  • When performing distributed training exposing multiple CCSs, the collective communications with the oneCCL backend are delegated to the CPU. This is done in the background by oneCCL, so no change to the users' code is required to move data between host and device, however it may impact the performance of the collectives at scale.
  • When using PyTorch DDP, the model must be offloaded to the XPU device after calling the DDP() wrapper on the model to avoid hangs.