You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

178 lines
4.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5.8 网络中的网络NiN"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4.0\n",
"cuda\n"
]
}
],
"source": [
"import time\n",
"import torch\n",
"from torch import nn, optim\n",
"\n",
"import sys\n",
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"print(torch.__version__)\n",
"print(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.8.1 NiN块"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def nin_block(in_channels, out_channels, kernel_size, stride, padding):\n",
" blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),\n",
" nn.ReLU(),\n",
" nn.Conv2d(out_channels, out_channels, kernel_size=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(out_channels, out_channels, kernel_size=1),\n",
" nn.ReLU())\n",
" return blk"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.8.2 NiN模型"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"net = nn.Sequential(\n",
" nin_block(1, 96, kernel_size=11, stride=4, padding=0),\n",
" nn.MaxPool2d(kernel_size=3, stride=2),\n",
" nin_block(96, 256, kernel_size=5, stride=1, padding=2),\n",
" nn.MaxPool2d(kernel_size=3, stride=2),\n",
" nin_block(256, 384, kernel_size=3, stride=1, padding=1),\n",
" nn.MaxPool2d(kernel_size=3, stride=2), \n",
" nn.Dropout(0.5),\n",
" # 标签类别数是10\n",
" nin_block(384, 10, kernel_size=3, stride=1, padding=1),\n",
" # 全局平均池化层可通过将窗口形状设置成输入的高和宽实现\n",
" nn.AvgPool2d(kernel_size=5),\n",
" # 将四维的输出转成二维的输出,其形状为(批量大小, 10)\n",
" d2l.FlattenLayer())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 output shape: torch.Size([1, 96, 54, 54])\n",
"1 output shape: torch.Size([1, 96, 26, 26])\n",
"2 output shape: torch.Size([1, 256, 26, 26])\n",
"3 output shape: torch.Size([1, 256, 12, 12])\n",
"4 output shape: torch.Size([1, 384, 12, 12])\n",
"5 output shape: torch.Size([1, 384, 5, 5])\n",
"6 output shape: torch.Size([1, 384, 5, 5])\n",
"7 output shape: torch.Size([1, 10, 5, 5])\n",
"8 output shape: torch.Size([1, 10, 1, 1])\n",
"9 output shape: torch.Size([1, 10])\n"
]
}
],
"source": [
"X = torch.rand(1, 1, 224, 224)\n",
"\n",
"for name, blk in net.named_children(): \n",
" X = blk(X)\n",
" print(name, 'output shape: ', X.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.8.3 获取数据和训练模型"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training on cuda\n",
"epoch 1, loss 0.0101, train acc 0.513, test acc 0.734, time 260.9 sec\n",
"epoch 2, loss 0.0050, train acc 0.763, test acc 0.754, time 175.1 sec\n",
"epoch 3, loss 0.0041, train acc 0.808, test acc 0.826, time 151.0 sec\n",
"epoch 4, loss 0.0037, train acc 0.828, test acc 0.827, time 151.0 sec\n",
"epoch 5, loss 0.0034, train acc 0.839, test acc 0.831, time 151.0 sec\n"
]
}
],
"source": [
"batch_size = 128\n",
"# 如出现“out of memory”的报错信息可减小batch_size或resize\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)\n",
"\n",
"lr, num_epochs = 0.002, 5\n",
"optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
"d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}