{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 7.3 小批量随机梯度下降" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0.0\n" ] } ], "source": [ "%matplotlib inline\n", "import numpy as np\n", "import time\n", "import torch\n", "from torch import nn, optim\n", "import sys\n", "sys.path.append(\"..\") \n", "import d2lzh_pytorch as d2l\n", "\n", "print(torch.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7.3.1 读取数据" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1500, 5])" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def get_data_ch7(): # 本函数已保存在d2lzh_pytorch包中方便以后使用\n", " data = np.genfromtxt('../../data/airfoil_self_noise.dat', delimiter='\\t')\n", " data = (data - data.mean(axis=0)) / data.std(axis=0) # 标准化\n", " return torch.tensor(data[:1500, :-1], dtype=torch.float32), \\\n", " torch.tensor(data[:1500, -1], dtype=torch.float32) # 前1500个样本(每个样本5个特征)\n", "\n", "features, labels = get_data_ch7()\n", "features.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7.3.2 从零开始实现" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def sgd(params, states, hyperparams):\n", " for p in params:\n", " p.data -= hyperparams['lr'] * p.grad.data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# 本函数已保存在d2lzh_pytorch包中方便以后使用\n", "def train_ch7(optimizer_fn, states, hyperparams, features, labels,\n", " batch_size=10, num_epochs=2):\n", " # 初始化模型\n", " net, loss = d2l.linreg, d2l.squared_loss\n", " \n", " w = torch.nn.Parameter(torch.tensor(np.random.normal(0, 0.01, size=(features.shape[1], 1)), dtype=torch.float32),\n", " requires_grad=True)\n", " b = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=True)\n", "\n", " def eval_loss():\n", " return loss(net(features, w, b), labels).mean().item()\n", "\n", " ls = [eval_loss()]\n", " data_iter = torch.utils.data.DataLoader(\n", " torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)\n", " \n", " for _ in range(num_epochs):\n", " start = time.time()\n", " for batch_i, (X, y) in enumerate(data_iter):\n", " l = loss(net(X, w, b), y).mean() # 使用平均损失\n", " \n", " # 梯度清零\n", " if w.grad is not None:\n", " w.grad.data.zero_()\n", " b.grad.data.zero_()\n", " \n", " l.backward()\n", " optimizer_fn([w, b], states, hyperparams) # 迭代模型参数\n", " if (batch_i + 1) * batch_size % 100 == 0:\n", " ls.append(eval_loss()) # 每100个样本记录下当前训练误差\n", " # 打印结果和作图\n", " print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))\n", " d2l.set_figsize()\n", " d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)\n", " d2l.plt.xlabel('epoch')\n", " d2l.plt.ylabel('loss')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 0.244678, 0.012216 sec per epoch\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def train_sgd(lr, batch_size, num_epochs=2):\n", " train_ch7(sgd, None, {'lr': lr}, features, labels, batch_size, num_epochs)\n", "\n", "train_sgd(1, 1500, 6)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 0.246391, 0.334214 sec per epoch\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "train_sgd(0.005, 1)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 0.245523, 0.050718 sec per epoch\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "train_sgd(0.05, 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7.3.3 简洁实现" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# 本函数与原书不同的是这里第一个参数优化器函数而不是优化器的名字\n", "# 例如: optimizer_fn=torch.optim.SGD, optimizer_hyperparams={\"lr\": 0.05}\n", "def train_pytorch_ch7(optimizer_fn, optimizer_hyperparams, features, labels,\n", " batch_size=10, num_epochs=2):\n", " # 初始化模型\n", " net = nn.Sequential(\n", " nn.Linear(features.shape[-1], 1)\n", " )\n", " loss = nn.MSELoss()\n", " optimizer = optimizer_fn(net.parameters(), **optimizer_hyperparams)\n", "\n", " def eval_loss():\n", " return loss(net(features).view(-1), labels).item() / 2\n", "\n", " ls = [eval_loss()]\n", " data_iter = torch.utils.data.DataLoader(\n", " torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)\n", "\n", " for _ in range(num_epochs):\n", " start = time.time()\n", " for batch_i, (X, y) in enumerate(data_iter):\n", " # 除以2是为了和train_ch7保持一致, 因为squared_loss中除了2\n", " l = loss(net(X).view(-1), y) / 2 \n", " \n", " optimizer.zero_grad()\n", " l.backward()\n", " optimizer.step()\n", " if (batch_i + 1) * batch_size % 100 == 0:\n", " ls.append(eval_loss())\n", " # 打印结果和作图\n", " print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))\n", " d2l.set_figsize()\n", " d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)\n", " d2l.plt.xlabel('epoch')\n", " d2l.plt.ylabel('loss')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 0.245491, 0.044150 sec per epoch\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "train_pytorch_ch7(optim.SGD, {\"lr\": 0.05}, features, labels, 10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }