You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2654 lines
106 KiB

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 7.3 小批量随机梯度下降"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0.0\n"
]
}
],
"source": [
"%matplotlib inline\n",
"import numpy as np\n",
"import time\n",
"import torch\n",
"from torch import nn, optim\n",
"import sys\n",
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"\n",
"print(torch.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7.3.1 读取数据"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1500, 5])"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def get_data_ch7(): # 本函数已保存在d2lzh_pytorch包中方便以后使用\n",
" data = np.genfromtxt('../../data/airfoil_self_noise.dat', delimiter='\\t')\n",
" data = (data - data.mean(axis=0)) / data.std(axis=0) # 标准化\n",
" return torch.tensor(data[:1500, :-1], dtype=torch.float32), \\\n",
" torch.tensor(data[:1500, -1], dtype=torch.float32) # 前1500个样本(每个样本5个特征)\n",
"\n",
"features, labels = get_data_ch7()\n",
"features.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7.3.2 从零开始实现"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def sgd(params, states, hyperparams):\n",
" for p in params:\n",
" p.data -= hyperparams['lr'] * p.grad.data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# 本函数已保存在d2lzh_pytorch包中方便以后使用\n",
"def train_ch7(optimizer_fn, states, hyperparams, features, labels,\n",
" batch_size=10, num_epochs=2):\n",
" # 初始化模型\n",
" net, loss = d2l.linreg, d2l.squared_loss\n",
" \n",
" w = torch.nn.Parameter(torch.tensor(np.random.normal(0, 0.01, size=(features.shape[1], 1)), dtype=torch.float32),\n",
" requires_grad=True)\n",
" b = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=True)\n",
"\n",
" def eval_loss():\n",
" return loss(net(features, w, b), labels).mean().item()\n",
"\n",
" ls = [eval_loss()]\n",
" data_iter = torch.utils.data.DataLoader(\n",
" torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)\n",
" \n",
" for _ in range(num_epochs):\n",
" start = time.time()\n",
" for batch_i, (X, y) in enumerate(data_iter):\n",
" l = loss(net(X, w, b), y).mean() # 使用平均损失\n",
" \n",
" # 梯度清零\n",
" if w.grad is not None:\n",
" w.grad.data.zero_()\n",
" b.grad.data.zero_()\n",
" \n",
" l.backward()\n",
" optimizer_fn([w, b], states, hyperparams) # 迭代模型参数\n",
" if (batch_i + 1) * batch_size % 100 == 0:\n",
" ls.append(eval_loss()) # 每100个样本记录下当前训练误差\n",
" # 打印结果和作图\n",
" print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))\n",
" d2l.set_figsize()\n",
" d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)\n",
" d2l.plt.xlabel('epoch')\n",
" d2l.plt.ylabel('loss')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss: 0.244678, 0.012216 sec per epoch\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Created with matplotlib (http://matplotlib.org/) -->\n",
"<svg height=\"184pt\" version=\"1.1\" viewBox=\"0 0 256 184\" width=\"256pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
" <defs>\n",
" <style type=\"text/css\">\n",
"*{stroke-linecap:butt;stroke-linejoin:round;}\n",
" </style>\n",
" </defs>\n",
" <g id=\"figure_1\">\n",
" <g id=\"patch_1\">\n",
" <path d=\"M 0 184.15625 \n",
"L 256.14375 184.15625 \n",
"L 256.14375 -0 \n",
"L 0 -0 \n",
"z\n",
"\" style=\"fill:none;\"/>\n",
" </g>\n",
" <g id=\"axes_1\">\n",
" <g id=\"patch_2\">\n",
" <path d=\"M 50.14375 146.6 \n",
"L 245.44375 146.6 \n",
"L 245.44375 10.7 \n",
"L 50.14375 10.7 \n",
"z\n",
"\" style=\"fill:#ffffff;\"/>\n",
" </g>\n",
" <g id=\"matplotlib.axis_1\">\n",
" <g id=\"xtick_1\">\n",
" <g id=\"line2d_1\">\n",
" <defs>\n",
" <path d=\"M 0 0 \n",
"L 0 3.5 \n",
"\" id=\"mac40f26caa\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
" </defs>\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"59.021023\" xlink:href=\"#mac40f26caa\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_1\">\n",
" <!-- 0 -->\n",
" <defs>\n",
" <path d=\"M 31.78125 66.40625 \n",
"Q 24.171875 66.40625 20.328125 58.90625 \n",
"Q 16.5 51.421875 16.5 36.375 \n",
"Q 16.5 21.390625 20.328125 13.890625 \n",
"Q 24.171875 6.390625 31.78125 6.390625 \n",
"Q 39.453125 6.390625 43.28125 13.890625 \n",
"Q 47.125 21.390625 47.125 36.375 \n",
"Q 47.125 51.421875 43.28125 58.90625 \n",
"Q 39.453125 66.40625 31.78125 66.40625 \n",
"z\n",
"M 31.78125 74.21875 \n",
"Q 44.046875 74.21875 50.515625 64.515625 \n",
"Q 56.984375 54.828125 56.984375 36.375 \n",
"Q 56.984375 17.96875 50.515625 8.265625 \n",
"Q 44.046875 -1.421875 31.78125 -1.421875 \n",
"Q 19.53125 -1.421875 13.0625 8.265625 \n",
"Q 6.59375 17.96875 6.59375 36.375 \n",
"Q 6.59375 54.828125 13.0625 64.515625 \n",
"Q 19.53125 74.21875 31.78125 74.21875 \n",
"z\n",
"\" id=\"DejaVuSans-30\"/>\n",
" </defs>\n",
" <g transform=\"translate(55.839773 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_2\">\n",
" <g id=\"line2d_2\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"118.202841\" xlink:href=\"#mac40f26caa\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_2\">\n",
" <!-- 2 -->\n",
" <defs>\n",
" <path d=\"M 19.1875 8.296875 \n",
"L 53.609375 8.296875 \n",
"L 53.609375 0 \n",
"L 7.328125 0 \n",
"L 7.328125 8.296875 \n",
"Q 12.9375 14.109375 22.625 23.890625 \n",
"Q 32.328125 33.6875 34.8125 36.53125 \n",
"Q 39.546875 41.84375 41.421875 45.53125 \n",
"Q 43.3125 49.21875 43.3125 52.78125 \n",
"Q 43.3125 58.59375 39.234375 62.25 \n",
"Q 35.15625 65.921875 28.609375 65.921875 \n",
"Q 23.96875 65.921875 18.8125 64.3125 \n",
"Q 13.671875 62.703125 7.8125 59.421875 \n",
"L 7.8125 69.390625 \n",
"Q 13.765625 71.78125 18.9375 73 \n",
"Q 24.125 74.21875 28.421875 74.21875 \n",
"Q 39.75 74.21875 46.484375 68.546875 \n",
"Q 53.21875 62.890625 53.21875 53.421875 \n",
"Q 53.21875 48.921875 51.53125 44.890625 \n",
"Q 49.859375 40.875 45.40625 35.40625 \n",
"Q 44.1875 33.984375 37.640625 27.21875 \n",
"Q 31.109375 20.453125 19.1875 8.296875 \n",
"z\n",
"\" id=\"DejaVuSans-32\"/>\n",
" </defs>\n",
" <g transform=\"translate(115.021591 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_3\">\n",
" <g id=\"line2d_3\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"177.384659\" xlink:href=\"#mac40f26caa\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_3\">\n",
" <!-- 4 -->\n",
" <defs>\n",
" <path d=\"M 37.796875 64.3125 \n",
"L 12.890625 25.390625 \n",
"L 37.796875 25.390625 \n",
"z\n",
"M 35.203125 72.90625 \n",
"L 47.609375 72.90625 \n",
"L 47.609375 25.390625 \n",
"L 58.015625 25.390625 \n",
"L 58.015625 17.1875 \n",
"L 47.609375 17.1875 \n",
"L 47.609375 0 \n",
"L 37.796875 0 \n",
"L 37.796875 17.1875 \n",
"L 4.890625 17.1875 \n",
"L 4.890625 26.703125 \n",
"z\n",
"\" id=\"DejaVuSans-34\"/>\n",
" </defs>\n",
" <g transform=\"translate(174.203409 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_4\">\n",
" <g id=\"line2d_4\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"236.566477\" xlink:href=\"#mac40f26caa\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_4\">\n",
" <!-- 6 -->\n",
" <defs>\n",
" <path d=\"M 33.015625 40.375 \n",
"Q 26.375 40.375 22.484375 35.828125 \n",
"Q 18.609375 31.296875 18.609375 23.390625 \n",
"Q 18.609375 15.53125 22.484375 10.953125 \n",
"Q 26.375 6.390625 33.015625 6.390625 \n",
"Q 39.65625 6.390625 43.53125 10.953125 \n",
"Q 47.40625 15.53125 47.40625 23.390625 \n",
"Q 47.40625 31.296875 43.53125 35.828125 \n",
"Q 39.65625 40.375 33.015625 40.375 \n",
"z\n",
"M 52.59375 71.296875 \n",
"L 52.59375 62.3125 \n",
"Q 48.875 64.0625 45.09375 64.984375 \n",
"Q 41.3125 65.921875 37.59375 65.921875 \n",
"Q 27.828125 65.921875 22.671875 59.328125 \n",
"Q 17.53125 52.734375 16.796875 39.40625 \n",
"Q 19.671875 43.65625 24.015625 45.921875 \n",
"Q 28.375 48.1875 33.59375 48.1875 \n",
"Q 44.578125 48.1875 50.953125 41.515625 \n",
"Q 57.328125 34.859375 57.328125 23.390625 \n",
"Q 57.328125 12.15625 50.6875 5.359375 \n",
"Q 44.046875 -1.421875 33.015625 -1.421875 \n",
"Q 20.359375 -1.421875 13.671875 8.265625 \n",
"Q 6.984375 17.96875 6.984375 36.375 \n",
"Q 6.984375 53.65625 15.1875 63.9375 \n",
"Q 23.390625 74.21875 37.203125 74.21875 \n",
"Q 40.921875 74.21875 44.703125 73.484375 \n",
"Q 48.484375 72.75 52.59375 71.296875 \n",
"z\n",
"\" id=\"DejaVuSans-36\"/>\n",
" </defs>\n",
" <g transform=\"translate(233.385227 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-36\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_5\">\n",
" <!-- epoch -->\n",
" <defs>\n",
" <path d=\"M 56.203125 29.59375 \n",
"L 56.203125 25.203125 \n",
"L 14.890625 25.203125 \n",
"Q 15.484375 15.921875 20.484375 11.0625 \n",
"Q 25.484375 6.203125 34.421875 6.203125 \n",
"Q 39.59375 6.203125 44.453125 7.46875 \n",
"Q 49.3125 8.734375 54.109375 11.28125 \n",
"L 54.109375 2.78125 \n",
"Q 49.265625 0.734375 44.1875 -0.34375 \n",
"Q 39.109375 -1.421875 33.890625 -1.421875 \n",
"Q 20.796875 -1.421875 13.15625 6.1875 \n",
"Q 5.515625 13.8125 5.515625 26.8125 \n",
"Q 5.515625 40.234375 12.765625 48.109375 \n",
"Q 20.015625 56 32.328125 56 \n",
"Q 43.359375 56 49.78125 48.890625 \n",
"Q 56.203125 41.796875 56.203125 29.59375 \n",
"z\n",
"M 47.21875 32.234375 \n",
"Q 47.125 39.59375 43.09375 43.984375 \n",
"Q 39.0625 48.390625 32.421875 48.390625 \n",
"Q 24.90625 48.390625 20.390625 44.140625 \n",
"Q 15.875 39.890625 15.1875 32.171875 \n",
"z\n",
"\" id=\"DejaVuSans-65\"/>\n",
" <path d=\"M 18.109375 8.203125 \n",
"L 18.109375 -20.796875 \n",
"L 9.078125 -20.796875 \n",
"L 9.078125 54.6875 \n",
"L 18.109375 54.6875 \n",
"L 18.109375 46.390625 \n",
"Q 20.953125 51.265625 25.265625 53.625 \n",
"Q 29.59375 56 35.59375 56 \n",
"Q 45.5625 56 51.78125 48.09375 \n",
"Q 58.015625 40.1875 58.015625 27.296875 \n",
"Q 58.015625 14.40625 51.78125 6.484375 \n",
"Q 45.5625 -1.421875 35.59375 -1.421875 \n",
"Q 29.59375 -1.421875 25.265625 0.953125 \n",
"Q 20.953125 3.328125 18.109375 8.203125 \n",
"z\n",
"M 48.6875 27.296875 \n",
"Q 48.6875 37.203125 44.609375 42.84375 \n",
"Q 40.53125 48.484375 33.40625 48.484375 \n",
"Q 26.265625 48.484375 22.1875 42.84375 \n",
"Q 18.109375 37.203125 18.109375 27.296875 \n",
"Q 18.109375 17.390625 22.1875 11.75 \n",
"Q 26.265625 6.109375 33.40625 6.109375 \n",
"Q 40.53125 6.109375 44.609375 11.75 \n",
"Q 48.6875 17.390625 48.6875 27.296875 \n",
"z\n",
"\" id=\"DejaVuSans-70\"/>\n",
" <path d=\"M 30.609375 48.390625 \n",
"Q 23.390625 48.390625 19.1875 42.75 \n",
"Q 14.984375 37.109375 14.984375 27.296875 \n",
"Q 14.984375 17.484375 19.15625 11.84375 \n",
"Q 23.34375 6.203125 30.609375 6.203125 \n",
"Q 37.796875 6.203125 41.984375 11.859375 \n",
"Q 46.1875 17.53125 46.1875 27.296875 \n",
"Q 46.1875 37.015625 41.984375 42.703125 \n",
"Q 37.796875 48.390625 30.609375 48.390625 \n",
"z\n",
"M 30.609375 56 \n",
"Q 42.328125 56 49.015625 48.375 \n",
"Q 55.71875 40.765625 55.71875 27.296875 \n",
"Q 55.71875 13.875 49.015625 6.21875 \n",
"Q 42.328125 -1.421875 30.609375 -1.421875 \n",
"Q 18.84375 -1.421875 12.171875 6.21875 \n",
"Q 5.515625 13.875 5.515625 27.296875 \n",
"Q 5.515625 40.765625 12.171875 48.375 \n",
"Q 18.84375 56 30.609375 56 \n",
"z\n",
"\" id=\"DejaVuSans-6f\"/>\n",
" <path d=\"M 48.78125 52.59375 \n",
"L 48.78125 44.1875 \n",
"Q 44.96875 46.296875 41.140625 47.34375 \n",
"Q 37.3125 48.390625 33.40625 48.390625 \n",
"Q 24.65625 48.390625 19.8125 42.84375 \n",
"Q 14.984375 37.3125 14.984375 27.296875 \n",
"Q 14.984375 17.28125 19.8125 11.734375 \n",
"Q 24.65625 6.203125 33.40625 6.203125 \n",
"Q 37.3125 6.203125 41.140625 7.25 \n",
"Q 44.96875 8.296875 48.78125 10.40625 \n",
"L 48.78125 2.09375 \n",
"Q 45.015625 0.34375 40.984375 -0.53125 \n",
"Q 36.96875 -1.421875 32.421875 -1.421875 \n",
"Q 20.0625 -1.421875 12.78125 6.34375 \n",
"Q 5.515625 14.109375 5.515625 27.296875 \n",
"Q 5.515625 40.671875 12.859375 48.328125 \n",
"Q 20.21875 56 33.015625 56 \n",
"Q 37.15625 56 41.109375 55.140625 \n",
"Q 45.0625 54.296875 48.78125 52.59375 \n",
"z\n",
"\" id=\"DejaVuSans-63\"/>\n",
" <path d=\"M 54.890625 33.015625 \n",
"L 54.890625 0 \n",
"L 45.90625 0 \n",
"L 45.90625 32.71875 \n",
"Q 45.90625 40.484375 42.875 44.328125 \n",
"Q 39.84375 48.1875 33.796875 48.1875 \n",
"Q 26.515625 48.1875 22.3125 43.546875 \n",
"Q 18.109375 38.921875 18.109375 30.90625 \n",
"L 18.109375 0 \n",
"L 9.078125 0 \n",
"L 9.078125 75.984375 \n",
"L 18.109375 75.984375 \n",
"L 18.109375 46.1875 \n",
"Q 21.34375 51.125 25.703125 53.5625 \n",
"Q 30.078125 56 35.796875 56 \n",
"Q 45.21875 56 50.046875 50.171875 \n",
"Q 54.890625 44.34375 54.890625 33.015625 \n",
"z\n",
"\" id=\"DejaVuSans-68\"/>\n",
" </defs>\n",
" <g transform=\"translate(132.565625 174.876562)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-65\"/>\n",
" <use x=\"61.523438\" xlink:href=\"#DejaVuSans-70\"/>\n",
" <use x=\"125\" xlink:href=\"#DejaVuSans-6f\"/>\n",
" <use x=\"186.181641\" xlink:href=\"#DejaVuSans-63\"/>\n",
" <use x=\"241.162109\" xlink:href=\"#DejaVuSans-68\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"matplotlib.axis_2\">\n",
" <g id=\"ytick_1\">\n",
" <g id=\"line2d_5\">\n",
" <defs>\n",
" <path d=\"M 0 0 \n",
"L -3.5 0 \n",
"\" id=\"m1689761f7d\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
" </defs>\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m1689761f7d\" y=\"137.789971\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_6\">\n",
" <!-- 0.25 -->\n",
" <defs>\n",
" <path d=\"M 10.6875 12.40625 \n",
"L 21 12.40625 \n",
"L 21 0 \n",
"L 10.6875 0 \n",
"z\n",
"\" id=\"DejaVuSans-2e\"/>\n",
" <path d=\"M 10.796875 72.90625 \n",
"L 49.515625 72.90625 \n",
"L 49.515625 64.59375 \n",
"L 19.828125 64.59375 \n",
"L 19.828125 46.734375 \n",
"Q 21.96875 47.46875 24.109375 47.828125 \n",
"Q 26.265625 48.1875 28.421875 48.1875 \n",
"Q 40.625 48.1875 47.75 41.5 \n",
"Q 54.890625 34.8125 54.890625 23.390625 \n",
"Q 54.890625 11.625 47.5625 5.09375 \n",
"Q 40.234375 -1.421875 26.90625 -1.421875 \n",
"Q 22.3125 -1.421875 17.546875 -0.640625 \n",
"Q 12.796875 0.140625 7.71875 1.703125 \n",
"L 7.71875 11.625 \n",
"Q 12.109375 9.234375 16.796875 8.0625 \n",
"Q 21.484375 6.890625 26.703125 6.890625 \n",
"Q 35.15625 6.890625 40.078125 11.328125 \n",
"Q 45.015625 15.765625 45.015625 23.390625 \n",
"Q 45.015625 31 40.078125 35.4375 \n",
"Q 35.15625 39.890625 26.703125 39.890625 \n",
"Q 22.75 39.890625 18.8125 39.015625 \n",
"Q 14.890625 38.140625 10.796875 36.28125 \n",
"z\n",
"\" id=\"DejaVuSans-35\"/>\n",
" </defs>\n",
" <g transform=\"translate(20.878125 141.58919)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-32\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_2\">\n",
" <g id=\"line2d_6\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m1689761f7d\" y=\"113.054114\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_7\">\n",
" <!-- 0.30 -->\n",
" <defs>\n",
" <path d=\"M 40.578125 39.3125 \n",
"Q 47.65625 37.796875 51.625 33 \n",
"Q 55.609375 28.21875 55.609375 21.1875 \n",
"Q 55.609375 10.40625 48.1875 4.484375 \n",
"Q 40.765625 -1.421875 27.09375 -1.421875 \n",
"Q 22.515625 -1.421875 17.65625 -0.515625 \n",
"Q 12.796875 0.390625 7.625 2.203125 \n",
"L 7.625 11.71875 \n",
"Q 11.71875 9.328125 16.59375 8.109375 \n",
"Q 21.484375 6.890625 26.8125 6.890625 \n",
"Q 36.078125 6.890625 40.9375 10.546875 \n",
"Q 45.796875 14.203125 45.796875 21.1875 \n",
"Q 45.796875 27.640625 41.28125 31.265625 \n",
"Q 36.765625 34.90625 28.71875 34.90625 \n",
"L 20.21875 34.90625 \n",
"L 20.21875 43.015625 \n",
"L 29.109375 43.015625 \n",
"Q 36.375 43.015625 40.234375 45.921875 \n",
"Q 44.09375 48.828125 44.09375 54.296875 \n",
"Q 44.09375 59.90625 40.109375 62.90625 \n",
"Q 36.140625 65.921875 28.71875 65.921875 \n",
"Q 24.65625 65.921875 20.015625 65.03125 \n",
"Q 15.375 64.15625 9.8125 62.3125 \n",
"L 9.8125 71.09375 \n",
"Q 15.4375 72.65625 20.34375 73.4375 \n",
"Q 25.25 74.21875 29.59375 74.21875 \n",
"Q 40.828125 74.21875 47.359375 69.109375 \n",
"Q 53.90625 64.015625 53.90625 55.328125 \n",
"Q 53.90625 49.265625 50.4375 45.09375 \n",
"Q 46.96875 40.921875 40.578125 39.3125 \n",
"z\n",
"\" id=\"DejaVuSans-33\"/>\n",
" </defs>\n",
" <g transform=\"translate(20.878125 116.853333)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-33\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_3\">\n",
" <g id=\"line2d_7\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m1689761f7d\" y=\"88.318257\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_8\">\n",
" <!-- 0.35 -->\n",
" <g transform=\"translate(20.878125 92.117476)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-33\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_4\">\n",
" <g id=\"line2d_8\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m1689761f7d\" y=\"63.582401\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_9\">\n",
" <!-- 0.40 -->\n",
" <g transform=\"translate(20.878125 67.38162)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-34\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_5\">\n",
" <g id=\"line2d_9\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m1689761f7d\" y=\"38.846544\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_10\">\n",
" <!-- 0.45 -->\n",
" <g transform=\"translate(20.878125 42.645763)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-34\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_6\">\n",
" <g id=\"line2d_10\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m1689761f7d\" y=\"14.110687\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_11\">\n",
" <!-- 0.50 -->\n",
" <g transform=\"translate(20.878125 17.909906)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-35\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_12\">\n",
" <!-- loss -->\n",
" <defs>\n",
" <path d=\"M 9.421875 75.984375 \n",
"L 18.40625 75.984375 \n",
"L 18.40625 0 \n",
"L 9.421875 0 \n",
"z\n",
"\" id=\"DejaVuSans-6c\"/>\n",
" <path d=\"M 44.28125 53.078125 \n",
"L 44.28125 44.578125 \n",
"Q 40.484375 46.53125 36.375 47.5 \n",
"Q 32.28125 48.484375 27.875 48.484375 \n",
"Q 21.1875 48.484375 17.84375 46.4375 \n",
"Q 14.5 44.390625 14.5 40.28125 \n",
"Q 14.5 37.15625 16.890625 35.375 \n",
"Q 19.28125 33.59375 26.515625 31.984375 \n",
"L 29.59375 31.296875 \n",
"Q 39.15625 29.25 43.1875 25.515625 \n",
"Q 47.21875 21.78125 47.21875 15.09375 \n",
"Q 47.21875 7.46875 41.1875 3.015625 \n",
"Q 35.15625 -1.421875 24.609375 -1.421875 \n",
"Q 20.21875 -1.421875 15.453125 -0.5625 \n",
"Q 10.6875 0.296875 5.421875 2 \n",
"L 5.421875 11.28125 \n",
"Q 10.40625 8.6875 15.234375 7.390625 \n",
"Q 20.0625 6.109375 24.8125 6.109375 \n",
"Q 31.15625 6.109375 34.5625 8.28125 \n",
"Q 37.984375 10.453125 37.984375 14.40625 \n",
"Q 37.984375 18.0625 35.515625 20.015625 \n",
"Q 33.0625 21.96875 24.703125 23.78125 \n",
"L 21.578125 24.515625 \n",
"Q 13.234375 26.265625 9.515625 29.90625 \n",
"Q 5.8125 33.546875 5.8125 39.890625 \n",
"Q 5.8125 47.609375 11.28125 51.796875 \n",
"Q 16.75 56 26.8125 56 \n",
"Q 31.78125 56 36.171875 55.265625 \n",
"Q 40.578125 54.546875 44.28125 53.078125 \n",
"z\n",
"\" id=\"DejaVuSans-73\"/>\n",
" </defs>\n",
" <g transform=\"translate(14.798437 88.307812)rotate(-90)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-6c\"/>\n",
" <use x=\"27.783203\" xlink:href=\"#DejaVuSans-6f\"/>\n",
" <use x=\"88.964844\" xlink:href=\"#DejaVuSans-73\"/>\n",
" <use x=\"141.064453\" xlink:href=\"#DejaVuSans-73\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_11\">\n",
" <path clip-path=\"url(#p6649500de0)\" d=\"M 59.021023 16.877273 \n",
"L 88.611932 124.908236 \n",
"L 118.202841 137.639953 \n",
"L 147.79375 139.589663 \n",
"L 177.384659 140.12685 \n",
"L 206.975568 140.351051 \n",
"L 236.566477 140.422727 \n",
"\" style=\"fill:none;stroke:#1f77b4;stroke-linecap:square;stroke-width:1.5;\"/>\n",
" </g>\n",
" <g id=\"patch_3\">\n",
" <path d=\"M 50.14375 146.6 \n",
"L 50.14375 10.7 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" <g id=\"patch_4\">\n",
" <path d=\"M 245.44375 146.6 \n",
"L 245.44375 10.7 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" <g id=\"patch_5\">\n",
" <path d=\"M 50.14375 146.6 \n",
"L 245.44375 146.6 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" <g id=\"patch_6\">\n",
" <path d=\"M 50.14375 10.7 \n",
"L 245.44375 10.7 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <defs>\n",
" <clipPath id=\"p6649500de0\">\n",
" <rect height=\"135.9\" width=\"195.3\" x=\"50.14375\" y=\"10.7\"/>\n",
" </clipPath>\n",
" </defs>\n",
"</svg>\n"
],
"text/plain": [
"<matplotlib.figure.Figure at 0x10f399dd8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def train_sgd(lr, batch_size, num_epochs=2):\n",
" train_ch7(sgd, None, {'lr': lr}, features, labels, batch_size, num_epochs)\n",
"\n",
"train_sgd(1, 1500, 6)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss: 0.246391, 0.334214 sec per epoch\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Created with matplotlib (http://matplotlib.org/) -->\n",
"<svg height=\"184pt\" version=\"1.1\" viewBox=\"0 0 256 184\" width=\"256pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
" <defs>\n",
" <style type=\"text/css\">\n",
"*{stroke-linecap:butt;stroke-linejoin:round;}\n",
" </style>\n",
" </defs>\n",
" <g id=\"figure_1\">\n",
" <g id=\"patch_1\">\n",
" <path d=\"M 0 184.15625 \n",
"L 256.14375 184.15625 \n",
"L 256.14375 -0 \n",
"L 0 -0 \n",
"z\n",
"\" style=\"fill:none;\"/>\n",
" </g>\n",
" <g id=\"axes_1\">\n",
" <g id=\"patch_2\">\n",
" <path d=\"M 50.14375 146.6 \n",
"L 245.44375 146.6 \n",
"L 245.44375 10.7 \n",
"L 50.14375 10.7 \n",
"z\n",
"\" style=\"fill:#ffffff;\"/>\n",
" </g>\n",
" <g id=\"matplotlib.axis_1\">\n",
" <g id=\"xtick_1\">\n",
" <g id=\"line2d_1\">\n",
" <defs>\n",
" <path d=\"M 0 0 \n",
"L 0 3.5 \n",
"\" id=\"m442991208f\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
" </defs>\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"59.021023\" xlink:href=\"#m442991208f\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_1\">\n",
" <!-- 0.0 -->\n",
" <defs>\n",
" <path d=\"M 31.78125 66.40625 \n",
"Q 24.171875 66.40625 20.328125 58.90625 \n",
"Q 16.5 51.421875 16.5 36.375 \n",
"Q 16.5 21.390625 20.328125 13.890625 \n",
"Q 24.171875 6.390625 31.78125 6.390625 \n",
"Q 39.453125 6.390625 43.28125 13.890625 \n",
"Q 47.125 21.390625 47.125 36.375 \n",
"Q 47.125 51.421875 43.28125 58.90625 \n",
"Q 39.453125 66.40625 31.78125 66.40625 \n",
"z\n",
"M 31.78125 74.21875 \n",
"Q 44.046875 74.21875 50.515625 64.515625 \n",
"Q 56.984375 54.828125 56.984375 36.375 \n",
"Q 56.984375 17.96875 50.515625 8.265625 \n",
"Q 44.046875 -1.421875 31.78125 -1.421875 \n",
"Q 19.53125 -1.421875 13.0625 8.265625 \n",
"Q 6.59375 17.96875 6.59375 36.375 \n",
"Q 6.59375 54.828125 13.0625 64.515625 \n",
"Q 19.53125 74.21875 31.78125 74.21875 \n",
"z\n",
"\" id=\"DejaVuSans-30\"/>\n",
" <path d=\"M 10.6875 12.40625 \n",
"L 21 12.40625 \n",
"L 21 0 \n",
"L 10.6875 0 \n",
"z\n",
"\" id=\"DejaVuSans-2e\"/>\n",
" </defs>\n",
" <g transform=\"translate(51.06946 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_2\">\n",
" <g id=\"line2d_2\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"103.407386\" xlink:href=\"#m442991208f\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_2\">\n",
" <!-- 0.5 -->\n",
" <defs>\n",
" <path d=\"M 10.796875 72.90625 \n",
"L 49.515625 72.90625 \n",
"L 49.515625 64.59375 \n",
"L 19.828125 64.59375 \n",
"L 19.828125 46.734375 \n",
"Q 21.96875 47.46875 24.109375 47.828125 \n",
"Q 26.265625 48.1875 28.421875 48.1875 \n",
"Q 40.625 48.1875 47.75 41.5 \n",
"Q 54.890625 34.8125 54.890625 23.390625 \n",
"Q 54.890625 11.625 47.5625 5.09375 \n",
"Q 40.234375 -1.421875 26.90625 -1.421875 \n",
"Q 22.3125 -1.421875 17.546875 -0.640625 \n",
"Q 12.796875 0.140625 7.71875 1.703125 \n",
"L 7.71875 11.625 \n",
"Q 12.109375 9.234375 16.796875 8.0625 \n",
"Q 21.484375 6.890625 26.703125 6.890625 \n",
"Q 35.15625 6.890625 40.078125 11.328125 \n",
"Q 45.015625 15.765625 45.015625 23.390625 \n",
"Q 45.015625 31 40.078125 35.4375 \n",
"Q 35.15625 39.890625 26.703125 39.890625 \n",
"Q 22.75 39.890625 18.8125 39.015625 \n",
"Q 14.890625 38.140625 10.796875 36.28125 \n",
"z\n",
"\" id=\"DejaVuSans-35\"/>\n",
" </defs>\n",
" <g transform=\"translate(95.455824 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_3\">\n",
" <g id=\"line2d_3\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"147.79375\" xlink:href=\"#m442991208f\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_3\">\n",
" <!-- 1.0 -->\n",
" <defs>\n",
" <path d=\"M 12.40625 8.296875 \n",
"L 28.515625 8.296875 \n",
"L 28.515625 63.921875 \n",
"L 10.984375 60.40625 \n",
"L 10.984375 69.390625 \n",
"L 28.421875 72.90625 \n",
"L 38.28125 72.90625 \n",
"L 38.28125 8.296875 \n",
"L 54.390625 8.296875 \n",
"L 54.390625 0 \n",
"L 12.40625 0 \n",
"z\n",
"\" id=\"DejaVuSans-31\"/>\n",
" </defs>\n",
" <g transform=\"translate(139.842187 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_4\">\n",
" <g id=\"line2d_4\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"192.180114\" xlink:href=\"#m442991208f\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_4\">\n",
" <!-- 1.5 -->\n",
" <g transform=\"translate(184.228551 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_5\">\n",
" <g id=\"line2d_5\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"236.566477\" xlink:href=\"#m442991208f\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_5\">\n",
" <!-- 2.0 -->\n",
" <defs>\n",
" <path d=\"M 19.1875 8.296875 \n",
"L 53.609375 8.296875 \n",
"L 53.609375 0 \n",
"L 7.328125 0 \n",
"L 7.328125 8.296875 \n",
"Q 12.9375 14.109375 22.625 23.890625 \n",
"Q 32.328125 33.6875 34.8125 36.53125 \n",
"Q 39.546875 41.84375 41.421875 45.53125 \n",
"Q 43.3125 49.21875 43.3125 52.78125 \n",
"Q 43.3125 58.59375 39.234375 62.25 \n",
"Q 35.15625 65.921875 28.609375 65.921875 \n",
"Q 23.96875 65.921875 18.8125 64.3125 \n",
"Q 13.671875 62.703125 7.8125 59.421875 \n",
"L 7.8125 69.390625 \n",
"Q 13.765625 71.78125 18.9375 73 \n",
"Q 24.125 74.21875 28.421875 74.21875 \n",
"Q 39.75 74.21875 46.484375 68.546875 \n",
"Q 53.21875 62.890625 53.21875 53.421875 \n",
"Q 53.21875 48.921875 51.53125 44.890625 \n",
"Q 49.859375 40.875 45.40625 35.40625 \n",
"Q 44.1875 33.984375 37.640625 27.21875 \n",
"Q 31.109375 20.453125 19.1875 8.296875 \n",
"z\n",
"\" id=\"DejaVuSans-32\"/>\n",
" </defs>\n",
" <g transform=\"translate(228.614915 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_6\">\n",
" <!-- epoch -->\n",
" <defs>\n",
" <path d=\"M 56.203125 29.59375 \n",
"L 56.203125 25.203125 \n",
"L 14.890625 25.203125 \n",
"Q 15.484375 15.921875 20.484375 11.0625 \n",
"Q 25.484375 6.203125 34.421875 6.203125 \n",
"Q 39.59375 6.203125 44.453125 7.46875 \n",
"Q 49.3125 8.734375 54.109375 11.28125 \n",
"L 54.109375 2.78125 \n",
"Q 49.265625 0.734375 44.1875 -0.34375 \n",
"Q 39.109375 -1.421875 33.890625 -1.421875 \n",
"Q 20.796875 -1.421875 13.15625 6.1875 \n",
"Q 5.515625 13.8125 5.515625 26.8125 \n",
"Q 5.515625 40.234375 12.765625 48.109375 \n",
"Q 20.015625 56 32.328125 56 \n",
"Q 43.359375 56 49.78125 48.890625 \n",
"Q 56.203125 41.796875 56.203125 29.59375 \n",
"z\n",
"M 47.21875 32.234375 \n",
"Q 47.125 39.59375 43.09375 43.984375 \n",
"Q 39.0625 48.390625 32.421875 48.390625 \n",
"Q 24.90625 48.390625 20.390625 44.140625 \n",
"Q 15.875 39.890625 15.1875 32.171875 \n",
"z\n",
"\" id=\"DejaVuSans-65\"/>\n",
" <path d=\"M 18.109375 8.203125 \n",
"L 18.109375 -20.796875 \n",
"L 9.078125 -20.796875 \n",
"L 9.078125 54.6875 \n",
"L 18.109375 54.6875 \n",
"L 18.109375 46.390625 \n",
"Q 20.953125 51.265625 25.265625 53.625 \n",
"Q 29.59375 56 35.59375 56 \n",
"Q 45.5625 56 51.78125 48.09375 \n",
"Q 58.015625 40.1875 58.015625 27.296875 \n",
"Q 58.015625 14.40625 51.78125 6.484375 \n",
"Q 45.5625 -1.421875 35.59375 -1.421875 \n",
"Q 29.59375 -1.421875 25.265625 0.953125 \n",
"Q 20.953125 3.328125 18.109375 8.203125 \n",
"z\n",
"M 48.6875 27.296875 \n",
"Q 48.6875 37.203125 44.609375 42.84375 \n",
"Q 40.53125 48.484375 33.40625 48.484375 \n",
"Q 26.265625 48.484375 22.1875 42.84375 \n",
"Q 18.109375 37.203125 18.109375 27.296875 \n",
"Q 18.109375 17.390625 22.1875 11.75 \n",
"Q 26.265625 6.109375 33.40625 6.109375 \n",
"Q 40.53125 6.109375 44.609375 11.75 \n",
"Q 48.6875 17.390625 48.6875 27.296875 \n",
"z\n",
"\" id=\"DejaVuSans-70\"/>\n",
" <path d=\"M 30.609375 48.390625 \n",
"Q 23.390625 48.390625 19.1875 42.75 \n",
"Q 14.984375 37.109375 14.984375 27.296875 \n",
"Q 14.984375 17.484375 19.15625 11.84375 \n",
"Q 23.34375 6.203125 30.609375 6.203125 \n",
"Q 37.796875 6.203125 41.984375 11.859375 \n",
"Q 46.1875 17.53125 46.1875 27.296875 \n",
"Q 46.1875 37.015625 41.984375 42.703125 \n",
"Q 37.796875 48.390625 30.609375 48.390625 \n",
"z\n",
"M 30.609375 56 \n",
"Q 42.328125 56 49.015625 48.375 \n",
"Q 55.71875 40.765625 55.71875 27.296875 \n",
"Q 55.71875 13.875 49.015625 6.21875 \n",
"Q 42.328125 -1.421875 30.609375 -1.421875 \n",
"Q 18.84375 -1.421875 12.171875 6.21875 \n",
"Q 5.515625 13.875 5.515625 27.296875 \n",
"Q 5.515625 40.765625 12.171875 48.375 \n",
"Q 18.84375 56 30.609375 56 \n",
"z\n",
"\" id=\"DejaVuSans-6f\"/>\n",
" <path d=\"M 48.78125 52.59375 \n",
"L 48.78125 44.1875 \n",
"Q 44.96875 46.296875 41.140625 47.34375 \n",
"Q 37.3125 48.390625 33.40625 48.390625 \n",
"Q 24.65625 48.390625 19.8125 42.84375 \n",
"Q 14.984375 37.3125 14.984375 27.296875 \n",
"Q 14.984375 17.28125 19.8125 11.734375 \n",
"Q 24.65625 6.203125 33.40625 6.203125 \n",
"Q 37.3125 6.203125 41.140625 7.25 \n",
"Q 44.96875 8.296875 48.78125 10.40625 \n",
"L 48.78125 2.09375 \n",
"Q 45.015625 0.34375 40.984375 -0.53125 \n",
"Q 36.96875 -1.421875 32.421875 -1.421875 \n",
"Q 20.0625 -1.421875 12.78125 6.34375 \n",
"Q 5.515625 14.109375 5.515625 27.296875 \n",
"Q 5.515625 40.671875 12.859375 48.328125 \n",
"Q 20.21875 56 33.015625 56 \n",
"Q 37.15625 56 41.109375 55.140625 \n",
"Q 45.0625 54.296875 48.78125 52.59375 \n",
"z\n",
"\" id=\"DejaVuSans-63\"/>\n",
" <path d=\"M 54.890625 33.015625 \n",
"L 54.890625 0 \n",
"L 45.90625 0 \n",
"L 45.90625 32.71875 \n",
"Q 45.90625 40.484375 42.875 44.328125 \n",
"Q 39.84375 48.1875 33.796875 48.1875 \n",
"Q 26.515625 48.1875 22.3125 43.546875 \n",
"Q 18.109375 38.921875 18.109375 30.90625 \n",
"L 18.109375 0 \n",
"L 9.078125 0 \n",
"L 9.078125 75.984375 \n",
"L 18.109375 75.984375 \n",
"L 18.109375 46.1875 \n",
"Q 21.34375 51.125 25.703125 53.5625 \n",
"Q 30.078125 56 35.796875 56 \n",
"Q 45.21875 56 50.046875 50.171875 \n",
"Q 54.890625 44.34375 54.890625 33.015625 \n",
"z\n",
"\" id=\"DejaVuSans-68\"/>\n",
" </defs>\n",
" <g transform=\"translate(132.565625 174.876562)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-65\"/>\n",
" <use x=\"61.523438\" xlink:href=\"#DejaVuSans-70\"/>\n",
" <use x=\"125\" xlink:href=\"#DejaVuSans-6f\"/>\n",
" <use x=\"186.181641\" xlink:href=\"#DejaVuSans-63\"/>\n",
" <use x=\"241.162109\" xlink:href=\"#DejaVuSans-68\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"matplotlib.axis_2\">\n",
" <g id=\"ytick_1\">\n",
" <g id=\"line2d_6\">\n",
" <defs>\n",
" <path d=\"M 0 0 \n",
"L -3.5 0 \n",
"\" id=\"mde597fe460\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
" </defs>\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#mde597fe460\" y=\"136.660519\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_7\">\n",
" <!-- 0.25 -->\n",
" <g transform=\"translate(20.878125 140.459738)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-32\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_2\">\n",
" <g id=\"line2d_7\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#mde597fe460\" y=\"110.924297\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_8\">\n",
" <!-- 0.30 -->\n",
" <defs>\n",
" <path d=\"M 40.578125 39.3125 \n",
"Q 47.65625 37.796875 51.625 33 \n",
"Q 55.609375 28.21875 55.609375 21.1875 \n",
"Q 55.609375 10.40625 48.1875 4.484375 \n",
"Q 40.765625 -1.421875 27.09375 -1.421875 \n",
"Q 22.515625 -1.421875 17.65625 -0.515625 \n",
"Q 12.796875 0.390625 7.625 2.203125 \n",
"L 7.625 11.71875 \n",
"Q 11.71875 9.328125 16.59375 8.109375 \n",
"Q 21.484375 6.890625 26.8125 6.890625 \n",
"Q 36.078125 6.890625 40.9375 10.546875 \n",
"Q 45.796875 14.203125 45.796875 21.1875 \n",
"Q 45.796875 27.640625 41.28125 31.265625 \n",
"Q 36.765625 34.90625 28.71875 34.90625 \n",
"L 20.21875 34.90625 \n",
"L 20.21875 43.015625 \n",
"L 29.109375 43.015625 \n",
"Q 36.375 43.015625 40.234375 45.921875 \n",
"Q 44.09375 48.828125 44.09375 54.296875 \n",
"Q 44.09375 59.90625 40.109375 62.90625 \n",
"Q 36.140625 65.921875 28.71875 65.921875 \n",
"Q 24.65625 65.921875 20.015625 65.03125 \n",
"Q 15.375 64.15625 9.8125 62.3125 \n",
"L 9.8125 71.09375 \n",
"Q 15.4375 72.65625 20.34375 73.4375 \n",
"Q 25.25 74.21875 29.59375 74.21875 \n",
"Q 40.828125 74.21875 47.359375 69.109375 \n",
"Q 53.90625 64.015625 53.90625 55.328125 \n",
"Q 53.90625 49.265625 50.4375 45.09375 \n",
"Q 46.96875 40.921875 40.578125 39.3125 \n",
"z\n",
"\" id=\"DejaVuSans-33\"/>\n",
" </defs>\n",
" <g transform=\"translate(20.878125 114.723516)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-33\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_3\">\n",
" <g id=\"line2d_8\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#mde597fe460\" y=\"85.188075\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_9\">\n",
" <!-- 0.35 -->\n",
" <g transform=\"translate(20.878125 88.987293)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-33\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_4\">\n",
" <g id=\"line2d_9\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#mde597fe460\" y=\"59.451852\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_10\">\n",
" <!-- 0.40 -->\n",
" <defs>\n",
" <path d=\"M 37.796875 64.3125 \n",
"L 12.890625 25.390625 \n",
"L 37.796875 25.390625 \n",
"z\n",
"M 35.203125 72.90625 \n",
"L 47.609375 72.90625 \n",
"L 47.609375 25.390625 \n",
"L 58.015625 25.390625 \n",
"L 58.015625 17.1875 \n",
"L 47.609375 17.1875 \n",
"L 47.609375 0 \n",
"L 37.796875 0 \n",
"L 37.796875 17.1875 \n",
"L 4.890625 17.1875 \n",
"L 4.890625 26.703125 \n",
"z\n",
"\" id=\"DejaVuSans-34\"/>\n",
" </defs>\n",
" <g transform=\"translate(20.878125 63.251071)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-34\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_5\">\n",
" <g id=\"line2d_10\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#mde597fe460\" y=\"33.71563\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_11\">\n",
" <!-- 0.45 -->\n",
" <g transform=\"translate(20.878125 37.514848)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-34\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_12\">\n",
" <!-- loss -->\n",
" <defs>\n",
" <path d=\"M 9.421875 75.984375 \n",
"L 18.40625 75.984375 \n",
"L 18.40625 0 \n",
"L 9.421875 0 \n",
"z\n",
"\" id=\"DejaVuSans-6c\"/>\n",
" <path d=\"M 44.28125 53.078125 \n",
"L 44.28125 44.578125 \n",
"Q 40.484375 46.53125 36.375 47.5 \n",
"Q 32.28125 48.484375 27.875 48.484375 \n",
"Q 21.1875 48.484375 17.84375 46.4375 \n",
"Q 14.5 44.390625 14.5 40.28125 \n",
"Q 14.5 37.15625 16.890625 35.375 \n",
"Q 19.28125 33.59375 26.515625 31.984375 \n",
"L 29.59375 31.296875 \n",
"Q 39.15625 29.25 43.1875 25.515625 \n",
"Q 47.21875 21.78125 47.21875 15.09375 \n",
"Q 47.21875 7.46875 41.1875 3.015625 \n",
"Q 35.15625 -1.421875 24.609375 -1.421875 \n",
"Q 20.21875 -1.421875 15.453125 -0.5625 \n",
"Q 10.6875 0.296875 5.421875 2 \n",
"L 5.421875 11.28125 \n",
"Q 10.40625 8.6875 15.234375 7.390625 \n",
"Q 20.0625 6.109375 24.8125 6.109375 \n",
"Q 31.15625 6.109375 34.5625 8.28125 \n",
"Q 37.984375 10.453125 37.984375 14.40625 \n",
"Q 37.984375 18.0625 35.515625 20.015625 \n",
"Q 33.0625 21.96875 24.703125 23.78125 \n",
"L 21.578125 24.515625 \n",
"Q 13.234375 26.265625 9.515625 29.90625 \n",
"Q 5.8125 33.546875 5.8125 39.890625 \n",
"Q 5.8125 47.609375 11.28125 51.796875 \n",
"Q 16.75 56 26.8125 56 \n",
"Q 31.78125 56 36.171875 55.265625 \n",
"Q 40.578125 54.546875 44.28125 53.078125 \n",
"z\n",
"\" id=\"DejaVuSans-73\"/>\n",
" </defs>\n",
" <g transform=\"translate(14.798437 88.307812)rotate(-90)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-6c\"/>\n",
" <use x=\"27.783203\" xlink:href=\"#DejaVuSans-6f\"/>\n",
" <use x=\"88.964844\" xlink:href=\"#DejaVuSans-73\"/>\n",
" <use x=\"141.064453\" xlink:href=\"#DejaVuSans-73\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_11\">\n",
" <path clip-path=\"url(#p4d103a3dcb)\" d=\"M 59.021023 16.877273 \n",
"L 64.939205 82.993478 \n",
"L 70.857386 105.607222 \n",
"L 76.775568 119.81194 \n",
"L 82.69375 129.692301 \n",
"L 88.611932 132.856694 \n",
"L 94.530114 133.288162 \n",
"L 100.448295 134.78806 \n",
"L 106.366477 138.827699 \n",
"L 112.284659 138.328651 \n",
"L 118.202841 139.246925 \n",
"L 124.121023 140.045732 \n",
"L 130.039205 139.235129 \n",
"L 135.957386 137.352399 \n",
"L 141.875568 138.736717 \n",
"L 147.79375 140.422727 \n",
"L 153.711932 139.574948 \n",
"L 159.630114 139.087037 \n",
"L 165.548295 140.202361 \n",
"L 171.466477 139.241088 \n",
"L 177.384659 140.386364 \n",
"L 183.302841 138.703276 \n",
"L 189.221023 136.974015 \n",
"L 195.139205 137.584999 \n",
"L 201.057386 140.262808 \n",
"L 206.975568 138.665494 \n",
"L 212.89375 139.906084 \n",
"L 218.811932 139.739055 \n",
"L 224.730114 138.754312 \n",
"L 230.648295 138.587682 \n",
"L 236.566477 138.51836 \n",
"\" style=\"fill:none;stroke:#1f77b4;stroke-linecap:square;stroke-width:1.5;\"/>\n",
" </g>\n",
" <g id=\"patch_3\">\n",
" <path d=\"M 50.14375 146.6 \n",
"L 50.14375 10.7 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" <g id=\"patch_4\">\n",
" <path d=\"M 245.44375 146.6 \n",
"L 245.44375 10.7 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" <g id=\"patch_5\">\n",
" <path d=\"M 50.14375 146.6 \n",
"L 245.44375 146.6 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" <g id=\"patch_6\">\n",
" <path d=\"M 50.14375 10.7 \n",
"L 245.44375 10.7 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <defs>\n",
" <clipPath id=\"p4d103a3dcb\">\n",
" <rect height=\"135.9\" width=\"195.3\" x=\"50.14375\" y=\"10.7\"/>\n",
" </clipPath>\n",
" </defs>\n",
"</svg>\n"
],
"text/plain": [
"<matplotlib.figure.Figure at 0x11c61a400>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"train_sgd(0.005, 1)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss: 0.245523, 0.050718 sec per epoch\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Created with matplotlib (http://matplotlib.org/) -->\n",
"<svg height=\"184pt\" version=\"1.1\" viewBox=\"0 0 256 184\" width=\"256pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
" <defs>\n",
" <style type=\"text/css\">\n",
"*{stroke-linecap:butt;stroke-linejoin:round;}\n",
" </style>\n",
" </defs>\n",
" <g id=\"figure_1\">\n",
" <g id=\"patch_1\">\n",
" <path d=\"M 0 184.15625 \n",
"L 256.14375 184.15625 \n",
"L 256.14375 -0 \n",
"L 0 -0 \n",
"z\n",
"\" style=\"fill:none;\"/>\n",
" </g>\n",
" <g id=\"axes_1\">\n",
" <g id=\"patch_2\">\n",
" <path d=\"M 50.14375 146.6 \n",
"L 245.44375 146.6 \n",
"L 245.44375 10.7 \n",
"L 50.14375 10.7 \n",
"z\n",
"\" style=\"fill:#ffffff;\"/>\n",
" </g>\n",
" <g id=\"matplotlib.axis_1\">\n",
" <g id=\"xtick_1\">\n",
" <g id=\"line2d_1\">\n",
" <defs>\n",
" <path d=\"M 0 0 \n",
"L 0 3.5 \n",
"\" id=\"mf9413430b5\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
" </defs>\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"59.021023\" xlink:href=\"#mf9413430b5\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_1\">\n",
" <!-- 0.0 -->\n",
" <defs>\n",
" <path d=\"M 31.78125 66.40625 \n",
"Q 24.171875 66.40625 20.328125 58.90625 \n",
"Q 16.5 51.421875 16.5 36.375 \n",
"Q 16.5 21.390625 20.328125 13.890625 \n",
"Q 24.171875 6.390625 31.78125 6.390625 \n",
"Q 39.453125 6.390625 43.28125 13.890625 \n",
"Q 47.125 21.390625 47.125 36.375 \n",
"Q 47.125 51.421875 43.28125 58.90625 \n",
"Q 39.453125 66.40625 31.78125 66.40625 \n",
"z\n",
"M 31.78125 74.21875 \n",
"Q 44.046875 74.21875 50.515625 64.515625 \n",
"Q 56.984375 54.828125 56.984375 36.375 \n",
"Q 56.984375 17.96875 50.515625 8.265625 \n",
"Q 44.046875 -1.421875 31.78125 -1.421875 \n",
"Q 19.53125 -1.421875 13.0625 8.265625 \n",
"Q 6.59375 17.96875 6.59375 36.375 \n",
"Q 6.59375 54.828125 13.0625 64.515625 \n",
"Q 19.53125 74.21875 31.78125 74.21875 \n",
"z\n",
"\" id=\"DejaVuSans-30\"/>\n",
" <path d=\"M 10.6875 12.40625 \n",
"L 21 12.40625 \n",
"L 21 0 \n",
"L 10.6875 0 \n",
"z\n",
"\" id=\"DejaVuSans-2e\"/>\n",
" </defs>\n",
" <g transform=\"translate(51.06946 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_2\">\n",
" <g id=\"line2d_2\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"103.407386\" xlink:href=\"#mf9413430b5\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_2\">\n",
" <!-- 0.5 -->\n",
" <defs>\n",
" <path d=\"M 10.796875 72.90625 \n",
"L 49.515625 72.90625 \n",
"L 49.515625 64.59375 \n",
"L 19.828125 64.59375 \n",
"L 19.828125 46.734375 \n",
"Q 21.96875 47.46875 24.109375 47.828125 \n",
"Q 26.265625 48.1875 28.421875 48.1875 \n",
"Q 40.625 48.1875 47.75 41.5 \n",
"Q 54.890625 34.8125 54.890625 23.390625 \n",
"Q 54.890625 11.625 47.5625 5.09375 \n",
"Q 40.234375 -1.421875 26.90625 -1.421875 \n",
"Q 22.3125 -1.421875 17.546875 -0.640625 \n",
"Q 12.796875 0.140625 7.71875 1.703125 \n",
"L 7.71875 11.625 \n",
"Q 12.109375 9.234375 16.796875 8.0625 \n",
"Q 21.484375 6.890625 26.703125 6.890625 \n",
"Q 35.15625 6.890625 40.078125 11.328125 \n",
"Q 45.015625 15.765625 45.015625 23.390625 \n",
"Q 45.015625 31 40.078125 35.4375 \n",
"Q 35.15625 39.890625 26.703125 39.890625 \n",
"Q 22.75 39.890625 18.8125 39.015625 \n",
"Q 14.890625 38.140625 10.796875 36.28125 \n",
"z\n",
"\" id=\"DejaVuSans-35\"/>\n",
" </defs>\n",
" <g transform=\"translate(95.455824 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_3\">\n",
" <g id=\"line2d_3\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"147.79375\" xlink:href=\"#mf9413430b5\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_3\">\n",
" <!-- 1.0 -->\n",
" <defs>\n",
" <path d=\"M 12.40625 8.296875 \n",
"L 28.515625 8.296875 \n",
"L 28.515625 63.921875 \n",
"L 10.984375 60.40625 \n",
"L 10.984375 69.390625 \n",
"L 28.421875 72.90625 \n",
"L 38.28125 72.90625 \n",
"L 38.28125 8.296875 \n",
"L 54.390625 8.296875 \n",
"L 54.390625 0 \n",
"L 12.40625 0 \n",
"z\n",
"\" id=\"DejaVuSans-31\"/>\n",
" </defs>\n",
" <g transform=\"translate(139.842187 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_4\">\n",
" <g id=\"line2d_4\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"192.180114\" xlink:href=\"#mf9413430b5\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_4\">\n",
" <!-- 1.5 -->\n",
" <g transform=\"translate(184.228551 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_5\">\n",
" <g id=\"line2d_5\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"236.566477\" xlink:href=\"#mf9413430b5\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_5\">\n",
" <!-- 2.0 -->\n",
" <defs>\n",
" <path d=\"M 19.1875 8.296875 \n",
"L 53.609375 8.296875 \n",
"L 53.609375 0 \n",
"L 7.328125 0 \n",
"L 7.328125 8.296875 \n",
"Q 12.9375 14.109375 22.625 23.890625 \n",
"Q 32.328125 33.6875 34.8125 36.53125 \n",
"Q 39.546875 41.84375 41.421875 45.53125 \n",
"Q 43.3125 49.21875 43.3125 52.78125 \n",
"Q 43.3125 58.59375 39.234375 62.25 \n",
"Q 35.15625 65.921875 28.609375 65.921875 \n",
"Q 23.96875 65.921875 18.8125 64.3125 \n",
"Q 13.671875 62.703125 7.8125 59.421875 \n",
"L 7.8125 69.390625 \n",
"Q 13.765625 71.78125 18.9375 73 \n",
"Q 24.125 74.21875 28.421875 74.21875 \n",
"Q 39.75 74.21875 46.484375 68.546875 \n",
"Q 53.21875 62.890625 53.21875 53.421875 \n",
"Q 53.21875 48.921875 51.53125 44.890625 \n",
"Q 49.859375 40.875 45.40625 35.40625 \n",
"Q 44.1875 33.984375 37.640625 27.21875 \n",
"Q 31.109375 20.453125 19.1875 8.296875 \n",
"z\n",
"\" id=\"DejaVuSans-32\"/>\n",
" </defs>\n",
" <g transform=\"translate(228.614915 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_6\">\n",
" <!-- epoch -->\n",
" <defs>\n",
" <path d=\"M 56.203125 29.59375 \n",
"L 56.203125 25.203125 \n",
"L 14.890625 25.203125 \n",
"Q 15.484375 15.921875 20.484375 11.0625 \n",
"Q 25.484375 6.203125 34.421875 6.203125 \n",
"Q 39.59375 6.203125 44.453125 7.46875 \n",
"Q 49.3125 8.734375 54.109375 11.28125 \n",
"L 54.109375 2.78125 \n",
"Q 49.265625 0.734375 44.1875 -0.34375 \n",
"Q 39.109375 -1.421875 33.890625 -1.421875 \n",
"Q 20.796875 -1.421875 13.15625 6.1875 \n",
"Q 5.515625 13.8125 5.515625 26.8125 \n",
"Q 5.515625 40.234375 12.765625 48.109375 \n",
"Q 20.015625 56 32.328125 56 \n",
"Q 43.359375 56 49.78125 48.890625 \n",
"Q 56.203125 41.796875 56.203125 29.59375 \n",
"z\n",
"M 47.21875 32.234375 \n",
"Q 47.125 39.59375 43.09375 43.984375 \n",
"Q 39.0625 48.390625 32.421875 48.390625 \n",
"Q 24.90625 48.390625 20.390625 44.140625 \n",
"Q 15.875 39.890625 15.1875 32.171875 \n",
"z\n",
"\" id=\"DejaVuSans-65\"/>\n",
" <path d=\"M 18.109375 8.203125 \n",
"L 18.109375 -20.796875 \n",
"L 9.078125 -20.796875 \n",
"L 9.078125 54.6875 \n",
"L 18.109375 54.6875 \n",
"L 18.109375 46.390625 \n",
"Q 20.953125 51.265625 25.265625 53.625 \n",
"Q 29.59375 56 35.59375 56 \n",
"Q 45.5625 56 51.78125 48.09375 \n",
"Q 58.015625 40.1875 58.015625 27.296875 \n",
"Q 58.015625 14.40625 51.78125 6.484375 \n",
"Q 45.5625 -1.421875 35.59375 -1.421875 \n",
"Q 29.59375 -1.421875 25.265625 0.953125 \n",
"Q 20.953125 3.328125 18.109375 8.203125 \n",
"z\n",
"M 48.6875 27.296875 \n",
"Q 48.6875 37.203125 44.609375 42.84375 \n",
"Q 40.53125 48.484375 33.40625 48.484375 \n",
"Q 26.265625 48.484375 22.1875 42.84375 \n",
"Q 18.109375 37.203125 18.109375 27.296875 \n",
"Q 18.109375 17.390625 22.1875 11.75 \n",
"Q 26.265625 6.109375 33.40625 6.109375 \n",
"Q 40.53125 6.109375 44.609375 11.75 \n",
"Q 48.6875 17.390625 48.6875 27.296875 \n",
"z\n",
"\" id=\"DejaVuSans-70\"/>\n",
" <path d=\"M 30.609375 48.390625 \n",
"Q 23.390625 48.390625 19.1875 42.75 \n",
"Q 14.984375 37.109375 14.984375 27.296875 \n",
"Q 14.984375 17.484375 19.15625 11.84375 \n",
"Q 23.34375 6.203125 30.609375 6.203125 \n",
"Q 37.796875 6.203125 41.984375 11.859375 \n",
"Q 46.1875 17.53125 46.1875 27.296875 \n",
"Q 46.1875 37.015625 41.984375 42.703125 \n",
"Q 37.796875 48.390625 30.609375 48.390625 \n",
"z\n",
"M 30.609375 56 \n",
"Q 42.328125 56 49.015625 48.375 \n",
"Q 55.71875 40.765625 55.71875 27.296875 \n",
"Q 55.71875 13.875 49.015625 6.21875 \n",
"Q 42.328125 -1.421875 30.609375 -1.421875 \n",
"Q 18.84375 -1.421875 12.171875 6.21875 \n",
"Q 5.515625 13.875 5.515625 27.296875 \n",
"Q 5.515625 40.765625 12.171875 48.375 \n",
"Q 18.84375 56 30.609375 56 \n",
"z\n",
"\" id=\"DejaVuSans-6f\"/>\n",
" <path d=\"M 48.78125 52.59375 \n",
"L 48.78125 44.1875 \n",
"Q 44.96875 46.296875 41.140625 47.34375 \n",
"Q 37.3125 48.390625 33.40625 48.390625 \n",
"Q 24.65625 48.390625 19.8125 42.84375 \n",
"Q 14.984375 37.3125 14.984375 27.296875 \n",
"Q 14.984375 17.28125 19.8125 11.734375 \n",
"Q 24.65625 6.203125 33.40625 6.203125 \n",
"Q 37.3125 6.203125 41.140625 7.25 \n",
"Q 44.96875 8.296875 48.78125 10.40625 \n",
"L 48.78125 2.09375 \n",
"Q 45.015625 0.34375 40.984375 -0.53125 \n",
"Q 36.96875 -1.421875 32.421875 -1.421875 \n",
"Q 20.0625 -1.421875 12.78125 6.34375 \n",
"Q 5.515625 14.109375 5.515625 27.296875 \n",
"Q 5.515625 40.671875 12.859375 48.328125 \n",
"Q 20.21875 56 33.015625 56 \n",
"Q 37.15625 56 41.109375 55.140625 \n",
"Q 45.0625 54.296875 48.78125 52.59375 \n",
"z\n",
"\" id=\"DejaVuSans-63\"/>\n",
" <path d=\"M 54.890625 33.015625 \n",
"L 54.890625 0 \n",
"L 45.90625 0 \n",
"L 45.90625 32.71875 \n",
"Q 45.90625 40.484375 42.875 44.328125 \n",
"Q 39.84375 48.1875 33.796875 48.1875 \n",
"Q 26.515625 48.1875 22.3125 43.546875 \n",
"Q 18.109375 38.921875 18.109375 30.90625 \n",
"L 18.109375 0 \n",
"L 9.078125 0 \n",
"L 9.078125 75.984375 \n",
"L 18.109375 75.984375 \n",
"L 18.109375 46.1875 \n",
"Q 21.34375 51.125 25.703125 53.5625 \n",
"Q 30.078125 56 35.796875 56 \n",
"Q 45.21875 56 50.046875 50.171875 \n",
"Q 54.890625 44.34375 54.890625 33.015625 \n",
"z\n",
"\" id=\"DejaVuSans-68\"/>\n",
" </defs>\n",
" <g transform=\"translate(132.565625 174.876562)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-65\"/>\n",
" <use x=\"61.523438\" xlink:href=\"#DejaVuSans-70\"/>\n",
" <use x=\"125\" xlink:href=\"#DejaVuSans-6f\"/>\n",
" <use x=\"186.181641\" xlink:href=\"#DejaVuSans-63\"/>\n",
" <use x=\"241.162109\" xlink:href=\"#DejaVuSans-68\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"matplotlib.axis_2\">\n",
" <g id=\"ytick_1\">\n",
" <g id=\"line2d_6\">\n",
" <defs>\n",
" <path d=\"M 0 0 \n",
"L -3.5 0 \n",
"\" id=\"m1e54e839d6\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
" </defs>\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m1e54e839d6\" y=\"137.068233\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_7\">\n",
" <!-- 0.25 -->\n",
" <g transform=\"translate(20.878125 140.867452)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-32\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_2\">\n",
" <g id=\"line2d_7\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m1e54e839d6\" y=\"112.984363\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_8\">\n",
" <!-- 0.30 -->\n",
" <defs>\n",
" <path d=\"M 40.578125 39.3125 \n",
"Q 47.65625 37.796875 51.625 33 \n",
"Q 55.609375 28.21875 55.609375 21.1875 \n",
"Q 55.609375 10.40625 48.1875 4.484375 \n",
"Q 40.765625 -1.421875 27.09375 -1.421875 \n",
"Q 22.515625 -1.421875 17.65625 -0.515625 \n",
"Q 12.796875 0.390625 7.625 2.203125 \n",
"L 7.625 11.71875 \n",
"Q 11.71875 9.328125 16.59375 8.109375 \n",
"Q 21.484375 6.890625 26.8125 6.890625 \n",
"Q 36.078125 6.890625 40.9375 10.546875 \n",
"Q 45.796875 14.203125 45.796875 21.1875 \n",
"Q 45.796875 27.640625 41.28125 31.265625 \n",
"Q 36.765625 34.90625 28.71875 34.90625 \n",
"L 20.21875 34.90625 \n",
"L 20.21875 43.015625 \n",
"L 29.109375 43.015625 \n",
"Q 36.375 43.015625 40.234375 45.921875 \n",
"Q 44.09375 48.828125 44.09375 54.296875 \n",
"Q 44.09375 59.90625 40.109375 62.90625 \n",
"Q 36.140625 65.921875 28.71875 65.921875 \n",
"Q 24.65625 65.921875 20.015625 65.03125 \n",
"Q 15.375 64.15625 9.8125 62.3125 \n",
"L 9.8125 71.09375 \n",
"Q 15.4375 72.65625 20.34375 73.4375 \n",
"Q 25.25 74.21875 29.59375 74.21875 \n",
"Q 40.828125 74.21875 47.359375 69.109375 \n",
"Q 53.90625 64.015625 53.90625 55.328125 \n",
"Q 53.90625 49.265625 50.4375 45.09375 \n",
"Q 46.96875 40.921875 40.578125 39.3125 \n",
"z\n",
"\" id=\"DejaVuSans-33\"/>\n",
" </defs>\n",
" <g transform=\"translate(20.878125 116.783582)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-33\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_3\">\n",
" <g id=\"line2d_8\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m1e54e839d6\" y=\"88.900493\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_9\">\n",
" <!-- 0.35 -->\n",
" <g transform=\"translate(20.878125 92.699712)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-33\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_4\">\n",
" <g id=\"line2d_9\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m1e54e839d6\" y=\"64.816623\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_10\">\n",
" <!-- 0.40 -->\n",
" <defs>\n",
" <path d=\"M 37.796875 64.3125 \n",
"L 12.890625 25.390625 \n",
"L 37.796875 25.390625 \n",
"z\n",
"M 35.203125 72.90625 \n",
"L 47.609375 72.90625 \n",
"L 47.609375 25.390625 \n",
"L 58.015625 25.390625 \n",
"L 58.015625 17.1875 \n",
"L 47.609375 17.1875 \n",
"L 47.609375 0 \n",
"L 37.796875 0 \n",
"L 37.796875 17.1875 \n",
"L 4.890625 17.1875 \n",
"L 4.890625 26.703125 \n",
"z\n",
"\" id=\"DejaVuSans-34\"/>\n",
" </defs>\n",
" <g transform=\"translate(20.878125 68.615842)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-34\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_5\">\n",
" <g id=\"line2d_10\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m1e54e839d6\" y=\"40.732753\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_11\">\n",
" <!-- 0.45 -->\n",
" <g transform=\"translate(20.878125 44.531972)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-34\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_6\">\n",
" <g id=\"line2d_11\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m1e54e839d6\" y=\"16.648883\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_12\">\n",
" <!-- 0.50 -->\n",
" <g transform=\"translate(20.878125 20.448102)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-35\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_13\">\n",
" <!-- loss -->\n",
" <defs>\n",
" <path d=\"M 9.421875 75.984375 \n",
"L 18.40625 75.984375 \n",
"L 18.40625 0 \n",
"L 9.421875 0 \n",
"z\n",
"\" id=\"DejaVuSans-6c\"/>\n",
" <path d=\"M 44.28125 53.078125 \n",
"L 44.28125 44.578125 \n",
"Q 40.484375 46.53125 36.375 47.5 \n",
"Q 32.28125 48.484375 27.875 48.484375 \n",
"Q 21.1875 48.484375 17.84375 46.4375 \n",
"Q 14.5 44.390625 14.5 40.28125 \n",
"Q 14.5 37.15625 16.890625 35.375 \n",
"Q 19.28125 33.59375 26.515625 31.984375 \n",
"L 29.59375 31.296875 \n",
"Q 39.15625 29.25 43.1875 25.515625 \n",
"Q 47.21875 21.78125 47.21875 15.09375 \n",
"Q 47.21875 7.46875 41.1875 3.015625 \n",
"Q 35.15625 -1.421875 24.609375 -1.421875 \n",
"Q 20.21875 -1.421875 15.453125 -0.5625 \n",
"Q 10.6875 0.296875 5.421875 2 \n",
"L 5.421875 11.28125 \n",
"Q 10.40625 8.6875 15.234375 7.390625 \n",
"Q 20.0625 6.109375 24.8125 6.109375 \n",
"Q 31.15625 6.109375 34.5625 8.28125 \n",
"Q 37.984375 10.453125 37.984375 14.40625 \n",
"Q 37.984375 18.0625 35.515625 20.015625 \n",
"Q 33.0625 21.96875 24.703125 23.78125 \n",
"L 21.578125 24.515625 \n",
"Q 13.234375 26.265625 9.515625 29.90625 \n",
"Q 5.8125 33.546875 5.8125 39.890625 \n",
"Q 5.8125 47.609375 11.28125 51.796875 \n",
"Q 16.75 56 26.8125 56 \n",
"Q 31.78125 56 36.171875 55.265625 \n",
"Q 40.578125 54.546875 44.28125 53.078125 \n",
"z\n",
"\" id=\"DejaVuSans-73\"/>\n",
" </defs>\n",
" <g transform=\"translate(14.798437 88.307812)rotate(-90)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-6c\"/>\n",
" <use x=\"27.783203\" xlink:href=\"#DejaVuSans-6f\"/>\n",
" <use x=\"88.964844\" xlink:href=\"#DejaVuSans-73\"/>\n",
" <use x=\"141.064453\" xlink:href=\"#DejaVuSans-73\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_12\">\n",
" <path clip-path=\"url(#p1c2d904669)\" d=\"M 59.021023 16.877273 \n",
"L 64.939205 72.39613 \n",
"L 70.857386 105.399637 \n",
"L 76.775568 112.357183 \n",
"L 82.69375 128.746823 \n",
"L 88.611932 134.148721 \n",
"L 94.530114 137.576849 \n",
"L 100.448295 134.393131 \n",
"L 106.366477 137.548972 \n",
"L 112.284659 136.972571 \n",
"L 118.202841 138.561767 \n",
"L 124.121023 138.862686 \n",
"L 130.039205 138.954702 \n",
"L 135.957386 139.270701 \n",
"L 141.875568 137.165604 \n",
"L 147.79375 139.317765 \n",
"L 153.711932 138.621197 \n",
"L 159.630114 140.316708 \n",
"L 165.548295 139.056222 \n",
"L 171.466477 138.647589 \n",
"L 177.384659 139.636922 \n",
"L 183.302841 140.000996 \n",
"L 189.221023 140.422727 \n",
"L 195.139205 138.215953 \n",
"L 201.057386 139.598536 \n",
"L 206.975568 140.049703 \n",
"L 212.89375 138.673874 \n",
"L 218.811932 138.438816 \n",
"L 224.730114 139.571706 \n",
"L 230.648295 138.771531 \n",
"L 236.566477 139.224901 \n",
"\" style=\"fill:none;stroke:#1f77b4;stroke-linecap:square;stroke-width:1.5;\"/>\n",
" </g>\n",
" <g id=\"patch_3\">\n",
" <path d=\"M 50.14375 146.6 \n",
"L 50.14375 10.7 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" <g id=\"patch_4\">\n",
" <path d=\"M 245.44375 146.6 \n",
"L 245.44375 10.7 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" <g id=\"patch_5\">\n",
" <path d=\"M 50.14375 146.6 \n",
"L 245.44375 146.6 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" <g id=\"patch_6\">\n",
" <path d=\"M 50.14375 10.7 \n",
"L 245.44375 10.7 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <defs>\n",
" <clipPath id=\"p1c2d904669\">\n",
" <rect height=\"135.9\" width=\"195.3\" x=\"50.14375\" y=\"10.7\"/>\n",
" </clipPath>\n",
" </defs>\n",
"</svg>\n"
],
"text/plain": [
"<matplotlib.figure.Figure at 0x11beca518>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"train_sgd(0.05, 10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7.3.3 简洁实现"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# 本函数与原书不同的是这里第一个参数优化器函数而不是优化器的名字\n",
"# 例如: optimizer_fn=torch.optim.SGD, optimizer_hyperparams={\"lr\": 0.05}\n",
"def train_pytorch_ch7(optimizer_fn, optimizer_hyperparams, features, labels,\n",
" batch_size=10, num_epochs=2):\n",
" # 初始化模型\n",
" net = nn.Sequential(\n",
" nn.Linear(features.shape[-1], 1)\n",
" )\n",
" loss = nn.MSELoss()\n",
" optimizer = optimizer_fn(net.parameters(), **optimizer_hyperparams)\n",
"\n",
" def eval_loss():\n",
" return loss(net(features).view(-1), labels).item() / 2\n",
"\n",
" ls = [eval_loss()]\n",
" data_iter = torch.utils.data.DataLoader(\n",
" torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)\n",
"\n",
" for _ in range(num_epochs):\n",
" start = time.time()\n",
" for batch_i, (X, y) in enumerate(data_iter):\n",
" # 除以2是为了和train_ch7保持一致, 因为squared_loss中除了2\n",
" l = loss(net(X).view(-1), y) / 2 \n",
" \n",
" optimizer.zero_grad()\n",
" l.backward()\n",
" optimizer.step()\n",
" if (batch_i + 1) * batch_size % 100 == 0:\n",
" ls.append(eval_loss())\n",
" # 打印结果和作图\n",
" print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))\n",
" d2l.set_figsize()\n",
" d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)\n",
" d2l.plt.xlabel('epoch')\n",
" d2l.plt.ylabel('loss')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss: 0.245491, 0.044150 sec per epoch\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Created with matplotlib (http://matplotlib.org/) -->\n",
"<svg height=\"184pt\" version=\"1.1\" viewBox=\"0 0 256 184\" width=\"256pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
" <defs>\n",
" <style type=\"text/css\">\n",
"*{stroke-linecap:butt;stroke-linejoin:round;}\n",
" </style>\n",
" </defs>\n",
" <g id=\"figure_1\">\n",
" <g id=\"patch_1\">\n",
" <path d=\"M 0 184.15625 \n",
"L 256.14375 184.15625 \n",
"L 256.14375 -0 \n",
"L 0 -0 \n",
"z\n",
"\" style=\"fill:none;\"/>\n",
" </g>\n",
" <g id=\"axes_1\">\n",
" <g id=\"patch_2\">\n",
" <path d=\"M 50.14375 146.6 \n",
"L 245.44375 146.6 \n",
"L 245.44375 10.7 \n",
"L 50.14375 10.7 \n",
"z\n",
"\" style=\"fill:#ffffff;\"/>\n",
" </g>\n",
" <g id=\"matplotlib.axis_1\">\n",
" <g id=\"xtick_1\">\n",
" <g id=\"line2d_1\">\n",
" <defs>\n",
" <path d=\"M 0 0 \n",
"L 0 3.5 \n",
"\" id=\"m4dfe75d77f\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
" </defs>\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"59.021023\" xlink:href=\"#m4dfe75d77f\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_1\">\n",
" <!-- 0.0 -->\n",
" <defs>\n",
" <path d=\"M 31.78125 66.40625 \n",
"Q 24.171875 66.40625 20.328125 58.90625 \n",
"Q 16.5 51.421875 16.5 36.375 \n",
"Q 16.5 21.390625 20.328125 13.890625 \n",
"Q 24.171875 6.390625 31.78125 6.390625 \n",
"Q 39.453125 6.390625 43.28125 13.890625 \n",
"Q 47.125 21.390625 47.125 36.375 \n",
"Q 47.125 51.421875 43.28125 58.90625 \n",
"Q 39.453125 66.40625 31.78125 66.40625 \n",
"z\n",
"M 31.78125 74.21875 \n",
"Q 44.046875 74.21875 50.515625 64.515625 \n",
"Q 56.984375 54.828125 56.984375 36.375 \n",
"Q 56.984375 17.96875 50.515625 8.265625 \n",
"Q 44.046875 -1.421875 31.78125 -1.421875 \n",
"Q 19.53125 -1.421875 13.0625 8.265625 \n",
"Q 6.59375 17.96875 6.59375 36.375 \n",
"Q 6.59375 54.828125 13.0625 64.515625 \n",
"Q 19.53125 74.21875 31.78125 74.21875 \n",
"z\n",
"\" id=\"DejaVuSans-30\"/>\n",
" <path d=\"M 10.6875 12.40625 \n",
"L 21 12.40625 \n",
"L 21 0 \n",
"L 10.6875 0 \n",
"z\n",
"\" id=\"DejaVuSans-2e\"/>\n",
" </defs>\n",
" <g transform=\"translate(51.06946 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_2\">\n",
" <g id=\"line2d_2\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"103.407386\" xlink:href=\"#m4dfe75d77f\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_2\">\n",
" <!-- 0.5 -->\n",
" <defs>\n",
" <path d=\"M 10.796875 72.90625 \n",
"L 49.515625 72.90625 \n",
"L 49.515625 64.59375 \n",
"L 19.828125 64.59375 \n",
"L 19.828125 46.734375 \n",
"Q 21.96875 47.46875 24.109375 47.828125 \n",
"Q 26.265625 48.1875 28.421875 48.1875 \n",
"Q 40.625 48.1875 47.75 41.5 \n",
"Q 54.890625 34.8125 54.890625 23.390625 \n",
"Q 54.890625 11.625 47.5625 5.09375 \n",
"Q 40.234375 -1.421875 26.90625 -1.421875 \n",
"Q 22.3125 -1.421875 17.546875 -0.640625 \n",
"Q 12.796875 0.140625 7.71875 1.703125 \n",
"L 7.71875 11.625 \n",
"Q 12.109375 9.234375 16.796875 8.0625 \n",
"Q 21.484375 6.890625 26.703125 6.890625 \n",
"Q 35.15625 6.890625 40.078125 11.328125 \n",
"Q 45.015625 15.765625 45.015625 23.390625 \n",
"Q 45.015625 31 40.078125 35.4375 \n",
"Q 35.15625 39.890625 26.703125 39.890625 \n",
"Q 22.75 39.890625 18.8125 39.015625 \n",
"Q 14.890625 38.140625 10.796875 36.28125 \n",
"z\n",
"\" id=\"DejaVuSans-35\"/>\n",
" </defs>\n",
" <g transform=\"translate(95.455824 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_3\">\n",
" <g id=\"line2d_3\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"147.79375\" xlink:href=\"#m4dfe75d77f\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_3\">\n",
" <!-- 1.0 -->\n",
" <defs>\n",
" <path d=\"M 12.40625 8.296875 \n",
"L 28.515625 8.296875 \n",
"L 28.515625 63.921875 \n",
"L 10.984375 60.40625 \n",
"L 10.984375 69.390625 \n",
"L 28.421875 72.90625 \n",
"L 38.28125 72.90625 \n",
"L 38.28125 8.296875 \n",
"L 54.390625 8.296875 \n",
"L 54.390625 0 \n",
"L 12.40625 0 \n",
"z\n",
"\" id=\"DejaVuSans-31\"/>\n",
" </defs>\n",
" <g transform=\"translate(139.842187 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_4\">\n",
" <g id=\"line2d_4\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"192.180114\" xlink:href=\"#m4dfe75d77f\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_4\">\n",
" <!-- 1.5 -->\n",
" <g transform=\"translate(184.228551 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_5\">\n",
" <g id=\"line2d_5\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"236.566477\" xlink:href=\"#m4dfe75d77f\" y=\"146.6\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_5\">\n",
" <!-- 2.0 -->\n",
" <defs>\n",
" <path d=\"M 19.1875 8.296875 \n",
"L 53.609375 8.296875 \n",
"L 53.609375 0 \n",
"L 7.328125 0 \n",
"L 7.328125 8.296875 \n",
"Q 12.9375 14.109375 22.625 23.890625 \n",
"Q 32.328125 33.6875 34.8125 36.53125 \n",
"Q 39.546875 41.84375 41.421875 45.53125 \n",
"Q 43.3125 49.21875 43.3125 52.78125 \n",
"Q 43.3125 58.59375 39.234375 62.25 \n",
"Q 35.15625 65.921875 28.609375 65.921875 \n",
"Q 23.96875 65.921875 18.8125 64.3125 \n",
"Q 13.671875 62.703125 7.8125 59.421875 \n",
"L 7.8125 69.390625 \n",
"Q 13.765625 71.78125 18.9375 73 \n",
"Q 24.125 74.21875 28.421875 74.21875 \n",
"Q 39.75 74.21875 46.484375 68.546875 \n",
"Q 53.21875 62.890625 53.21875 53.421875 \n",
"Q 53.21875 48.921875 51.53125 44.890625 \n",
"Q 49.859375 40.875 45.40625 35.40625 \n",
"Q 44.1875 33.984375 37.640625 27.21875 \n",
"Q 31.109375 20.453125 19.1875 8.296875 \n",
"z\n",
"\" id=\"DejaVuSans-32\"/>\n",
" </defs>\n",
" <g transform=\"translate(228.614915 161.198437)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_6\">\n",
" <!-- epoch -->\n",
" <defs>\n",
" <path d=\"M 56.203125 29.59375 \n",
"L 56.203125 25.203125 \n",
"L 14.890625 25.203125 \n",
"Q 15.484375 15.921875 20.484375 11.0625 \n",
"Q 25.484375 6.203125 34.421875 6.203125 \n",
"Q 39.59375 6.203125 44.453125 7.46875 \n",
"Q 49.3125 8.734375 54.109375 11.28125 \n",
"L 54.109375 2.78125 \n",
"Q 49.265625 0.734375 44.1875 -0.34375 \n",
"Q 39.109375 -1.421875 33.890625 -1.421875 \n",
"Q 20.796875 -1.421875 13.15625 6.1875 \n",
"Q 5.515625 13.8125 5.515625 26.8125 \n",
"Q 5.515625 40.234375 12.765625 48.109375 \n",
"Q 20.015625 56 32.328125 56 \n",
"Q 43.359375 56 49.78125 48.890625 \n",
"Q 56.203125 41.796875 56.203125 29.59375 \n",
"z\n",
"M 47.21875 32.234375 \n",
"Q 47.125 39.59375 43.09375 43.984375 \n",
"Q 39.0625 48.390625 32.421875 48.390625 \n",
"Q 24.90625 48.390625 20.390625 44.140625 \n",
"Q 15.875 39.890625 15.1875 32.171875 \n",
"z\n",
"\" id=\"DejaVuSans-65\"/>\n",
" <path d=\"M 18.109375 8.203125 \n",
"L 18.109375 -20.796875 \n",
"L 9.078125 -20.796875 \n",
"L 9.078125 54.6875 \n",
"L 18.109375 54.6875 \n",
"L 18.109375 46.390625 \n",
"Q 20.953125 51.265625 25.265625 53.625 \n",
"Q 29.59375 56 35.59375 56 \n",
"Q 45.5625 56 51.78125 48.09375 \n",
"Q 58.015625 40.1875 58.015625 27.296875 \n",
"Q 58.015625 14.40625 51.78125 6.484375 \n",
"Q 45.5625 -1.421875 35.59375 -1.421875 \n",
"Q 29.59375 -1.421875 25.265625 0.953125 \n",
"Q 20.953125 3.328125 18.109375 8.203125 \n",
"z\n",
"M 48.6875 27.296875 \n",
"Q 48.6875 37.203125 44.609375 42.84375 \n",
"Q 40.53125 48.484375 33.40625 48.484375 \n",
"Q 26.265625 48.484375 22.1875 42.84375 \n",
"Q 18.109375 37.203125 18.109375 27.296875 \n",
"Q 18.109375 17.390625 22.1875 11.75 \n",
"Q 26.265625 6.109375 33.40625 6.109375 \n",
"Q 40.53125 6.109375 44.609375 11.75 \n",
"Q 48.6875 17.390625 48.6875 27.296875 \n",
"z\n",
"\" id=\"DejaVuSans-70\"/>\n",
" <path d=\"M 30.609375 48.390625 \n",
"Q 23.390625 48.390625 19.1875 42.75 \n",
"Q 14.984375 37.109375 14.984375 27.296875 \n",
"Q 14.984375 17.484375 19.15625 11.84375 \n",
"Q 23.34375 6.203125 30.609375 6.203125 \n",
"Q 37.796875 6.203125 41.984375 11.859375 \n",
"Q 46.1875 17.53125 46.1875 27.296875 \n",
"Q 46.1875 37.015625 41.984375 42.703125 \n",
"Q 37.796875 48.390625 30.609375 48.390625 \n",
"z\n",
"M 30.609375 56 \n",
"Q 42.328125 56 49.015625 48.375 \n",
"Q 55.71875 40.765625 55.71875 27.296875 \n",
"Q 55.71875 13.875 49.015625 6.21875 \n",
"Q 42.328125 -1.421875 30.609375 -1.421875 \n",
"Q 18.84375 -1.421875 12.171875 6.21875 \n",
"Q 5.515625 13.875 5.515625 27.296875 \n",
"Q 5.515625 40.765625 12.171875 48.375 \n",
"Q 18.84375 56 30.609375 56 \n",
"z\n",
"\" id=\"DejaVuSans-6f\"/>\n",
" <path d=\"M 48.78125 52.59375 \n",
"L 48.78125 44.1875 \n",
"Q 44.96875 46.296875 41.140625 47.34375 \n",
"Q 37.3125 48.390625 33.40625 48.390625 \n",
"Q 24.65625 48.390625 19.8125 42.84375 \n",
"Q 14.984375 37.3125 14.984375 27.296875 \n",
"Q 14.984375 17.28125 19.8125 11.734375 \n",
"Q 24.65625 6.203125 33.40625 6.203125 \n",
"Q 37.3125 6.203125 41.140625 7.25 \n",
"Q 44.96875 8.296875 48.78125 10.40625 \n",
"L 48.78125 2.09375 \n",
"Q 45.015625 0.34375 40.984375 -0.53125 \n",
"Q 36.96875 -1.421875 32.421875 -1.421875 \n",
"Q 20.0625 -1.421875 12.78125 6.34375 \n",
"Q 5.515625 14.109375 5.515625 27.296875 \n",
"Q 5.515625 40.671875 12.859375 48.328125 \n",
"Q 20.21875 56 33.015625 56 \n",
"Q 37.15625 56 41.109375 55.140625 \n",
"Q 45.0625 54.296875 48.78125 52.59375 \n",
"z\n",
"\" id=\"DejaVuSans-63\"/>\n",
" <path d=\"M 54.890625 33.015625 \n",
"L 54.890625 0 \n",
"L 45.90625 0 \n",
"L 45.90625 32.71875 \n",
"Q 45.90625 40.484375 42.875 44.328125 \n",
"Q 39.84375 48.1875 33.796875 48.1875 \n",
"Q 26.515625 48.1875 22.3125 43.546875 \n",
"Q 18.109375 38.921875 18.109375 30.90625 \n",
"L 18.109375 0 \n",
"L 9.078125 0 \n",
"L 9.078125 75.984375 \n",
"L 18.109375 75.984375 \n",
"L 18.109375 46.1875 \n",
"Q 21.34375 51.125 25.703125 53.5625 \n",
"Q 30.078125 56 35.796875 56 \n",
"Q 45.21875 56 50.046875 50.171875 \n",
"Q 54.890625 44.34375 54.890625 33.015625 \n",
"z\n",
"\" id=\"DejaVuSans-68\"/>\n",
" </defs>\n",
" <g transform=\"translate(132.565625 174.876562)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-65\"/>\n",
" <use x=\"61.523438\" xlink:href=\"#DejaVuSans-70\"/>\n",
" <use x=\"125\" xlink:href=\"#DejaVuSans-6f\"/>\n",
" <use x=\"186.181641\" xlink:href=\"#DejaVuSans-63\"/>\n",
" <use x=\"241.162109\" xlink:href=\"#DejaVuSans-68\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"matplotlib.axis_2\">\n",
" <g id=\"ytick_1\">\n",
" <g id=\"line2d_6\">\n",
" <defs>\n",
" <path d=\"M 0 0 \n",
"L -3.5 0 \n",
"\" id=\"m6af6a8b8a6\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
" </defs>\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m6af6a8b8a6\" y=\"136.514517\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_7\">\n",
" <!-- 0.25 -->\n",
" <g transform=\"translate(20.878125 140.313736)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-32\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_2\">\n",
" <g id=\"line2d_7\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m6af6a8b8a6\" y=\"108.678722\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_8\">\n",
" <!-- 0.30 -->\n",
" <defs>\n",
" <path d=\"M 40.578125 39.3125 \n",
"Q 47.65625 37.796875 51.625 33 \n",
"Q 55.609375 28.21875 55.609375 21.1875 \n",
"Q 55.609375 10.40625 48.1875 4.484375 \n",
"Q 40.765625 -1.421875 27.09375 -1.421875 \n",
"Q 22.515625 -1.421875 17.65625 -0.515625 \n",
"Q 12.796875 0.390625 7.625 2.203125 \n",
"L 7.625 11.71875 \n",
"Q 11.71875 9.328125 16.59375 8.109375 \n",
"Q 21.484375 6.890625 26.8125 6.890625 \n",
"Q 36.078125 6.890625 40.9375 10.546875 \n",
"Q 45.796875 14.203125 45.796875 21.1875 \n",
"Q 45.796875 27.640625 41.28125 31.265625 \n",
"Q 36.765625 34.90625 28.71875 34.90625 \n",
"L 20.21875 34.90625 \n",
"L 20.21875 43.015625 \n",
"L 29.109375 43.015625 \n",
"Q 36.375 43.015625 40.234375 45.921875 \n",
"Q 44.09375 48.828125 44.09375 54.296875 \n",
"Q 44.09375 59.90625 40.109375 62.90625 \n",
"Q 36.140625 65.921875 28.71875 65.921875 \n",
"Q 24.65625 65.921875 20.015625 65.03125 \n",
"Q 15.375 64.15625 9.8125 62.3125 \n",
"L 9.8125 71.09375 \n",
"Q 15.4375 72.65625 20.34375 73.4375 \n",
"Q 25.25 74.21875 29.59375 74.21875 \n",
"Q 40.828125 74.21875 47.359375 69.109375 \n",
"Q 53.90625 64.015625 53.90625 55.328125 \n",
"Q 53.90625 49.265625 50.4375 45.09375 \n",
"Q 46.96875 40.921875 40.578125 39.3125 \n",
"z\n",
"\" id=\"DejaVuSans-33\"/>\n",
" </defs>\n",
" <g transform=\"translate(20.878125 112.477941)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-33\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_3\">\n",
" <g id=\"line2d_8\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m6af6a8b8a6\" y=\"80.842927\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_9\">\n",
" <!-- 0.35 -->\n",
" <g transform=\"translate(20.878125 84.642146)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-33\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_4\">\n",
" <g id=\"line2d_9\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m6af6a8b8a6\" y=\"53.007132\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_10\">\n",
" <!-- 0.40 -->\n",
" <defs>\n",
" <path d=\"M 37.796875 64.3125 \n",
"L 12.890625 25.390625 \n",
"L 37.796875 25.390625 \n",
"z\n",
"M 35.203125 72.90625 \n",
"L 47.609375 72.90625 \n",
"L 47.609375 25.390625 \n",
"L 58.015625 25.390625 \n",
"L 58.015625 17.1875 \n",
"L 47.609375 17.1875 \n",
"L 47.609375 0 \n",
"L 37.796875 0 \n",
"L 37.796875 17.1875 \n",
"L 4.890625 17.1875 \n",
"L 4.890625 26.703125 \n",
"z\n",
"\" id=\"DejaVuSans-34\"/>\n",
" </defs>\n",
" <g transform=\"translate(20.878125 56.806351)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-34\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_5\">\n",
" <g id=\"line2d_10\">\n",
" <g>\n",
" <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m6af6a8b8a6\" y=\"25.171337\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_11\">\n",
" <!-- 0.45 -->\n",
" <g transform=\"translate(20.878125 28.970556)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
" <use x=\"95.410156\" xlink:href=\"#DejaVuSans-34\"/>\n",
" <use x=\"159.033203\" xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_12\">\n",
" <!-- loss -->\n",
" <defs>\n",
" <path d=\"M 9.421875 75.984375 \n",
"L 18.40625 75.984375 \n",
"L 18.40625 0 \n",
"L 9.421875 0 \n",
"z\n",
"\" id=\"DejaVuSans-6c\"/>\n",
" <path d=\"M 44.28125 53.078125 \n",
"L 44.28125 44.578125 \n",
"Q 40.484375 46.53125 36.375 47.5 \n",
"Q 32.28125 48.484375 27.875 48.484375 \n",
"Q 21.1875 48.484375 17.84375 46.4375 \n",
"Q 14.5 44.390625 14.5 40.28125 \n",
"Q 14.5 37.15625 16.890625 35.375 \n",
"Q 19.28125 33.59375 26.515625 31.984375 \n",
"L 29.59375 31.296875 \n",
"Q 39.15625 29.25 43.1875 25.515625 \n",
"Q 47.21875 21.78125 47.21875 15.09375 \n",
"Q 47.21875 7.46875 41.1875 3.015625 \n",
"Q 35.15625 -1.421875 24.609375 -1.421875 \n",
"Q 20.21875 -1.421875 15.453125 -0.5625 \n",
"Q 10.6875 0.296875 5.421875 2 \n",
"L 5.421875 11.28125 \n",
"Q 10.40625 8.6875 15.234375 7.390625 \n",
"Q 20.0625 6.109375 24.8125 6.109375 \n",
"Q 31.15625 6.109375 34.5625 8.28125 \n",
"Q 37.984375 10.453125 37.984375 14.40625 \n",
"Q 37.984375 18.0625 35.515625 20.015625 \n",
"Q 33.0625 21.96875 24.703125 23.78125 \n",
"L 21.578125 24.515625 \n",
"Q 13.234375 26.265625 9.515625 29.90625 \n",
"Q 5.8125 33.546875 5.8125 39.890625 \n",
"Q 5.8125 47.609375 11.28125 51.796875 \n",
"Q 16.75 56 26.8125 56 \n",
"Q 31.78125 56 36.171875 55.265625 \n",
"Q 40.578125 54.546875 44.28125 53.078125 \n",
"z\n",
"\" id=\"DejaVuSans-73\"/>\n",
" </defs>\n",
" <g transform=\"translate(14.798437 88.307812)rotate(-90)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-6c\"/>\n",
" <use x=\"27.783203\" xlink:href=\"#DejaVuSans-6f\"/>\n",
" <use x=\"88.964844\" xlink:href=\"#DejaVuSans-73\"/>\n",
" <use x=\"141.064453\" xlink:href=\"#DejaVuSans-73\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_11\">\n",
" <path clip-path=\"url(#p6d5caf865e)\" d=\"M 59.021023 16.877273 \n",
"L 64.939205 83.085146 \n",
"L 70.857386 118.704118 \n",
"L 76.775568 126.613417 \n",
"L 82.69375 135.745057 \n",
"L 88.611932 136.593733 \n",
"L 94.530114 137.946963 \n",
"L 100.448295 137.820088 \n",
"L 106.366477 138.782831 \n",
"L 112.284659 137.524777 \n",
"L 118.202841 139.030897 \n",
"L 124.121023 139.480699 \n",
"L 130.039205 139.509626 \n",
"L 135.957386 139.839174 \n",
"L 141.875568 139.912815 \n",
"L 147.79375 139.55293 \n",
"L 153.711932 140.025305 \n",
"L 159.630114 137.388869 \n",
"L 165.548295 139.865521 \n",
"L 171.466477 140.392274 \n",
"L 177.384659 136.315901 \n",
"L 183.302841 138.804292 \n",
"L 189.221023 138.965768 \n",
"L 195.139205 137.946706 \n",
"L 201.057386 137.485663 \n",
"L 206.975568 137.255789 \n",
"L 212.89375 135.824032 \n",
"L 218.811932 139.832902 \n",
"L 224.730114 140.422727 \n",
"L 230.648295 139.62424 \n",
"L 236.566477 139.02451 \n",
"\" style=\"fill:none;stroke:#1f77b4;stroke-linecap:square;stroke-width:1.5;\"/>\n",
" </g>\n",
" <g id=\"patch_3\">\n",
" <path d=\"M 50.14375 146.6 \n",
"L 50.14375 10.7 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" <g id=\"patch_4\">\n",
" <path d=\"M 245.44375 146.6 \n",
"L 245.44375 10.7 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" <g id=\"patch_5\">\n",
" <path d=\"M 50.14375 146.6 \n",
"L 245.44375 146.6 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" <g id=\"patch_6\">\n",
" <path d=\"M 50.14375 10.7 \n",
"L 245.44375 10.7 \n",
"\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <defs>\n",
" <clipPath id=\"p6d5caf865e\">\n",
" <rect height=\"135.9\" width=\"195.3\" x=\"50.14375\" y=\"10.7\"/>\n",
" </clipPath>\n",
" </defs>\n",
"</svg>\n"
],
"text/plain": [
"<matplotlib.figure.Figure at 0x11fc0d358>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"train_pytorch_ch7(optim.SGD, {\"lr\": 0.05}, features, labels, 10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}