You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

291 lines
8.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5.6 深度卷积神经网络AlexNet"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2019-03-19T07:36:45.657048Z",
"start_time": "2019-03-19T07:36:45.285668Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4.0\n",
"0.2.1\n",
"cuda\n"
]
}
],
"source": [
"import time\n",
"import torch\n",
"from torch import nn, optim\n",
"import torchvision\n",
"\n",
"import sys\n",
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"print(torch.__version__)\n",
"print(torchvision.__version__)\n",
"print(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.6.2 AlexNet"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2019-03-19T07:36:45.703036Z",
"start_time": "2019-03-19T07:36:45.658231Z"
}
},
"outputs": [],
"source": [
"class AlexNet(nn.Module):\n",
" def __init__(self):\n",
" super(AlexNet, self).__init__()\n",
" self.conv = nn.Sequential(\n",
" nn.Conv2d(1, 96, 11, 4), # in_channels, out_channels, kernel_size, stride, padding\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(3, 2), # kernel_size, stride\n",
" # 减小卷积窗口使用填充为2来使得输入与输出的高和宽一致且增大输出通道数\n",
" nn.Conv2d(96, 256, 5, 1, 2),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(3, 2),\n",
" # 连续3个卷积层且使用更小的卷积窗口。除了最后的卷积层外进一步增大了输出通道数。\n",
" # 前两个卷积层后不使用池化层来减小输入的高和宽\n",
" nn.Conv2d(256, 384, 3, 1, 1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(384, 384, 3, 1, 1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(384, 256, 3, 1, 1),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(3, 2)\n",
" )\n",
" # 这里全连接层的输出个数比LeNet中的大数倍。使用丢弃层来缓解过拟合\n",
" self.fc = nn.Sequential(\n",
" nn.Linear(256*5*5, 4096),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.5),\n",
" nn.Linear(4096, 4096),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.5),\n",
" # 输出层。由于这里使用Fashion-MNIST所以用类别数为10而非论文中的1000\n",
" nn.Linear(4096, 10),\n",
" )\n",
"\n",
" def forward(self, img):\n",
" feature = self.conv(img)\n",
" output = self.fc(feature.view(img.shape[0], -1))\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2019-03-19T07:36:46.053598Z",
"start_time": "2019-03-19T07:36:45.704356Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AlexNet(\n",
" (conv): Sequential(\n",
" (0): Conv2d(1, 96, kernel_size=(11, 11), stride=(4, 4))\n",
" (1): ReLU()\n",
" (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (3): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
" (4): ReLU()\n",
" (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (6): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (7): ReLU()\n",
" (8): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (9): ReLU()\n",
" (10): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (11): ReLU()\n",
" (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (fc): Sequential(\n",
" (0): Linear(in_features=6400, out_features=4096, bias=True)\n",
" (1): ReLU()\n",
" (2): Dropout(p=0.5)\n",
" (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
" (4): ReLU()\n",
" (5): Dropout(p=0.5)\n",
" (6): Linear(in_features=4096, out_features=10, bias=True)\n",
" )\n",
")\n"
]
}
],
"source": [
"net = AlexNet()\n",
"print(net)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.6.3 读取数据"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2019-03-19T07:36:46.066761Z",
"start_time": "2019-03-19T07:36:46.054928Z"
}
},
"outputs": [],
"source": [
"# 本函数已保存在d2lzh_pytorch包中方便以后使用\n",
"def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):\n",
" \"\"\"Download the fashion mnist dataset and then load into memory.\"\"\"\n",
" trans = []\n",
" if resize:\n",
" trans.append(torchvision.transforms.Resize(size=resize))\n",
" trans.append(torchvision.transforms.ToTensor())\n",
" \n",
" transform = torchvision.transforms.Compose(trans)\n",
" mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)\n",
" mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)\n",
"\n",
" train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)\n",
" test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=4)\n",
"\n",
" return train_iter, test_iter"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2019-03-19T07:36:46.091524Z",
"start_time": "2019-03-19T07:36:46.067835Z"
}
},
"outputs": [],
"source": [
"batch_size = 128\n",
"# 如出现“out of memory”的报错信息可减小batch_size或resize\n",
"train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.6.4 训练"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2019-03-19T07:36:47.850402Z",
"start_time": "2019-03-19T07:36:46.092485Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training on cuda\n",
"epoch 1, loss 0.0047, train acc 0.770, test acc 0.865, time 128.3 sec\n",
"epoch 2, loss 0.0025, train acc 0.879, test acc 0.889, time 128.8 sec\n",
"epoch 3, loss 0.0022, train acc 0.898, test acc 0.901, time 130.4 sec\n",
"epoch 4, loss 0.0019, train acc 0.908, test acc 0.900, time 131.4 sec\n",
"epoch 5, loss 0.0018, train acc 0.913, test acc 0.902, time 129.9 sec\n"
]
}
],
"source": [
"lr, num_epochs = 0.001, 5\n",
"optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
"d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}