You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

307 lines
8.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5.5 卷积神经网络(LeNet)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2019-05-29T13:57:37.383972Z",
"start_time": "2019-05-29T13:57:34.520559Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0.0\n",
"cuda\n"
]
}
],
"source": [
"import os\n",
"import time\n",
"import torch\n",
"from torch import nn, optim\n",
"\n",
"import sys\n",
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"print(torch.__version__)\n",
"print(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.5.1 LeNet模型 "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2019-05-29T13:57:37.394997Z",
"start_time": "2019-05-29T13:57:37.386720Z"
}
},
"outputs": [],
"source": [
"class LeNet(nn.Module):\n",
" def __init__(self):\n",
" super(LeNet, self).__init__()\n",
" self.conv = nn.Sequential(\n",
" nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size\n",
" nn.Sigmoid(),\n",
" nn.MaxPool2d(2, 2), # kernel_size, stride\n",
" nn.Conv2d(6, 16, 5),\n",
" nn.Sigmoid(),\n",
" nn.MaxPool2d(2, 2)\n",
" )\n",
" self.fc = nn.Sequential(\n",
" nn.Linear(16*4*4, 120),\n",
" nn.Sigmoid(),\n",
" nn.Linear(120, 84),\n",
" nn.Sigmoid(),\n",
" nn.Linear(84, 10)\n",
" )\n",
"\n",
" def forward(self, img):\n",
" feature = self.conv(img)\n",
" output = self.fc(feature.view(img.shape[0], -1))\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2019-05-29T13:57:37.450484Z",
"start_time": "2019-05-29T13:57:37.397357Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LeNet(\n",
" (conv): Sequential(\n",
" (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n",
" (1): Sigmoid()\n",
" (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
" (4): Sigmoid()\n",
" (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (fc): Sequential(\n",
" (0): Linear(in_features=256, out_features=120, bias=True)\n",
" (1): Sigmoid()\n",
" (2): Linear(in_features=120, out_features=84, bias=True)\n",
" (3): Sigmoid()\n",
" (4): Linear(in_features=84, out_features=10, bias=True)\n",
" )\n",
")\n"
]
}
],
"source": [
"net = LeNet()\n",
"print(net)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.5.2 获取数据和训练模型"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2019-05-29T13:57:38.432567Z",
"start_time": "2019-05-29T13:57:37.452521Z"
}
},
"outputs": [],
"source": [
"batch_size = 256\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2019-05-29T13:57:38.442887Z",
"start_time": "2019-05-29T13:57:38.435111Z"
}
},
"outputs": [],
"source": [
"# 本函数已保存在d2lzh_pytorch包中方便以后使用。该函数将被逐步改进它的完整实现将在“图像增广”一节中描述\n",
"def evaluate_accuracy(data_iter, net, device=None):\n",
" if device is None and isinstance(net, torch.nn.Module):\n",
" # 如果没指定device就使用net的device\n",
" device = list(net.parameters())[0].device\n",
" acc_sum, n = 0.0, 0\n",
" with torch.no_grad():\n",
" for X, y in data_iter:\n",
" if isinstance(net, torch.nn.Module):\n",
" net.eval() # 评估模式, 这会关闭dropout\n",
" acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()\n",
" net.train() # 改回训练模式\n",
" else: # 自定义的模型, 3.13节之后不会用到, 不考虑GPU\n",
" if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数\n",
" # 将is_training设置成False\n",
" acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() \n",
" else:\n",
" acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() \n",
" n += y.shape[0]\n",
" return acc_sum / n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2019-05-29T13:57:38.453480Z",
"start_time": "2019-05-29T13:57:38.445655Z"
}
},
"outputs": [],
"source": [
"# 本函数已保存在d2lzh_pytorch包中方便以后使用\n",
"def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):\n",
" net = net.to(device)\n",
" print(\"training on \", device)\n",
" loss = torch.nn.CrossEntropyLoss()\n",
" batch_count = 0\n",
" for epoch in range(num_epochs):\n",
" train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()\n",
" for X, y in train_iter:\n",
" X = X.to(device)\n",
" y = y.to(device)\n",
" y_hat = net(X)\n",
" l = loss(y_hat, y)\n",
" optimizer.zero_grad()\n",
" l.backward()\n",
" optimizer.step()\n",
" train_l_sum += l.cpu().item()\n",
" train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()\n",
" n += y.shape[0]\n",
" batch_count += 1\n",
" test_acc = evaluate_accuracy(test_iter, net)\n",
" print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'\n",
" % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2019-05-29T13:58:00.333237Z",
"start_time": "2019-05-29T13:57:38.456012Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training on cuda\n",
"epoch 1, loss 1.7885, train acc 0.337, test acc 0.584, time 2.4 sec\n",
"epoch 2, loss 0.4793, train acc 0.614, test acc 0.666, time 2.3 sec\n",
"epoch 3, loss 0.2637, train acc 0.704, test acc 0.720, time 2.3 sec\n",
"epoch 4, loss 0.1747, train acc 0.734, test acc 0.740, time 2.2 sec\n",
"epoch 5, loss 0.1282, train acc 0.751, test acc 0.749, time 2.2 sec\n"
]
}
],
"source": [
"lr, num_epochs = 0.001, 5\n",
"optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
"train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}