{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 7.1 优化与深度学习\n",
"## 7.1.2 优化在深度学习中的挑战"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import sys\n",
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"from mpl_toolkits import mplot3d # 三维画图\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 7.1.2.1 局部最小值"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def f(x):\n",
" return x * np.cos(np.pi * x)\n",
"\n",
"d2l.set_figsize((4.5, 2.5))\n",
"x = np.arange(-1.0, 2.0, 0.1)\n",
"fig, = d2l.plt.plot(x, f(x))\n",
"fig.axes.annotate('local minimum', xy=(-0.3, -0.25), xytext=(-0.77, -1.0),\n",
" arrowprops=dict(arrowstyle='->'))\n",
"fig.axes.annotate('global minimum', xy=(1.1, -0.95), xytext=(0.6, 0.8),\n",
" arrowprops=dict(arrowstyle='->'))\n",
"d2l.plt.xlabel('x')\n",
"d2l.plt.ylabel('f(x)');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 7.1.2.2 鞍点"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"x = np.arange(-2.0, 2.0, 0.1)\n",
"fig, = d2l.plt.plot(x, x**3)\n",
"fig.axes.annotate('saddle point', xy=(0, -0.2), xytext=(-0.52, -5.0),\n",
" arrowprops=dict(arrowstyle='->'))\n",
"d2l.plt.xlabel('x')\n",
"d2l.plt.ylabel('f(x)');"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"x, y = np.mgrid[-1: 1: 31j, -1: 1: 31j]\n",
"z = x**2 - y**2\n",
"\n",
"ax = d2l.plt.figure().add_subplot(111, projection='3d')\n",
"ax.plot_wireframe(x, y, z, **{'rstride': 2, 'cstride': 2})\n",
"ax.plot([0], [0], [0], 'rx')\n",
"ticks = [-1, 0, 1]\n",
"d2l.plt.xticks(ticks)\n",
"d2l.plt.yticks(ticks)\n",
"ax.set_zticks(ticks)\n",
"d2l.plt.xlabel('x')\n",
"d2l.plt.ylabel('y');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}