You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

469 lines
12 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4.1 模型构造"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.2.0\n"
]
}
],
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"print(torch.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4.1.1 继承`Module`类来构造模型"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class MLP(nn.Module):\n",
" # 声明带有模型参数的层,这里声明了两个全连接层\n",
" def __init__(self, **kwargs):\n",
" # 调用MLP父类Block的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数\n",
" # 参数如“模型参数的访问、初始化和共享”一节将介绍的模型参数params\n",
" super(MLP, self).__init__(**kwargs)\n",
" self.hidden = nn.Linear(784, 256) # 隐藏层\n",
" self.act = nn.ReLU()\n",
" self.output = nn.Linear(256, 10) # 输出层\n",
" \n",
"\n",
" # 定义模型的前向计算即如何根据输入x计算返回所需要的模型输出\n",
" def forward(self, x):\n",
" a = self.act(self.hidden(x))\n",
" return self.output(a)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MLP(\n",
" (hidden): Linear(in_features=784, out_features=256, bias=True)\n",
" (act): ReLU()\n",
" (output): Linear(in_features=256, out_features=10, bias=True)\n",
")\n"
]
},
{
"data": {
"text/plain": [
"tensor([[ 0.0234, -0.2646, -0.1168, -0.2127, 0.0884, -0.0456, 0.0811, 0.0297,\n",
" 0.2032, 0.1364],\n",
" [ 0.1479, -0.1545, -0.0265, -0.2119, -0.0543, -0.0086, 0.0902, -0.1017,\n",
" 0.1504, 0.1144]], grad_fn=<AddmmBackward>)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = torch.rand(2, 784)\n",
"net = MLP()\n",
"print(net)\n",
"net(X)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4.1.2 `Module`的子类\n",
"### 4.1.2.1 `Sequential`类"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class MySequential(nn.Module):\n",
" from collections import OrderedDict\n",
" def __init__(self, *args):\n",
" super(MySequential, self).__init__()\n",
" if len(args) == 1 and isinstance(args[0], OrderedDict): # 如果传入的是一个OrderedDict\n",
" for key, module in args[0].items():\n",
" self.add_module(key, module) # add_module方法会将module添加进self._modules(一个OrderedDict)\n",
" else: # 传入的是一些Module\n",
" for idx, module in enumerate(args):\n",
" self.add_module(str(idx), module)\n",
" def forward(self, input):\n",
" # self._modules返回一个 OrderedDict保证会按照成员添加时的顺序遍历成\n",
" for module in self._modules.values():\n",
" input = module(input)\n",
" return input"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MySequential(\n",
" (0): Linear(in_features=784, out_features=256, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=256, out_features=10, bias=True)\n",
")\n"
]
},
{
"data": {
"text/plain": [
"tensor([[ 0.1273, 0.1642, -0.1060, 0.1401, 0.0609, -0.0199, -0.0140, -0.0588,\n",
" 0.1765, -0.1296],\n",
" [ 0.0267, 0.1670, -0.0626, 0.0744, 0.0574, 0.0413, 0.1313, -0.1479,\n",
" 0.0932, -0.0615]], grad_fn=<AddmmBackward>)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net = MySequential(\n",
" nn.Linear(784, 256),\n",
" nn.ReLU(),\n",
" nn.Linear(256, 10), \n",
" )\n",
"print(net)\n",
"net(X)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.1.2.2 `ModuleList`类"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Linear(in_features=256, out_features=10, bias=True)\n",
"ModuleList(\n",
" (0): Linear(in_features=784, out_features=256, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=256, out_features=10, bias=True)\n",
")\n"
]
}
],
"source": [
"net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])\n",
"net.append(nn.Linear(256, 10)) # # 类似List的append操作\n",
"print(net[-1]) # 类似List的索引访问\n",
"print(net)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# net(torch.zeros(1, 784)) # 会报NotImplementedError"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class MyModule(nn.Module):\n",
" def __init__(self):\n",
" super(MyModule, self).__init__()\n",
" self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])\n",
"\n",
" def forward(self, x):\n",
" # ModuleList can act as an iterable, or be indexed using ints\n",
" for i, l in enumerate(self.linears):\n",
" x = self.linears[i // 2](x) + l(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"net1:\n",
"torch.Size([10, 10])\n",
"torch.Size([10])\n",
"net2:\n"
]
}
],
"source": [
"class Module_ModuleList(nn.Module):\n",
" def __init__(self):\n",
" super(Module_ModuleList, self).__init__()\n",
" self.linears = nn.ModuleList([nn.Linear(10, 10)])\n",
" \n",
"class Module_List(nn.Module):\n",
" def __init__(self):\n",
" super(Module_List, self).__init__()\n",
" self.linears = [nn.Linear(10, 10)]\n",
"\n",
"net1 = Module_ModuleList()\n",
"net2 = Module_List()\n",
"\n",
"print(\"net1:\")\n",
"for p in net1.parameters():\n",
" print(p.size())\n",
"\n",
"print(\"net2:\")\n",
"for p in net2.parameters():\n",
" print(p)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.1.2.3 `ModuleDict`类"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Linear(in_features=784, out_features=256, bias=True)\n",
"Linear(in_features=256, out_features=10, bias=True)\n",
"ModuleDict(\n",
" (act): ReLU()\n",
" (linear): Linear(in_features=784, out_features=256, bias=True)\n",
" (output): Linear(in_features=256, out_features=10, bias=True)\n",
")\n"
]
}
],
"source": [
"net = nn.ModuleDict({\n",
" 'linear': nn.Linear(784, 256),\n",
" 'act': nn.ReLU(),\n",
"})\n",
"net['output'] = nn.Linear(256, 10) # 添加\n",
"print(net['linear']) # 访问\n",
"print(net.output)\n",
"print(net)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# net(torch.zeros(1, 784)) # 会报NotImplementedError"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4.1.3 构造复杂的模型"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class FancyMLP(nn.Module):\n",
" def __init__(self, **kwargs):\n",
" super(FancyMLP, self).__init__(**kwargs)\n",
" \n",
" self.rand_weight = torch.rand((20, 20), requires_grad=False) # 不可训练参数(常数参数)\n",
" self.linear = nn.Linear(20, 20)\n",
"\n",
" def forward(self, x):\n",
" x = self.linear(x)\n",
" # 使用创建的常数参数以及nn.functional中的relu函数和mm函数\n",
" x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1)\n",
" \n",
" # 复用全连接层。等价于两个全连接层共享参数\n",
" x = self.linear(x)\n",
" # 控制流这里我们需要调用item函数来返回标量进行比较\n",
" while x.norm().item() > 1:\n",
" x /= 2\n",
" if x.norm().item() < 0.8:\n",
" x *= 10\n",
" return x.sum()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FancyMLP(\n",
" (linear): Linear(in_features=20, out_features=20, bias=True)\n",
")\n"
]
},
{
"data": {
"text/plain": [
"tensor(0.8907, grad_fn=<SumBackward0>)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = torch.rand(2, 20)\n",
"net = FancyMLP()\n",
"print(net)\n",
"net(X)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (0): NestMLP(\n",
" (net): Sequential(\n",
" (0): Linear(in_features=40, out_features=30, bias=True)\n",
" (1): ReLU()\n",
" )\n",
" )\n",
" (1): Linear(in_features=30, out_features=20, bias=True)\n",
" (2): FancyMLP(\n",
" (linear): Linear(in_features=20, out_features=20, bias=True)\n",
" )\n",
")\n"
]
},
{
"data": {
"text/plain": [
"tensor(-0.4605, grad_fn=<SumBackward0>)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class NestMLP(nn.Module):\n",
" def __init__(self, **kwargs):\n",
" super(NestMLP, self).__init__(**kwargs)\n",
" self.net = nn.Sequential(nn.Linear(40, 30), nn.ReLU()) \n",
"\n",
" def forward(self, x):\n",
" return self.net(x)\n",
"\n",
"net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())\n",
"\n",
"X = torch.rand(2, 40)\n",
"print(net)\n",
"net(X)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}