You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

130 lines
2.8 KiB

3 years ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3.10 多层感知机的简洁实现"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4.1\n"
]
}
],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.nn import init\n",
"import numpy as np\n",
"import sys\n",
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"\n",
"print(torch.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.10.1 定义模型"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"num_inputs, num_outputs, num_hiddens = 784, 10, 256\n",
" \n",
"net = nn.Sequential(\n",
" d2l.FlattenLayer(),\n",
" nn.Linear(num_inputs, num_hiddens),\n",
" nn.ReLU(),\n",
" nn.Linear(num_hiddens, num_outputs), \n",
" )\n",
" \n",
"for params in net.parameters():\n",
" init.normal_(params, mean=0, std=0.01)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.10.2 读取数据并训练模型"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1, loss 0.0031, train acc 0.703, test acc 0.757\n",
"epoch 2, loss 0.0019, train acc 0.824, test acc 0.822\n",
"epoch 3, loss 0.0016, train acc 0.845, test acc 0.825\n",
"epoch 4, loss 0.0015, train acc 0.855, test acc 0.811\n",
"epoch 5, loss 0.0014, train acc 0.865, test acc 0.846\n"
]
}
],
"source": [
"batch_size = 256\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n",
"loss = torch.nn.CrossEntropyLoss()\n",
"\n",
"optimizer = torch.optim.SGD(net.parameters(), lr=0.5)\n",
"\n",
"num_epochs = 5\n",
"d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}