9.1. Ray Train#

Ray Train 利用 Ray 的 Actor 和 Task 对机器学习和深度学习训练流程进行了封装,实现了单机任务的横向扩展。简单来说,Ray Train 将单机机器学习任务封装在Actor 中,每个 Actor 都拥有一个独立的机器学习模型副本,能够独立完成训练任务。利用 Actor 的横向扩展能力,Ray Train 使得训练任务能够在 Ray 集群上实现扩展。

Ray Train 封装了 PyTorch、PyTorch Lightning、HuggingFace Transformers、XGBoost、LightGBM 等常用机器学习库,并向用户提供了接口。用户无须编写 Actor 代码,只需对原有的单机机器学习工作流进行少量修改,就能快速切换到集群模式。以 PyTorch 为例,本节介绍如何基于数据并行实现训练任务的横向扩展。数据并行的原理详见 章节 12.2

关键步骤#

将一个 PyTorch 单机训练代码修改为 Ray Train 需要做以下修改:

  • 定义 train_loop,它是一个单节点训练的函数,包括加载数据,更新参数。

  • 定义 ScalingConfig,它定义了如何横向扩展这个训练作业,包括需要多少个计算节点,是否使用 GPU 等。

  • 定义 Trainer,把 train_loopScalingConfig 粘合起来,然后执行 Trainer.fit() 方法进行训练。

图 9.1 展示了适配 Ray Train 的关键部分。

../_images/ray-train-key-parts.svg

图 9.1 Ray Train 关键部分#

具体的代码主要包括:

from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig

def train_loop():
    ...

scaling_config = ScalingConfig(num_workers=..., use_gpu=...)
trainer = TorchTrainer(train_loop_per_worker=train_loop, scaling_config=scaling_config)
result = trainer.fit()

案例:图像分类#

下面是一个完整的训练,这个例子使用了 PyTorch 提供的 ResNet 模型 [He et al., 2016],读者可以根据自己环境中的 GPU 数量,设置 ScalingConfig

import os
import tempfile

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision.models import resnet18

import ray
import ray.train.torch
from ray.train import Checkpoint
def train_func(model, optimizer, criterion, train_loader):
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    for data, target in train_loader:
        # 无需手动将 images 和 labels 发送到指定的 GPU 上
        # `prepare_data_loader` 帮忙完成了这个过程
        # data, target = data.to(device), target.to(device)
        output = model(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def test_func(model, data_loader):
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in data_loader:
            # data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total
data_dir = os.path.join(os.getcwd(), "../data")

def train_loop():
    # 加载数据并进行数据增强
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(), 
         torchvision.transforms.Normalize((0.5,), (0.5,))]
    )

    train_loader = DataLoader(
        torchvision.datasets.FashionMNIST(root=data_dir, train=True, download=True, transform=transform),
        batch_size=128,
        shuffle=True)
    test_loader = DataLoader(
        torchvision.datasets.FashionMNIST(root=data_dir, train=False, download=True, transform=transform),
        batch_size=128,
        shuffle=True)

    # 1. 将数据分发到多个计算节点
    train_loader = ray.train.torch.prepare_data_loader(train_loader)
    test_loader = ray.train.torch.prepare_data_loader(test_loader)
    
    # 原始的 resnet 为 3 通道的图像设计的
    # FashionMNIST 为 1 通道,修改 resnet 第一层以适配这种输入
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    
    # 2. 将模型分发到多个计算节点和 GPU 上
    model = ray.train.torch.prepare_model(model)
    criterion = nn.CrossEntropyLoss()
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # 训练 10 个 epoch
    for epoch in range(10):
        if ray.train.get_context().get_world_size() > 1:
            train_loader.sampler.set_epoch(epoch)

        train_func(model, optimizer, criterion, train_loader)
        acc = test_func(model, test_loader)
        
        # 3. 监控训练指标和保存 checkpoint
        metrics = {"acc": acc, "epoch": epoch}

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            torch.save(
                model.state_dict(),
                os.path.join(temp_checkpoint_dir, "model.pt")
            )
            ray.train.report(
                metrics,
                checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
            )
        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)
# 4. 配置 `ScalingConfig`,Ray Train 根据这个配置将训练任务拓展到集群
scaling_config = ray.train.ScalingConfig(num_workers=4, use_gpu=True)

# 5. 使用 TorchTrainer 启动并行训练
trainer = ray.train.torch.TorchTrainer(
    train_loop_per_worker=train_loop,
    scaling_config=scaling_config,
    run_config=ray.train.RunConfig(
        storage_path=os.path.join(data_dir, "torch_ckpt"),
        name="exp_fashionmnist_resnet18",
    )
)
result = trainer.fit()
Hide code cell output

Tune Status

Current time:2024-04-10 09:41:32
Running for: 00:01:33.99
Memory: 31.5/90.0 GiB

System Info

Using FIFO scheduling algorithm.
Logical resource usage: 1.0/64 CPUs, 4.0/4 GPUs (0.0/1.0 accelerator_type:TITAN)

Trial Status

