Dive-into-DL-PyTorch/code/chapter09_computer-vision/9.2_fine-tuning.ipynb

366 lines
191 KiB

3 years ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 9.2 微调"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-03T14:11:15.965701Z",
"start_time": "2019-06-03T14:11:00.668216Z"
}
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import torch\n",
"from torch import nn, optim\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import torchvision\n",
"from torchvision.datasets import ImageFolder\n",
"from torchvision import transforms\n",
"from torchvision import models\n",
"import os\n",
"\n",
"import sys\n",
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9.2.1 热狗识别\n",
"### 9.2.1.1 获取数据集"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-03T14:11:15.973612Z",
"start_time": "2019-06-03T14:11:15.967565Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"['train', 'test']"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_dir = '/S1/CSCL/tangss/Datasets'\n",
"os.listdir(os.path.join(data_dir, \"hotdog\"))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-03T14:11:16.061146Z",
"start_time": "2019-06-03T14:11:15.975384Z"
}
},
"outputs": [],
"source": [
"train_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/train'))\n",
"test_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/test'))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-03T14:11:17.118164Z",
"start_time": "2019-06-03T14:11:16.062645Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoIAAACtCAYAAAAklR7fAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzs3Xm4ZVdZ4P/v2vM+87lj3VvzmEpVKjMECGEwCCKDIAE0CC2GnzaICI0IAhGZFGnF9ofi0I8DONCAIDYIgjIlkEBGKqlKzVW37jyeec97rfX741YVsfvpRysQw6+zP89Tzz13n31OnfOe/dzz7netd22htaZQKBQKhUKh8PhjPNYvoFAoFAqFQqHw2CgSwUKhUCgUCoXHqSIRLBQKhUKhUHicKhLBQqFQKBQKhcepIhEsFAqFQqFQeJwqEsFCoVAoFAqFx6kiESwUCoVCoVB4nCoSwUKhUCgUCoXHqSIRLBQKhUKhUHicsi5mZ9d1dL1WIRhEDI80aLV65DInz1JGR4fJsxTDMKhWa1iWTavVIooSXNcCBEEY0qjXCKME2zRxbIc8y4iTiHqlQhD0GKo7SKUxtYlCk2eSUr3K1NwqmB6uY6M1GIZJnuekaYjjulimBUKg8pz+IKDs+0xMTKK0RggBCNCgxfqVVPS57VprDMPg/BVWtIaVlWX6/Z74QQe7cPE8r64r5bFzn5FCIEBITNNECEGWpSAMbEswNl5BGzkIY30/crQ+fxuUVghhgFCgwZQGrdU+YZQipcI0LZRSCGP9/EgrhVQKy7ZBGGzYMESr1UZjUC57CGEic4lpGesHDpr1cyuNUhLTNNCcO74ArdbfkxAGQuj19yPE+mOEwjAEWmkQ6/sceejBVa316H981AuFQqHweHFRiWCtUuYlz38W1+y/lD/8008ja4KSazLacHnCNdcw6PW56sBljI7UuezAHoZGayidMTw8QRLHWIZgEHTpxzlVX5H2I0ZHqnzqr79FfuRTbH76T1Jzalg6YXHlu8i+onbgCZScreh0life+HJy7ZJmkk6vz+mps0ydmee++w9z6tQCJ6fmoVFFa02axrz7vb+BrewLr19Kicg1ea5wHIskSzCFgeu6KCmRUpJLxa2//ms/8EAXHpmyP8yLn/sBXNclzxXaSGgOlUjTGKkyMD081+Vttz6D1JkFbSOlxLIMsizDlS4ChQzANGKSShnXVZCbSCPDizyiGXj3rX8JjksQBLi2ySCJQFpU/Ap+zSMYJNgVC+mY/OiNN9BZiaiNwvjoFlItsS3NoNfBdW2EBr9cI0kCVldbDA3XEaZFmubkSUpztEavO0DqHADHcajWXOKkz/DQEJYAjeLyy3ecfWyjXygUCoX/24mLudbwSLOq3/L6n2Lr1iZPf9qzGGqOEIYJSZ4SDxLCMMa2TWx7/UvNtodJU4VwwJIav+KjtCYXEaVylYNfv4sHvvF3XHtVnUAqZBowNrEXK+0wPzPHyNbLsP0mmRFS88fIjBI7L7mGdreLEAKpclRqog1Jkg7wKzYT9RqOXUZJi2NHpliRglZ7QKZMhOGAbXLqxCyLy0tUKiWUzPjiF7/A6soSSuVoAd1ulyiKiorgD4Gh+jb9omf/Fpu3buKBhw7juILxsY1IFSEzRcmscvPPTzJ6mY+vJKm2cTybLMmxpIFqZYBDfOYQaZ4xctV1hEYIJpiuwHQs8m6IvzzM+975YVr9FCVs3LKNQZlBGOD7Hnmm8SsunUFIdcjkV99xC8cfHJDIDNvNcXwH41wlMQoTPMdheLjO0tIKrmfhl8ukcUyS5ZiWIs005WqNYNCiWvNwhaBUcRmEOWkcIoTmhqddfq/W+trH9hMoFAqFwv/NLqoiuGFiAz9x009iRQ5RVONstoStfZROwQC/WgFAYGPVc4Q3wktf+J/52t/9LiW3wSCeJs9zop7m79/3BraMpVw2PkS1sgPHauGnLtbIKNPfPIIn+1iuBCxco85qe5XRsQ1883+8h6ufczO6uRk7cTCqBkk3wLYMXDFEuxehWMawDRob67hhzsZNPhVXYGcpce6xsTLMbXct8+EP/yWL3TUOXHE5m3ZsIw5WGWkMc/d9Bx+NWBceASEMQDE7O0tjuMkgDqmNOsg0Ixl4bN4VMnFlFWyTNEwwHINcKSzlIFspqZaU5QDr7AmmlvtYzgTuWBlztI7IMgJCynWfyAx530d+lbe89jdA+iBcltuz+I6HTFM8p04aJwgkI9UtvPfdnwR7gVf+9E+hIoc0yTGFIM9zPM9Da0m73SaOQ5rDQ0CC7Wik1Li2g2HkqDyhWvFoVHwsAVESo3MIB5qNm0Ye48gXCoVC4fHgoppFTMPEUAaJ1wenQ963kDJDYIK20FqgFBhWRBpojtz7AN+57eP87Dv/ljf99kf4iVs+zFc+dw//9F/fyrjoYboKWYXW2lHidopIywSrcziuoDo+Rj9KMRyX4ydPUm8MoYSB2HodUycPMjh2BxMbJ/FVij1s8pnb26hKlX4vJk8d4tCiF6z/H1Vs4tBgKXGIcsHE2DA3P/8JPHjH+whP/QO3/f17ed/bXsPuye3MzEdoXRQDf1hoNJ1Oh267TWetxcYxh7GhhLe88YVU3TVedcuN5HaAlUhsz0YLRZIkJK0YOhqh2pjLZ0nuPMiOUolGliNn59DdAbGRYikLGWaIsmTZOssH/+DX2Ld/kijuoZQCNGmakqU5mUwR0uLk6ZPkUYKdDfOxP/8Cl167A9tw0ErgeR4qixHCxLI8xsc3IrCJMwGmS5gogkRiagfTgDiICXoZ3UFAKi1s02Z0rEYYdR7r0BcKhULhceCiEsFcKca3Xs5lV70GyxumUvcQlgtYCGFiGBamaaPzJpZ22H/JODKKGKvm/PItN3LTUzfSv/sTjFdi6sMOpfoo1doQ0/eepXXmELkbEQ3WsE3JWpKwefNOkjRlYuMkmZIsLy+zqbGJME1YCOCbd36DOWHz3z7wZTb4x3nLL30AVRvl23cd5fN/9Y/c/w/fQnZSquOjdIIOXh5juQMaG0ZJXUkQX83U0iq9JcXu8WHe/86XcMdn38WenZsfpXAXLpZpmRimJrNtHMvjbe+8iZe96UoYmydIB3heir3sYy2H9Hop4cmA4NsrdL72XZpHpqjPT7H20B3kjs9sq8PC3CFac1NM33kP/oyBE9qYwkJGGr9UZlDv8eKf+xHe++7XUrNsTEtgWSauq0HmZHmIbRkM1z3yROJ6Zd7+1g+xa+cIpqHQRo5fqhEEfVprHaRM6PcDTAOSNKJZc6mXBaYVU/EtKpUSvrf+GkwtcVxBu91Gyn//lI1CoVAoFB6pi5ojWKvV9YFrbiKK1nBs+Ou/fD8yXMQR5xsyzrdF9kBXGJ6YIMwrBIvT1D2bP/nl17N7c4z0JSMbRtGlMWr1MoPVRTQpi3ed5ZLn38D0PXey+5k/wlrPotoYZfrsGcbHx7Asm1DFDFdGECpAGA3u/Kd/Ztuu2xnzd/OlLy4x7u7HMFLskoNRG8GNDc4uLbDpuit5yguew+FjPT73+QeZnj/Ir7zlJt7+jj/hY3/6ZlzTZjAoE+Zz/Mwt7+aBwyeKsuAPgZHmdn39099DOe1x628+j9ruDKEc4imTh/78S+xQazSHJwk2ZTj9IdKSR+nyEiNDJvOnZ0kOPog8eD/9Wh32XMOpmWMMTVbYu+8aer1VRrdfhzu5hcTrYzo2ERG2bZJFoBfht978e2SySlwyUP0c0/apNuqkuSaMAkqux8TmrUwtrOFZBm/41RtYXSyTCclYrUm7s0aGjWXkGKbAsUxMO0dpsByPte6AsUYJz3VBGBw58hCTG8ewbZPrnrSvmCNYKBQKhUfVRVUEN05Oog1Nd+Y4e/bt5Xk/8UpEHqK1BBRKx9iWgZYlNBVe+jPv5jk3vghpTnDrz72azeMBkVtiZGiYUnMCt65JRImxrbtgAF7D4cztd2EIh0E7xHNsov4Cm8pnSJYfIpYGtmiQGD7RoubM3X/AxIZ/YaRc4xufX6WhL0XrCMt0EJmB0VohS9aoiZj46BH+4m23ctX
"text/plain": [
"<Figure size 806.4x201.6 with 16 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"hotdogs = [train_imgs[i][0] for i in range(8)]\n",
"not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]\n",
"d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-03T14:11:17.123685Z",
"start_time": "2019-06-03T14:11:17.119781Z"
}
},
"outputs": [],
"source": [
"normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
"train_augs = transforms.Compose([\n",
" transforms.RandomResizedCrop(size=224),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.ToTensor(),\n",
" normalize\n",
" ])\n",
"\n",
"test_augs = transforms.Compose([\n",
" transforms.Resize(size=256),\n",
" transforms.CenterCrop(size=224),\n",
" transforms.ToTensor(),\n",
" normalize\n",
" ])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 9.2.1.2 定义和初始化模型"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-03T14:11:17.620581Z",
"start_time": "2019-06-03T14:11:17.125242Z"
}
},
"outputs": [],
"source": [
"pretrained_net = models.resnet18(pretrained=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-03T14:11:17.624672Z",
"start_time": "2019-06-03T14:11:17.622216Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Linear(in_features=512, out_features=1000, bias=True)\n"
]
}
],
"source": [
"print(pretrained_net.fc)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-03T14:11:17.670998Z",
"start_time": "2019-06-03T14:11:17.626040Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Linear(in_features=512, out_features=2, bias=True)\n"
]
}
],
"source": [
"pretrained_net.fc = nn.Linear(512, 2)\n",
"print(pretrained_net.fc)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-03T14:11:17.716500Z",
"start_time": "2019-06-03T14:11:17.672667Z"
}
},
"outputs": [],
"source": [
"output_params = list(map(id, pretrained_net.fc.parameters()))\n",
"feature_params = filter(lambda p: id(p) not in output_params, pretrained_net.parameters())\n",
"\n",
"lr = 0.01\n",
"optimizer = optim.SGD([{'params': feature_params},\n",
" {'params': pretrained_net.fc.parameters(), 'lr': lr * 10}],\n",
" lr=lr, weight_decay=0.001)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 9.2.1.3 微调模型"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-03T14:11:17.761943Z",
"start_time": "2019-06-03T14:11:17.718154Z"
}
},
"outputs": [],
"source": [
"def train_fine_tuning(net, optimizer, batch_size=128, num_epochs=5):\n",
" train_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/train'), transform=train_augs),\n",
" batch_size, shuffle=True)\n",
" test_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/test'), transform=test_augs),\n",
" batch_size)\n",
" loss = torch.nn.CrossEntropyLoss()\n",
" d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-03T14:13:52.316406Z",
"start_time": "2019-06-03T14:11:17.763719Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training on cuda\n",
"epoch 1, loss 3.1183, train acc 0.731, test acc 0.932, time 41.4 sec\n",
"epoch 2, loss 0.6471, train acc 0.829, test acc 0.869, time 25.6 sec\n",
"epoch 3, loss 0.0964, train acc 0.920, test acc 0.910, time 24.9 sec\n",
"epoch 4, loss 0.0659, train acc 0.922, test acc 0.936, time 25.2 sec\n",
"epoch 5, loss 0.0668, train acc 0.913, test acc 0.929, time 25.0 sec\n"
]
}
],
"source": [
"train_fine_tuning(pretrained_net, optimizer)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-03T14:15:57.891925Z",
"start_time": "2019-06-03T14:13:52.319140Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training on cuda\n",
"epoch 1, loss 2.6686, train acc 0.582, test acc 0.556, time 25.3 sec\n",
"epoch 2, loss 0.2434, train acc 0.797, test acc 0.776, time 25.3 sec\n",
"epoch 3, loss 0.1251, train acc 0.845, test acc 0.802, time 24.9 sec\n",
"epoch 4, loss 0.0958, train acc 0.833, test acc 0.810, time 25.0 sec\n",
"epoch 5, loss 0.0757, train acc 0.836, test acc 0.780, time 24.9 sec\n"
]
}
],
"source": [
"scratch_net = models.resnet18(pretrained=False, num_classes=2)\n",
"lr = 0.1\n",
"optimizer = optim.SGD(scratch_net.parameters(), lr=lr, weight_decay=0.001)\n",
"train_fine_tuning(scratch_net, optimizer)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:py36]",
"language": "python",
"name": "conda-env-py36-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}