You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

262 lines
7.0 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5.11 残差网络ResNet"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4.0\n",
"cuda\n"
]
}
],
"source": [
"import time\n",
"import torch\n",
"from torch import nn, optim\n",
"import torch.nn.functional as F\n",
"\n",
"import sys\n",
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"print(torch.__version__)\n",
"print(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.11.2 残差块"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class Residual(nn.Module): # 本类已保存在d2lzh_pytorch包中方便以后使用\n",
" def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):\n",
" super(Residual, self).__init__()\n",
" self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)\n",
" self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)\n",
" if use_1x1conv:\n",
" self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)\n",
" else:\n",
" self.conv3 = None\n",
" self.bn1 = nn.BatchNorm2d(out_channels)\n",
" self.bn2 = nn.BatchNorm2d(out_channels)\n",
"\n",
" def forward(self, X):\n",
" Y = F.relu(self.bn1(self.conv1(X)))\n",
" Y = self.bn2(self.conv2(Y))\n",
" if self.conv3:\n",
" X = self.conv3(X)\n",
" return F.relu(Y + X)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 3, 6, 6])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"blk = Residual(3, 3)\n",
"X = torch.rand((4, 3, 6, 6))\n",
"blk(X).shape"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 6, 3, 3])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"blk = Residual(3, 6, use_1x1conv=True, stride=2)\n",
"blk(X).shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.11.2 ResNet模型"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"net = nn.Sequential(\n",
" nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n",
" nn.BatchNorm2d(64), \n",
" nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=3, stride=2, padding=1))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def resnet_block(in_channels, out_channels, num_residuals, first_block=False):\n",
" if first_block:\n",
" assert in_channels == out_channels # 第一个模块的通道数同输入通道数一致\n",
" blk = []\n",
" for i in range(num_residuals):\n",
" if i == 0 and not first_block:\n",
" blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))\n",
" else:\n",
" blk.append(Residual(out_channels, out_channels))\n",
" return nn.Sequential(*blk)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"net.add_module(\"resnet_block1\", resnet_block(64, 64, 2, first_block=True))\n",
"net.add_module(\"resnet_block2\", resnet_block(64, 128, 2))\n",
"net.add_module(\"resnet_block3\", resnet_block(128, 256, 2))\n",
"net.add_module(\"resnet_block4\", resnet_block(256, 512, 2))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"net.add_module(\"global_avg_pool\", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, 512, 1, 1)\n",
"net.add_module(\"fc\", nn.Sequential(d2l.FlattenLayer(), nn.Linear(512, 10))) "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 output shape:\t torch.Size([1, 64, 112, 112])\n",
"1 output shape:\t torch.Size([1, 64, 112, 112])\n",
"2 output shape:\t torch.Size([1, 64, 112, 112])\n",
"3 output shape:\t torch.Size([1, 64, 56, 56])\n",
"resnet_block1 output shape:\t torch.Size([1, 64, 56, 56])\n",
"resnet_block2 output shape:\t torch.Size([1, 128, 28, 28])\n",
"resnet_block3 output shape:\t torch.Size([1, 256, 14, 14])\n",
"resnet_block4 output shape:\t torch.Size([1, 512, 7, 7])\n",
"global_avg_pool output shape:\t torch.Size([1, 512, 1, 1])\n",
"fc output shape:\t torch.Size([1, 10])\n"
]
}
],
"source": [
"X = torch.rand((1, 1, 224, 224))\n",
"for name, layer in net.named_children():\n",
" X = layer(X)\n",
" print(name, ' output shape:\\t', X.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.11.3 获取数据和训练模型"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training on cuda\n",
"epoch 1, loss 0.0015, train acc 0.853, test acc 0.885, time 31.0 sec\n",
"epoch 2, loss 0.0010, train acc 0.910, test acc 0.899, time 31.8 sec\n",
"epoch 3, loss 0.0008, train acc 0.926, test acc 0.911, time 31.6 sec\n",
"epoch 4, loss 0.0007, train acc 0.936, test acc 0.916, time 31.8 sec\n",
"epoch 5, loss 0.0006, train acc 0.944, test acc 0.926, time 31.5 sec\n"
]
}
],
"source": [
"batch_size = 256\n",
"# 如出现“out of memory”的报错信息可减小batch_size或resize\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n",
"\n",
"lr, num_epochs = 0.001, 5\n",
"optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
"d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}