{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3.12 权重衰减"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4.1\n"
]
}
],
"source": [
"%matplotlib inline\n",
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np\n",
"import sys\n",
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"\n",
"print(torch.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.12.2 高维线性回归实验"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"n_train, n_test, num_inputs = 20, 100, 200\n",
"true_w, true_b = torch.ones(num_inputs, 1) * 0.01, 0.05\n",
"\n",
"features = torch.randn((n_train + n_test, num_inputs))\n",
"labels = torch.matmul(features, true_w) + true_b\n",
"labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)\n",
"train_features, test_features = features[:n_train, :], features[n_train:, :]\n",
"train_labels, test_labels = labels[:n_train], labels[n_train:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.12.3 从零开始实现\n",
"### 3.12.3.1 初始化模型参数"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def init_params():\n",
" w = torch.randn((num_inputs, 1), requires_grad=True)\n",
" b = torch.zeros(1, requires_grad=True)\n",
" return [w, b]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.12.3.2 定义$L_2$范数惩罚项"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def l2_penalty(w):\n",
" return (w**2).sum() / 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.12.3.3 定义训练和测试"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"batch_size, num_epochs, lr = 1, 100, 0.003\n",
"net, loss = d2l.linreg, d2l.squared_loss\n",
"\n",
"dataset = torch.utils.data.TensorDataset(train_features, train_labels)\n",
"train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)\n",
"\n",
"def fit_and_plot(lambd):\n",
" w, b = init_params()\n",
" train_ls, test_ls = [], []\n",
" for _ in range(num_epochs):\n",
" for X, y in train_iter:\n",
" # 添加了L2范数惩罚项\n",
" l = loss(net(X, w, b), y) + lambd * l2_penalty(w)\n",
" l = l.sum()\n",
" \n",
" if w.grad is not None:\n",
" w.grad.data.zero_()\n",
" b.grad.data.zero_()\n",
" l.backward()\n",
" d2l.sgd([w, b], lr, batch_size)\n",
" train_ls.append(loss(net(train_features, w, b), train_labels).mean().item())\n",
" test_ls.append(loss(net(test_features, w, b), test_labels).mean().item())\n",
" d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',\n",
" range(1, num_epochs + 1), test_ls, ['train', 'test'])\n",
" print('L2 norm of w:', w.norm().item())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.12.3.4 观察过拟合"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"L2 norm of w: 15.114808082580566\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fit_and_plot(lambd=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.12.3.5 使用权重衰减"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"L2 norm of w: 0.035220853984355927\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fit_and_plot(lambd=3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.12.4 简洁实现"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def fit_and_plot_pytorch(wd):\n",
" # 对权重参数衰减。权重名称一般是以weight结尾\n",
" net = nn.Linear(num_inputs, 1)\n",
" nn.init.normal_(net.weight, mean=0, std=1)\n",
" nn.init.normal_(net.bias, mean=0, std=1)\n",
" optimizer_w = torch.optim.SGD(params=[net.weight], lr=lr, weight_decay=wd) # 对权重参数衰减\n",
" optimizer_b = torch.optim.SGD(params=[net.bias], lr=lr) # 不对偏差参数衰减\n",
" \n",
" train_ls, test_ls = [], []\n",
" for _ in range(num_epochs):\n",
" for X, y in train_iter:\n",
" l = loss(net(X), y).mean()\n",
" optimizer_w.zero_grad()\n",
" optimizer_b.zero_grad()\n",
" \n",
" l.backward()\n",
" \n",
" # 对两个optimizer实例分别调用step函数,从而分别更新权重和偏差\n",
" optimizer_w.step()\n",
" optimizer_b.step()\n",
" train_ls.append(loss(net(train_features), train_labels).mean().item())\n",
" test_ls.append(loss(net(test_features), test_labels).mean().item())\n",
" d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',\n",
" range(1, num_epochs + 1), test_ls, ['train', 'test'])\n",
" print('L2 norm of w:', net.weight.data.norm().item())"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"L2 norm of w: 12.86785888671875\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fit_and_plot_pytorch(0)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"L2 norm of w: 0.09631537646055222\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fit_and_plot_pytorch(3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}