训练大模型也不怕,轻量级TorchShard库减少GPU内存消耗,API与PyTorch相同

新闻 人工智能
训练大模型时,如何优雅地减少 GPU 内存消耗?你不妨试试这个 TorchShard 库,兼具模型并行与数据并行等特点,还具有与 PyTorch 相同的 API 设计。

[[413467]]

模型并行性能够促进视觉任务的性能。但是目前,还没有一个标准库可以让我们像采用混合精度等其他 SOTA 技术那样轻松地采用模型并行性。

最近,马里兰大学帕克分校计算机科学系的研究者 Kaiyu Yue 开源了一个工具TorchShard,这是一个轻量级的引擎,用于将 PyTorch 张量切片成并行的 shard。当模型拥有大量的线性层(例如 BERT、GPT)或者很多类(数百万)时,TorchShard 可以减少 GPU 内存并扩展训练规模,它具有与 PyTorch 相同的 API 设计。

项目地址:https://github.com/KaiyuYue/torchshard

BERT 和 GPT 等超大模型正在成为 NLP 领域应用中的趋势。然而训练这种大模型面临内存限制的问题,为了解决这个难题,研究者使用 Megatron-LM 和 PyTorch-Lightning 模型并行性扩大训练。其中,Megatron-LM 只专注于大规模训练语言模型,而 PyTorch-Lightning 仅基于 sharded 优化器状态和梯度,如 DeepSpeed。

在计算机视觉任务中,我们会在训练基于 Transformer、MLP 模型或在数百万个类中训练模型时遇到同样的问题。TorchShard 的目标是:

  • 建立一个标准的 PyTorch 扩展库,用于使用模型并行性进行扩展训练;
  • 以一种简单、自然的方式使用 PyTorch。

TorchShard 是对模型并行单元(mpu)的彻底重写,是 Megatron-LM 核心。最重要的是,TorchShard 具有与 PyTorch 相同的 API 设计,这意味着所有的子类和子函数都保持与 PyTorch 相同。例如,如果你想让原来的线性层 torch.nn. linear 是并行的,只需将 torch 变成 ts,并调用带有 dim 参数的子类 nn.ParallelLinear,如下所示:

  1. import torchshard as ts 
  2.  
  3. ts.init_process_group(group_size=2) # init parallel groups 
  4.  
  5. m = torch.nn.Sequential( 
  6.  
  7. torch.nn.Linear(2030, bias=True), 
  8.  
  9. ts.nn.ParallelLinear(3030, bias=True, dim=None), # equal to nn.Linear() 
  10.  
  11. ts.nn.ParallelLinear(3030, bias=True, dim=0), # parallel in row dimension 
  12.  
  13. ts.nn.ParallelLinear(3030, bias=True, dim=1), # parallel in column dimension 
  14.  
  15. ).cuda() 
  16.  
  17. x = m(x) # forward 
  18.  
  19. loss = ts.nn.functional.parallel_cross_entropy(x, y) # parallel loss function 
  20.  
  21. loss.backward() # backward 
  22.  
  23. torch.save( 
  24.  
  25. ts.collect_state_dict(m, m.state_dict()), 'm.pt') # save model state 

除此之外,TorchShard 还支持与 DDP 一起使用时的各种特性,保存和加载 shard checkpoints,初始化 shard 参数,以及跨多台机器和 GPU 处理张量。具体如下:

  • torchshard 包含必要的功能和操作,如 torch 包;
  • torchshard.nn 包含图形的基本构建块,如 torch.nn 包;
  • torchshard.nn.functional 包含 torchshard.nn 的相应功能操作,如 torch.nn.functional 包;
  • torchshard.distributed 包含处理分布式张量和组的基本功能,如 torch.distributed 包更容易使用。

如何开始 TorchShard?

安装要求:Python 版本 3.6 以上(含)以及 PyTorch 版本 1.9.0 以上(含)。通过 pip 安装 TorchShard 库:

  1. pip install torchshard 

