前言

在大三的时候已经使用PyTorch写过简单的DNN、CNN、预训练模型等,但当时只是被学分课(机器学习、计算机视觉)逼着写的,所以写完作业就基本不碰PyTorch了,也没有认真研究很多细节。现重新学习PyTorch,记录其很多重要但容易被忽略的细节,争取早日开始复现代码~

0 推荐学习资源

1 经验细节汇总

1.1 数据操作

1.1.1 内存中的tensor

由于python对于变量在内存中的特殊储存方式,基于python的PyTorch也会因此受到影响,具体有以下几种形式:

  1. 像Numpy一样,对一个tensor使用索引操作(如new_tensor=tensor[1:]),索引出的结果与这个tensor共享内存(即修改一个,另一个也会跟着修改)
  2. .view()改变tensor的形状,返回的新tensor与源tensor共享内存(顾名思义,.view()仅仅改变对该张量的观察角度,内部数据并未改变)。所以如果想返回一个真正副本,推荐使用.clone.view().reshape()
  3. 使用.numpy().from_numpy()将tensor与Numpy中的array相互转换时,产生的tensor和array共享内存。如果这个tensor需要一个新的内存,那么可以使用torch.tensor(),这将消耗更多的时间和空间。

1.1.2 tensor的contiguous

顾名思义,连续的。PyTorch中张量的底层实现是使用C中的一维数组(一段连续的内存空间),所以这里的连续是指在内存中是连续的。

使用.view()等方法时,必须先保证这个tensor是连续的(使用.is_contiguous()方法可以判断)。如果tensor在内存中不连续,则需要使用.contiguous()方法,他会重新开辟一块内存空间以保证连续。

.is_contiguous()的直观解释是tensor底层一维数组元素的存储顺序与tensor按行优先一维展开的元素顺序是否一致

PyTorch又提供了.reshape()方法,其实就等价于.contiguous().view()

下面是一个简单的示例:

a = torch.randn(16)
print(a.is_contiguous()) # True
b = a.view(-1, 4)
print(b.is_contiguous()) # True
c = b.transpose(0, 1)
print(c.is_contiguous()) # False

d = c.view(-1)
# Outputs:
RuntimeError: view size is not compatible with input tensor s size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

d = c.contiguous().view(-1)
print(d.shape)
# Outputs:
torch.Size([16])

了解更多,推荐阅读参考资料[1]

1.1.3 inplace操作

PyTorch操作inplace版本都有后缀_,代表就地修改,例如:

y.add_(x)
y.copy_(x)
x.grad.data.zero_()
x.requires_grad_()

1.1.4 tensor在不同设备上移动

使用方法.to(device)可以将tensor在cpu和gpu之间相互移动。

1.2 网络结构

1.2.1 定义网络的几种方法

  1. 继承nn.Module类,定义一些以及.forward()方法,返回值为输出
  2. 使用nn.Sequential(),按顺序地定义每一层

1.2.2 torch.nn的特性

  1. 可使用net.parameters()来查看模型所有的可学习参数,返回一个生成器
  2. torch.nn仅支持一个batch样本的输入(不支持单样本),如果只有单个样本,需要手动添加维度

1.2.3 train()与eval()

  1. 在验证和测试时,需使用model.eval()方法,它可以自动关闭训练时使用的DropoutBatch Norm

2 常用模板代码

2.1 模型的训练及验证

模型的训练

def train_epoch(epoch, model, optimizer, criterion, train_iter):
    model.train()
    for i, batch in enumerate(train_iter, start=1):
        optimizer.zero_grad()
        out = model(inputs)
        loss = criterion(out, gold)
        loss.backward()
        optimizer.step()
        if i % 5 == 0:
            print('Epoch: {}, batch: [{}/{}], Loss: {:.5}'.format(epoch, i, len(train_iter), loss.item()))

模型的验证

def valid_epoch(epoch, model, optimizer, criterion, valid_iter):
    model.eval()
    with torch.no_grad():
        loss_list = []
        for _, batch in enumerate(valid_iter, start=1):
            out = model(inputs)
            loss = criterion(out, gold)
            loss_list.append(loss)
    return sum(loss_list) / len(valid_iter)

2.2 模型的保存和加载

模型的保存

torch.save(model, PATH) # 方法1(不推荐)
torch.save(model.state_dict(), PATH) # 方法2
torch.save({'epoch':epoch, 'model':model.state_dict(), ...}, PATH) # 方法3

模型的加载

# 对应方法1(不推荐)
model = torch.load(PATH)

# 对应方法2
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

# 对应方法3
model = TheModelClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model'])

注意: 在不同设备上保存或加载,需要添加torch.load(PATH, map_location=device)参数,且还需要使用model.to(device)。其中device是希望加载到的设备

参考资料

  1. PyTorch中的contiguous - 栩风
  2. [TorchText]使用 - VanJordan
  3. torchtext入门教程,轻松玩转文本数据处理 - Lee
  4. PyTorch 保存和加载模型 - 鑫鑫淼淼焱焱 (原文 - Matthew Inkawhich)