{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 3.10 多层感知机的简洁实现" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4.1\n" ] } ], "source": [ "import torch\n", "from torch import nn\n", "from torch.nn import init\n", "import numpy as np\n", "import sys\n", "sys.path.append(\"..\") \n", "import d2lzh_pytorch as d2l\n", "\n", "print(torch.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.10.1 定义模型" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "num_inputs, num_outputs, num_hiddens = 784, 10, 256\n", " \n", "net = nn.Sequential(\n", " d2l.FlattenLayer(),\n", " nn.Linear(num_inputs, num_hiddens),\n", " nn.ReLU(),\n", " nn.Linear(num_hiddens, num_outputs), \n", " )\n", " \n", "for params in net.parameters():\n", " init.normal_(params, mean=0, std=0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.10.2 读取数据并训练模型" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 1, loss 0.0031, train acc 0.703, test acc 0.757\n", "epoch 2, loss 0.0019, train acc 0.824, test acc 0.822\n", "epoch 3, loss 0.0016, train acc 0.845, test acc 0.825\n", "epoch 4, loss 0.0015, train acc 0.855, test acc 0.811\n", "epoch 5, loss 0.0014, train acc 0.865, test acc 0.846\n" ] } ], "source": [ "batch_size = 256\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n", "loss = torch.nn.CrossEntropyLoss()\n", "\n", "optimizer = torch.optim.SGD(net.parameters(), lr=0.5)\n", "\n", "num_epochs = 5\n", "d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }