You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

261 lines
8.0 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5.7 使用重复元素的网络VGG"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4.0\n",
"cuda\n"
]
}
],
"source": [
"import time\n",
"import torch\n",
"from torch import nn, optim\n",
"\n",
"import sys\n",
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"print(torch.__version__)\n",
"print(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.7.1 VGG块"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def vgg_block(num_convs, in_channels, out_channels):\n",
" blk = []\n",
" for i in range(num_convs):\n",
" if i == 0:\n",
" blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))\n",
" else:\n",
" blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))\n",
" blk.append(nn.ReLU())\n",
" blk.append(nn.MaxPool2d(kernel_size=2, stride=2))\n",
" return nn.Sequential(*blk)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.7.2 VGG网络"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"conv_arch = ((1, 1, 64), (1, 64, 128), (2, 128, 256), (2, 256, 512), (2, 512, 512))\n",
"fc_features = 512 * 7 * 7 # 根据卷积层的输出算出来的\n",
"fc_hidden_units = 4096 # 任意"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def vgg(conv_arch, fc_features, fc_hidden_units=4096):\n",
" net = nn.Sequential()\n",
" # 卷积层部分\n",
" for i, (num_convs, in_channels, out_channels) in enumerate(conv_arch):\n",
" net.add_module(\"vgg_block_\" + str(i+1), vgg_block(num_convs, in_channels, out_channels))\n",
" # 全连接层部分\n",
" net.add_module(\"fc\", nn.Sequential(d2l.FlattenLayer(),\n",
" nn.Linear(fc_features, fc_hidden_units),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.5),\n",
" nn.Linear(fc_hidden_units, fc_hidden_units),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.5),\n",
" nn.Linear(fc_hidden_units, 10)\n",
" ))\n",
" return net"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"vgg_block_1 output shape: torch.Size([1, 64, 112, 112])\n",
"vgg_block_2 output shape: torch.Size([1, 128, 56, 56])\n",
"vgg_block_3 output shape: torch.Size([1, 256, 28, 28])\n",
"vgg_block_4 output shape: torch.Size([1, 512, 14, 14])\n",
"vgg_block_5 output shape: torch.Size([1, 512, 7, 7])\n",
"fc output shape: torch.Size([1, 10])\n"
]
}
],
"source": [
"net = vgg(conv_arch, fc_features, fc_hidden_units)\n",
"X = torch.rand(1, 1, 224, 224)\n",
"\n",
"# named_children获取一级子模块及其名字(named_modules会返回所有子模块,包括子模块的子模块)\n",
"for name, blk in net.named_children(): \n",
" X = blk(X)\n",
" print(name, 'output shape: ', X.shape)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (vgg_block_1): Sequential(\n",
" (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU()\n",
" (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (vgg_block_2): Sequential(\n",
" (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU()\n",
" (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (vgg_block_3): Sequential(\n",
" (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU()\n",
" (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (3): ReLU()\n",
" (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (vgg_block_4): Sequential(\n",
" (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU()\n",
" (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (3): ReLU()\n",
" (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (vgg_block_5): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU()\n",
" (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (3): ReLU()\n",
" (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (fc): Sequential(\n",
" (0): FlattenLayer()\n",
" (1): Linear(in_features=3136, out_features=512, bias=True)\n",
" (2): ReLU()\n",
" (3): Dropout(p=0.5)\n",
" (4): Linear(in_features=512, out_features=512, bias=True)\n",
" (5): ReLU()\n",
" (6): Dropout(p=0.5)\n",
" (7): Linear(in_features=512, out_features=10, bias=True)\n",
" )\n",
")\n"
]
}
],
"source": [
"ratio = 8\n",
"small_conv_arch = [(1, 1, 64//ratio), (1, 64//ratio, 128//ratio), (2, 128//ratio, 256//ratio), \n",
" (2, 256//ratio, 512//ratio), (2, 512//ratio, 512//ratio)]\n",
"net = vgg(small_conv_arch, fc_features // ratio, fc_hidden_units // ratio)\n",
"print(net)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.7.3 获取数据和训练模型"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training on cuda\n",
"epoch 1, loss 0.0101, train acc 0.755, test acc 0.859, time 255.9 sec\n",
"epoch 2, loss 0.0051, train acc 0.882, test acc 0.902, time 238.1 sec\n",
"epoch 3, loss 0.0043, train acc 0.900, test acc 0.908, time 225.5 sec\n",
"epoch 4, loss 0.0038, train acc 0.913, test acc 0.914, time 230.3 sec\n",
"epoch 5, loss 0.0035, train acc 0.919, test acc 0.918, time 153.9 sec\n"
]
}
],
"source": [
"batch_size = 64\n",
"# 如出现“out of memory”的报错信息可减小batch_size或resize\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)\n",
"\n",
"lr, num_epochs = 0.001, 5\n",
"optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
"d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}