You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

292 lines
7.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5.12 稠密连接网络DenseNet"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4.0\n",
"cuda\n"
]
}
],
"source": [
"import time\n",
"import torch\n",
"from torch import nn, optim\n",
"import torch.nn.functional as F\n",
"\n",
"import sys\n",
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"print(torch.__version__)\n",
"print(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.12.1 稠密块"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def conv_block(in_channels, out_channels):\n",
" blk = nn.Sequential(nn.BatchNorm2d(in_channels), \n",
" nn.ReLU(),\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))\n",
" return blk"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class DenseBlock(nn.Module):\n",
" def __init__(self, num_convs, in_channels, out_channels):\n",
" super(DenseBlock, self).__init__()\n",
" net = []\n",
" for i in range(num_convs):\n",
" in_c = in_channels + i * out_channels\n",
" net.append(conv_block(in_c, out_channels))\n",
" self.net = nn.ModuleList(net)\n",
" self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数\n",
"\n",
" def forward(self, X):\n",
" for blk in self.net:\n",
" Y = blk(X)\n",
" X = torch.cat((X, Y), dim=1) # 在通道维上将输入和输出连结\n",
" return X"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 23, 8, 8])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"blk = DenseBlock(2, 3, 10)\n",
"X = torch.rand(4, 3, 8, 8)\n",
"Y = blk(X)\n",
"Y.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.12.2 过渡层"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def transition_block(in_channels, out_channels):\n",
" blk = nn.Sequential(\n",
" nn.BatchNorm2d(in_channels), \n",
" nn.ReLU(),\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=1),\n",
" nn.AvgPool2d(kernel_size=2, stride=2))\n",
" return blk"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 10, 4, 4])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"blk = transition_block(23, 10)\n",
"blk(Y).shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.12.3 DenseNet模型"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"net = nn.Sequential(\n",
" nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n",
" nn.BatchNorm2d(64), \n",
" nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=3, stride=2, padding=1))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"num_channels, growth_rate = 64, 32 # num_channels为当前的通道数\n",
"num_convs_in_dense_blocks = [4, 4, 4, 4]\n",
"\n",
"for i, num_convs in enumerate(num_convs_in_dense_blocks):\n",
" DB = DenseBlock(num_convs, num_channels, growth_rate)\n",
" net.add_module(\"DenseBlosk_%d\" % i, DB)\n",
" # 上一个稠密块的输出通道数\n",
" num_channels = DB.out_channels\n",
" # 在稠密块之间加入通道数减半的过渡层\n",
" if i != len(num_convs_in_dense_blocks) - 1:\n",
" net.add_module(\"transition_block_%d\" % i, transition_block(num_channels, num_channels // 2))\n",
" num_channels = num_channels // 2"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"net.add_module(\"BN\", nn.BatchNorm2d(num_channels))\n",
"net.add_module(\"relu\", nn.ReLU())\n",
"net.add_module(\"global_avg_pool\", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)\n",
"net.add_module(\"fc\", nn.Sequential(d2l.FlattenLayer(), nn.Linear(num_channels, 10))) "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 output shape:\t torch.Size([1, 64, 48, 48])\n",
"1 output shape:\t torch.Size([1, 64, 48, 48])\n",
"2 output shape:\t torch.Size([1, 64, 48, 48])\n",
"3 output shape:\t torch.Size([1, 64, 24, 24])\n",
"DenseBlosk_0 output shape:\t torch.Size([1, 192, 24, 24])\n",
"transition_block_0 output shape:\t torch.Size([1, 96, 12, 12])\n",
"DenseBlosk_1 output shape:\t torch.Size([1, 224, 12, 12])\n",
"transition_block_1 output shape:\t torch.Size([1, 112, 6, 6])\n",
"DenseBlosk_2 output shape:\t torch.Size([1, 240, 6, 6])\n",
"transition_block_2 output shape:\t torch.Size([1, 120, 3, 3])\n",
"DenseBlosk_3 output shape:\t torch.Size([1, 248, 3, 3])\n",
"BN output shape:\t torch.Size([1, 248, 3, 3])\n",
"relu output shape:\t torch.Size([1, 248, 3, 3])\n",
"global_avg_pool output shape:\t torch.Size([1, 248, 1, 1])\n",
"fc output shape:\t torch.Size([1, 10])\n"
]
}
],
"source": [
"X = torch.rand((1, 1, 96, 96))\n",
"for name, layer in net.named_children():\n",
" X = layer(X)\n",
" print(name, ' output shape:\\t', X.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.12.4 获取数据并训练模型"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training on cuda\n",
"epoch 1, loss 0.0020, train acc 0.834, test acc 0.749, time 27.7 sec\n",
"epoch 2, loss 0.0011, train acc 0.900, test acc 0.824, time 25.5 sec\n",
"epoch 3, loss 0.0009, train acc 0.913, test acc 0.839, time 23.8 sec\n",
"epoch 4, loss 0.0008, train acc 0.921, test acc 0.889, time 24.9 sec\n",
"epoch 5, loss 0.0008, train acc 0.929, test acc 0.884, time 24.3 sec\n"
]
}
],
"source": [
"batch_size = 256\n",
"# 如出现“out of memory”的报错信息可减小batch_size或resize\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n",
"\n",
"lr, num_epochs = 0.001, 5\n",
"optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
"d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}