这里以 ImageNet 上训练 ResNet-50 为例,展示仅需几行代码就能在项目中使用 TorchShard。通常 ResNet-50 设计范式包含两部分:卷积块和全连接层,如下图 1 所示。一般来说,由于大量的类依赖于数据集,最后的线性层比卷积块有更多的参数。所以我们切片最后一个线性层来检查其最大尺寸。

图 1:DDP 以及 DDP + TorchShard 前向训练流。

在上图 1 中,左边展示了传统的 DDP 训练范式。假设我们有两个等级,DDP 将强制每个等级有重复的模型参数。然而,TorchShard 会将层级参数切片到不同的等级,从而减少整个 GPU 内存。现在向 ImageNet 官方训练脚本添加一些代码,修改后的版本已经成为 TorchShard 项目的一部分。

首先将 torchshard import 进来:

  1. import torchshard as ts 

然后需要初始化模型并行的进程组,就像初始化 DDP 进程组的方法一样。只需要设置一个功能参数来告诉 torchshard 应该从目标层中切片出多少个 shard。

  1. ts.distributed.init_process_group(group_size=args.world_size) 

接下来将模型转换为并行版本,其中可以直接将整个模型输入到转换辅助函数中,无需特殊处理。

  1. import resnet 
  2.  
  3. model = resnet.__dict__[args.arch](pretrained=args.pretrained) 
  4.  
  5. ts.nn.ParallelLinear.convert_parallel_linear( 
  6.  
  7. model, dim=args.model_parallel_dim 
  8.  
  9.  
  10. print("=> paralleling model'{}'".format(args.arch)) 

此外,不要忘记损失函数 torchshard.nn.ParallelCrossEntropy ,该损失函数可以根据输入张量在原始 PyTorch 版本和并行版本之间切换运行模式。例如,如果输入张量是由 torchshard 并行层产生的,torchshard.nn.ParallelCrossEntropy 将以并行方式计算损失值。

  1. criterion = ts.nn.ParallelCrossEntropyLoss().cuda(args.gpu) 

当模型并行模式(TorchShard)和数据并行模式(DDP)一起工作时,我们需要处理并行层的输入。每个等级中的参数和训练数据都不同。因此,我们在 ResNet forward 中的并行线性层之前收集输入张量。

  1. x = ts.distributed.gather(x, dim=0) # gather input along the dim of batch size 
  2.  
  3. x = self.fc(x) 

同样地,我们在计算损失值之前收集目标张量。

  1. output = model(images) 
  2.  
  3. if args.enable_model_parallel: 
  4.  
  5. target = ts.distributed.gather(target, dim=0
  6.  
  7. loss = criterion(output, target) 

最后,使用 TorchShard 函数保存和加载 checkpoints 非常简单。TorchShard 提供了名为 torchshard.collect_state_dict 基本函数用于保存 checkpoints,torchshard.relocate_state_dict 用于加载 checkpoints。

保存检查点:

  1. state_dict = model.state_dict() 
  2.  
  3. # collect states across all ranks 
  4.  
  5. state_dict = ts.collect_state_dict(model, state_dict) 
  6.  
  7. if ts.distributed.get_rank() == 0
  8.  
  9. torch.save(state_dict, 'resnet50.pt') # save as before 

加载检查点:

  1. if ts.distributed.get_rank() == 0
  2.  
  3. state_dict = torch.load('resnet50.pt'
  4.  
  5. # relocate state_dict() for all ranks 
  6.  
  7. state_dict = ts.relocate_state_dict(model, state_dict) 
  8.  
  9. model.load_state_dict(state_dict) # load as before 

现在我们已经完成了在 ImageNet 上为 shard 训练添加代码, 然后可以通过增加类的数量来扩展它,即最后一个线性层的输出特征维度。训练脚本可以在 torchshard/project/imagenet 中找到。下图展示了在 8 个 NVIDIA TITAN-XP (12196 MiB) GPU 、类数 ≤ 1000000 上和 16 个 GPU 、类数为 2000000 上训练 ResNet-50 扩展能力。

图 2:在不同并行策略下使用标准 ResNet 训练设置(即输入大小 224 和批量大小 256)的 GPU 内存成本。

使用 AMP 与 ZeRO

TorchShard 以简单自然的 PyTorch 方式与其他技术(例如自动混合精度 AMP 以及 ZeRO)一起混合使用。

  1. # gradscaler 
  2.  
  3. scaler = torch.cuda.amp.GradScaler(enabled=args.enable_amp_mode) 
  4.  
  5.  
  6.  
  7. with torch.cuda.amp.autocast(enabled=args.enable_amp_mode): # compute output 
  8.  
  9. output = model(images) 
  10.  
  11.  
  12.  
  13. if args.enable_model_parallel: 
  14.  
  15. target = ts.distributed.gather(target, dim=0
  16.  
  17. loss = criterion(output, target) 
  18.  
  19.  
  20.  
  21. # compute gradient and do SGD step 
  22.  
  23. scaler.scale(loss).backward() 
  24.  
  25. scaler.step(optimizer) 
  26.  
  27. scaler.update() 
  28.  
  29. optimizer.zero_grad() 

图 3:在不同并行策略以及 AMP 下,使用标准的 ResNet 训练设置时(输入尺寸 224,batch 大小 256),使用 GPU 内存的成本。

ZeRO 是 DeepSpeed 的核心,与 PyTorch >= 1.9.0 一起使用。如果你想测试一个函数,请安装最新版本的脚本来运行,代码如下:

  1. from torch.distributed.optim import ZeroRedundancyOptimizer 
  2.  
  3.  
  4.  
  5. if args.enable_zero_optim: 
  6.  
  7. print('=> using ZeroRedundancyOptimizer'
  8.  
  9. optimizer = torch.distributed.optim.ZeroRedundancyOptimizer( 
  10.  
  11. model.parameters(), 
  12.  
  13. optimizer_class=torch.optim.SGD, 
  14.  
  15. lr=args.lr, 
  16.  
  17. momentum=args.momentum, 
  18.  
  19. weight_decay=args.weight_decay) 
  20.  
  21. else
  22.  
  23. optimizer = torch.optim.SGD(model.parameters(), args.lr, 
  24.  
  25. momentum=args.momentum, 
  26.  
  27. weight_decay=args.weight_decay) 

图 4:在不同的并行策略和 ZeRO 优化器下,在标准 ResNet 训练设置(输入大小 224 和批大小 256)的 GPU 内存成本。

此外,TorchShard 还提供了基本的 Python API 以及和相应的模板文件,以简化自定义并行层的实现。

研究者将持续开发 TorchShard,如 TorchShard 下一个特性是新的数据采样器 torchshard.utils.data.DistributedGroupSampler,它的命名遵循 torch.utils.data.DistributedSampler。该采样器旨在帮助用户构建 M-way 数据并行、N-way 模型并行,使得其就像 DDP 中的 DistributedSampler 一样简单。用户唯一要做的就是设置模型并行组号,然后 DistributedGroupSampler 来确保同一模型并行组中的模块具有相同的训练数据。

 

责任编辑:张燕妮 来源: 机器之心Pro
相关推荐

2023-07-12 10:04:20

模型训练

2011-07-06 09:11:40

MozillaFirefox

2024-01-08 13:38:00

AI模型

2022-08-10 12:21:07

PythonWebBottle

2013-05-15 10:20:16

Paas虚拟化

2022-05-19 14:43:58

PyTorch训练

2023-11-24 11:11:08

Python数据库

2010-11-10 10:57:43

T-SQL代码

2012-07-18 10:09:55

轻量级移动客户端开发类库

2021-03-25 15:19:33

深度学习Pytorch技巧

2020-09-11 10:48:49

微软机器学习开源AI

2023-11-16 16:37:02

2009-07-17 14:38:51

轻量级Swing组件

2009-07-14 18:05:28

轻量级Swing组件

2009-06-05 11:07:30

2010-01-11 10:48:15

2011-12-22 11:02:04

轻量级Linux

2023-06-27 12:56:23

微软AI

2023-08-01 14:28:00

OpenAI模型token

2022-02-21 10:14:15

数据中心电力
点赞
收藏

51CTO技术栈公众号