如下所示:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#第一行代码 model.to(device)#第二行代码 |
mytensor = my_tensor.to(device)#第三行代码 |
if torch.cuda.device_count() > 1: model = nn.DataParallel(model) |