训练大模型也不怕,轻量级TorchShard库减少GPU内存消耗,API与PyTorch相同
第一时间获取价值内容

来自:机器之心
训练大模型时,如何优雅地减少 GPU 内存消耗?你不妨试试这个 TorchShard 库,兼具模型并行与数据并行等特点,还具有与 PyTorch 相同的 API 设计。

建立一个标准的 PyTorch 扩展库,用于使用模型并行性进行扩展训练;
以一种简单、自然的方式使用 PyTorch。
import torchshard as ts
ts.init_process_group(group_size=2) # init parallel groups
m = torch.nn.Sequential( torch.nn.Linear(20, 30, bias=True), ts.nn.ParallelLinear(30, 30, bias=True, dim=None), # equal to nn.Linear() ts.nn.ParallelLinear(30, 30, bias=True, dim=0), # parallel in row dimension ts.nn.ParallelLinear(30, 30, bias=True, dim=1), # parallel in column dimension).cuda()
x = m(x) # forwardloss = ts.nn.functional.parallel_cross_entropy(x, y) # parallel loss functionloss.backward() # backward
torch.save( ts.collect_state_dict(m, m.state_dict()), 'm.pt') # save model statetorchshard 包含必要的功能和操作,如 torch 包;
torchshard.nn 包含图形的基本构建块,如 torch.nn 包;
torchshard.nn.functional 包含 torchshard.nn 的相应功能操作,如 torch.nn.functional 包;
torchshard.distributed 包含处理分布式张量和组的基本功能,如 torch.distributed 包更容易使用。
pip install torchshard
import torchshard as tsts.distributed.init_process_group(group_size=args.world_size)import resnetmodel = resnet.__dict__[args.arch](pretrained=args.pretrained)ts.nn.ParallelLinear.convert_parallel_linear( model, dim=args.model_parallel_dim)print('=> paralleling model'{}''.format(args.arch))criterion = ts.nn.ParallelCrossEntropyLoss().cuda(args.gpu)x = ts.distributed.gather(x, dim=0) # gather input along the dim of batch size x = self.fc(x)output = model(images)if args.enable_model_parallel:target = ts.distributed.gather(target, dim=0)loss = criterion(output, target)
state_dict = model.state_dict()# collect states across all ranksstate_dict = ts.collect_state_dict(model, state_dict)if ts.distributed.get_rank() == 0: torch.save(state_dict, 'resnet50.pt') # save as beforeif ts.distributed.get_rank() == 0:state_dict = torch.load('resnet50.pt')# relocate state_dict() for all ranksstate_dict = ts.relocate_state_dict(model, state_dict)model.load_state_dict(state_dict) # load as before

# gradscalerscaler = torch.cuda.amp.GradScaler(enabled=args.enable_amp_mode)
with torch.cuda.amp.autocast(enabled=args.enable_amp_mode): # compute output output = model(images) if args.enable_model_parallel: target = ts.distributed.gather(target, dim=0) loss = criterion(output, target)
# compute gradient and do SGD stepscaler.scale(loss).backward()scaler.step(optimizer)scaler.update()optimizer.zero_grad()

from torch.distributed.optim import ZeroRedundancyOptimizerif args.enable_zero_optim:print('=> using ZeroRedundancyOptimizer')optimizer = torch.distributed.optim.ZeroRedundancyOptimizer(model.parameters(),optimizer_class=torch.optim.SGD,lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)else:optimizer = torch.optim.SGD(model.parameters(), args.lr,momentum=args.momentum,weight_decay=args.weight_decay)

-结束-
👆 长按识别,即可关注
赞 (0)
