{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 4.1 模型构造" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.2.0\n" ] } ], "source": [ "import torch\n", "from torch import nn\n", "\n", "print(torch.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.1.1 继承`Module`类来构造模型" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class MLP(nn.Module):\n", " # 声明带有模型参数的层,这里声明了两个全连接层\n", " def __init__(self, **kwargs):\n", " # 调用MLP父类Block的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数\n", " # 参数,如“模型参数的访问、初始化和共享”一节将介绍的模型参数params\n", " super(MLP, self).__init__(**kwargs)\n", " self.hidden = nn.Linear(784, 256) # 隐藏层\n", " self.act = nn.ReLU()\n", " self.output = nn.Linear(256, 10) # 输出层\n", " \n", "\n", " # 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出\n", " def forward(self, x):\n", " a = self.act(self.hidden(x))\n", " return self.output(a)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLP(\n", " (hidden): Linear(in_features=784, out_features=256, bias=True)\n", " (act): ReLU()\n", " (output): Linear(in_features=256, out_features=10, bias=True)\n", ")\n" ] }, { "data": { "text/plain": [ "tensor([[ 0.0234, -0.2646, -0.1168, -0.2127, 0.0884, -0.0456, 0.0811, 0.0297,\n", " 0.2032, 0.1364],\n", " [ 0.1479, -0.1545, -0.0265, -0.2119, -0.0543, -0.0086, 0.0902, -0.1017,\n", " 0.1504, 0.1144]], grad_fn=)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = torch.rand(2, 784)\n", "net = MLP()\n", "print(net)\n", "net(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.1.2 `Module`的子类\n", "### 4.1.2.1 `Sequential`类" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class MySequential(nn.Module):\n", " from collections import OrderedDict\n", " def __init__(self, *args):\n", " super(MySequential, self).__init__()\n", " if len(args) == 1 and isinstance(args[0], OrderedDict): # 如果传入的是一个OrderedDict\n", " for key, module in args[0].items():\n", " self.add_module(key, module) # add_module方法会将module添加进self._modules(一个OrderedDict)\n", " else: # 传入的是一些Module\n", " for idx, module in enumerate(args):\n", " self.add_module(str(idx), module)\n", " def forward(self, input):\n", " # self._modules返回一个 OrderedDict,保证会按照成员添加时的顺序遍历成\n", " for module in self._modules.values():\n", " input = module(input)\n", " return input" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MySequential(\n", " (0): Linear(in_features=784, out_features=256, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=256, out_features=10, bias=True)\n", ")\n" ] }, { "data": { "text/plain": [ "tensor([[ 0.1273, 0.1642, -0.1060, 0.1401, 0.0609, -0.0199, -0.0140, -0.0588,\n", " 0.1765, -0.1296],\n", " [ 0.0267, 0.1670, -0.0626, 0.0744, 0.0574, 0.0413, 0.1313, -0.1479,\n", " 0.0932, -0.0615]], grad_fn=)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net = MySequential(\n", " nn.Linear(784, 256),\n", " nn.ReLU(),\n", " nn.Linear(256, 10), \n", " )\n", "print(net)\n", "net(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.1.2.2 `ModuleList`类" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Linear(in_features=256, out_features=10, bias=True)\n", "ModuleList(\n", " (0): Linear(in_features=784, out_features=256, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=256, out_features=10, bias=True)\n", ")\n" ] } ], "source": [ "net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])\n", "net.append(nn.Linear(256, 10)) # # 类似List的append操作\n", "print(net[-1]) # 类似List的索引访问\n", "print(net)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# net(torch.zeros(1, 784)) # 会报NotImplementedError" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class MyModule(nn.Module):\n", " def __init__(self):\n", " super(MyModule, self).__init__()\n", " self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])\n", "\n", " def forward(self, x):\n", " # ModuleList can act as an iterable, or be indexed using ints\n", " for i, l in enumerate(self.linears):\n", " x = self.linears[i // 2](x) + l(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "net1:\n", "torch.Size([10, 10])\n", "torch.Size([10])\n", "net2:\n" ] } ], "source": [ "class Module_ModuleList(nn.Module):\n", " def __init__(self):\n", " super(Module_ModuleList, self).__init__()\n", " self.linears = nn.ModuleList([nn.Linear(10, 10)])\n", " \n", "class Module_List(nn.Module):\n", " def __init__(self):\n", " super(Module_List, self).__init__()\n", " self.linears = [nn.Linear(10, 10)]\n", "\n", "net1 = Module_ModuleList()\n", "net2 = Module_List()\n", "\n", "print(\"net1:\")\n", "for p in net1.parameters():\n", " print(p.size())\n", "\n", "print(\"net2:\")\n", "for p in net2.parameters():\n", " print(p)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.1.2.3 `ModuleDict`类" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Linear(in_features=784, out_features=256, bias=True)\n", "Linear(in_features=256, out_features=10, bias=True)\n", "ModuleDict(\n", " (act): ReLU()\n", " (linear): Linear(in_features=784, out_features=256, bias=True)\n", " (output): Linear(in_features=256, out_features=10, bias=True)\n", ")\n" ] } ], "source": [ "net = nn.ModuleDict({\n", " 'linear': nn.Linear(784, 256),\n", " 'act': nn.ReLU(),\n", "})\n", "net['output'] = nn.Linear(256, 10) # 添加\n", "print(net['linear']) # 访问\n", "print(net.output)\n", "print(net)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# net(torch.zeros(1, 784)) # 会报NotImplementedError" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.1.3 构造复杂的模型" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class FancyMLP(nn.Module):\n", " def __init__(self, **kwargs):\n", " super(FancyMLP, self).__init__(**kwargs)\n", " \n", " self.rand_weight = torch.rand((20, 20), requires_grad=False) # 不可训练参数(常数参数)\n", " self.linear = nn.Linear(20, 20)\n", "\n", " def forward(self, x):\n", " x = self.linear(x)\n", " # 使用创建的常数参数,以及nn.functional中的relu函数和mm函数\n", " x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1)\n", " \n", " # 复用全连接层。等价于两个全连接层共享参数\n", " x = self.linear(x)\n", " # 控制流,这里我们需要调用item函数来返回标量进行比较\n", " while x.norm().item() > 1:\n", " x /= 2\n", " if x.norm().item() < 0.8:\n", " x *= 10\n", " return x.sum()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "FancyMLP(\n", " (linear): Linear(in_features=20, out_features=20, bias=True)\n", ")\n" ] }, { "data": { "text/plain": [ "tensor(0.8907, grad_fn=)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = torch.rand(2, 20)\n", "net = FancyMLP()\n", "print(net)\n", "net(X)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential(\n", " (0): NestMLP(\n", " (net): Sequential(\n", " (0): Linear(in_features=40, out_features=30, bias=True)\n", " (1): ReLU()\n", " )\n", " )\n", " (1): Linear(in_features=30, out_features=20, bias=True)\n", " (2): FancyMLP(\n", " (linear): Linear(in_features=20, out_features=20, bias=True)\n", " )\n", ")\n" ] }, { "data": { "text/plain": [ "tensor(-0.4605, grad_fn=)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class NestMLP(nn.Module):\n", " def __init__(self, **kwargs):\n", " super(NestMLP, self).__init__(**kwargs)\n", " self.net = nn.Sequential(nn.Linear(40, 30), nn.ReLU()) \n", "\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", "net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())\n", "\n", "X = torch.rand(2, 40)\n", "print(net)\n", "net(X)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.2" } }, "nbformat": 4, "nbformat_minor": 2 }