Trial name status loc iter total time (s) acc epoch
TorchTrainer_3d3d1_00000TERMINATED10.0.0.3:49324 10 80.96870.8976 9
(RayTrainWorker pid=49399) Setting up process group for: env:// [rank=0, world_size=4]
(RayTrainWorker pid=49400) [W Utils.hpp:133] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function getCvarInt)
(TorchTrainer pid=49324) Started distributed worker processes: 
(TorchTrainer pid=49324) - (ip=10.0.0.3, pid=49399) world_rank=0, local_rank=0, node_rank=0
(TorchTrainer pid=49324) - (ip=10.0.0.3, pid=49400) world_rank=1, local_rank=1, node_rank=0
(TorchTrainer pid=49324) - (ip=10.0.0.3, pid=49401) world_rank=2, local_rank=2, node_rank=0
(TorchTrainer pid=49324) - (ip=10.0.0.3, pid=49402) world_rank=3, local_rank=3, node_rank=0
(RayTrainWorker pid=49399) Moving model to device: cuda:0
(RayTrainWorker pid=49399) Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=49401) [rank2]:[W Utils.hpp:106] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function getCvarString)
(RayTrainWorker pid=49400) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000000)
(RayTrainWorker pid=49402) [W Utils.hpp:133] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function getCvarInt) [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(RayTrainWorker pid=49402) [rank3]:[W Utils.hpp:106] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function getCvarString) [repeated 3x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.8604, 'epoch': 0}
(RayTrainWorker pid=49399) {'acc': 0.8808, 'epoch': 1}
(RayTrainWorker pid=49399) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000001) [repeated 4x across cluster]
(RayTrainWorker pid=49399) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000002) [repeated 4x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.8852, 'epoch': 2}
(RayTrainWorker pid=49399) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000003) [repeated 4x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.8964, 'epoch': 3}
(RayTrainWorker pid=49399) {'acc': 0.8972, 'epoch': 4}
(RayTrainWorker pid=49401) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000004) [repeated 4x across cluster]
(RayTrainWorker pid=49399) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000005) [repeated 4x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.8968, 'epoch': 5}
(RayTrainWorker pid=49401) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000006) [repeated 4x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.8948, 'epoch': 6}
(RayTrainWorker pid=49399) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000007) [repeated 4x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.894, 'epoch': 7}
(RayTrainWorker pid=49401) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000008) [repeated 4x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.894, 'epoch': 8}
(RayTrainWorker pid=49401) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000009) [repeated 4x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.8976, 'epoch': 9}
2024-04-10 09:41:32,109	WARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.
You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).
2024-04-10 09:41:32,112	INFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name' in 0.0057s.
2024-04-10 09:41:32,120	INFO tune.py:1048 -- Total run time: 94.05 seconds (93.99 seconds for the tuning loop).

与原生 PyTorch 的区别#

与单机程序的区别#

Ray Train 帮用户将模型和数据分发到多个计算节点,用户需要设置:model = ray.train.torch.prepare_model(model)train_loader = ray.train.torch.prepare_data_loader(train_loader),设置之后,Ray Train 不需要显式地调用 model.to("cuda"),也不需要 images, labels = images.to("cuda"), labels.to("cuda") 等将模型数据拷贝到 GPU 的代码。

DistributedDataParallel 的区别#

PyTorch 的 DistributedDataParallel 也可以实现数据并行,Ray Train 把 DistributedDataParallel 中的复杂细节都隐藏起来,只需要用户从单机代码稍作改动,不需要 torch.distributed 的分布式环境(World)和进程(Rank)。有关 World 和 Rank 等概念,可以参考 章节 11.2

数据读取#

如果单机版的数据读取是基于 PyTorch 的 DataLoader,可以使用 ray.train.torch.prepare_data_loader() 对原有的 PyTorch DataLoader 进行适配。也可以使用 Ray Data 提供的数据预处理方法进行数据预处理。

ScalingConfig#

ScalingConfig(num_workers=..., use_gpu=...) 中的 num_workers 参数用于控制任务的并行度,use_gpu 参数用于控制是否使用GPU资源。num_workers 可以理解为启动的 Ray Actor 的数量,每个 Actor 独立执行训练任务。如果 use_gpu=True,在默认情况下,每个 Actor 会分配到 1 个 GPU,相应地,每个 Actor 的环境变量 CUDA_VISIBLE_DEVICES 也是 1 个。若要使每个 Actor 能够访问多个 GPU,可以设置 resources_per_worker 参数:resources_per_worker={"GPU": n}

监控#

分布式训练中,每个 Worker 是独立运行的,但大部分情况下,只需要对进程号(Rank)为 0 的第一个进程监控即可。ray.train.report(metrics=...) 默认收集 Rank=0 的指标。

Checkpoint#

Checkpoint 的过程大致如下:

  1. Checkpoint 会先写到本地的目录,可以直接用 PyTorch、PyTorch Lightning 或 TensorFlow 提供的保存模型的接口。比如刚才例子中的:

with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
    torch.save(
        model.module.state_dict(),
        os.path.join(temp_checkpoint_dir, "model.pt")
    )
  1. ray.train.report(metrics=..., checkpoint=...) 时,将刚刚保存在本地的 Checkpoint 上传到持久化文件系统(比如,S3 或者 HDFS)中,该文件系统可被所有计算节点访问。本地的 Checkpoint 只是一个缓存,Checkpoint 上传到持久化文件系统后,本地的 Checkpoint 会被删除。持久化的文件系统目录在 TorchTrainer 上配置:

TorchTrainer(
    train_loop,
    scaling_config=scaling_config,
    run_config=ray.train.RunConfig(
        storage_path=...,
        name="experiment_name",
    )
)

使用数据并行训练时,每个 Rank 都有一份模型权重的拷贝,保存到本地与持久化文件系统上的 Checkpoint 是一样的。使用流水线并行训练(章节 12.3)等其他并行策略时,每个 Rank 的本地保存的是模型的一部分,每个 Rank 分别保存自己那部分的模型权重。生成 Checkpoint 文件时,应该加上一些文件前后缀,以做区分。

with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
    rank = train.get_context().get_world_rank()
    torch.save(
        ...,
        os.path.join(temp_checkpoint_dir, f"model-rank={rank}.pt"),
    )
    train.report(
        metrics, 
        checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
)