在深度学习中,灵活地构建模型、高效地管理参数、以及充分利用硬件资源是训练高性能模型的关键。
PyTorch
以其动态图机制和模块化设计,为开发者提供了极大的自由度。
本文通过四个核心主题带你深入理解
PyTorch
的模型构建机制,并辅以完整可运行的代码示例。
一、模型构建
PyTorch
中最基本的构建单元是nn.Module。
我们可以使用nn.Sequential快速搭建线性堆叠的网络:
importtorchfromtorchimportnnfromtorch.nnimportfunctionalasFnet=nn.Sequential(nn.Linear(20,256),nn.ReLU(),nn.Linear(256,10))X=torch.rand(2,20)print(net(X))#
输出形状:
10]
自定义模块
真实场景往往需要更复杂的逻辑。
为此,我们可以通过继承nn.Module自定义前向传播:
classMLP(nn.Module):def__init__(self):super().__init__()self.hidden=nn.Linear(20,256)self.out=nn.Linear(256,10)#`forward`
代码,包括控制流、张量操作甚至调用非可微分函数(只要不参与梯度计算)
defforward(self,X):returnself.out(F.relu(self.hidden(X)))net=MLP()print(net(X))进一步地,我们可以创建自己的顺序容器类,了解
PyTorch
如何管理子模块:
classMySequential(nn.Module):def__init__(self,*args):super().__init__()foridx,moduleinenumerate(args):self._modules[str(idx)]=module#注册到
_modules
defforward(self,X):forblockinself._modules.values():X=block(X)returnX嵌套模块
现实中


