You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

279 lines
7.4 KiB

3 years ago
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.2.0\n"
]
}
],
"source": [
"%matplotlib inline\n",
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np\n",
"import sys\n",
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"\n",
"print(torch.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def dropout(X, drop_prob):\n",
" X = X.float()\n",
" assert 0 <= drop_prob <= 1\n",
" keep_prob = 1 - drop_prob\n",
" # 这种情况下把全部元素都丢弃\n",
" if keep_prob == 0:\n",
" return torch.zeros_like(X)\n",
" mask = (torch.rand(X.shape) < keep_prob).float()\n",
" \n",
" return mask * X / keep_prob"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0., 1., 2., 3., 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11., 12., 13., 14., 15.]])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = torch.arange(16).view(2, 8)\n",
"dropout(X, 0)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0., 0., 4., 6., 0., 0., 12., 14.],\n",
" [ 0., 18., 20., 22., 0., 0., 28., 0.]])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dropout(X, 0.5)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0.]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dropout(X, 1.0)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256\n",
"\n",
"W1 = torch.tensor(np.random.normal(0, 0.01, size=(num_inputs, num_hiddens1)), dtype=torch.float, requires_grad=True)\n",
"b1 = torch.zeros(num_hiddens1, requires_grad=True)\n",
"W2 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens1, num_hiddens2)), dtype=torch.float, requires_grad=True)\n",
"b2 = torch.zeros(num_hiddens2, requires_grad=True)\n",
"W3 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens2, num_outputs)), dtype=torch.float, requires_grad=True)\n",
"b3 = torch.zeros(num_outputs, requires_grad=True)\n",
"\n",
"params = [W1, b1, W2, b2, W3, b3]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"drop_prob1, drop_prob2 = 0.2, 0.5\n",
"\n",
"def net(X, is_training=True):\n",
" X = X.view(-1, num_inputs)\n",
" H1 = (torch.matmul(X, W1) + b1).relu()\n",
" if is_training: # 只在训练模型时使用丢弃法\n",
" H1 = dropout(H1, drop_prob1) # 在第一层全连接后添加丢弃层\n",
" H2 = (torch.matmul(H1, W2) + b2).relu()\n",
" if is_training:\n",
" H2 = dropout(H2, drop_prob2) # 在第二层全连接后添加丢弃层\n",
" return torch.matmul(H2, W3) + b3"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# def evaluate_accuracy(data_iter, net):\n",
"# acc_sum, n = 0.0, 0\n",
"# for X, y in data_iter:\n",
"# if isinstance(net, torch.nn.Module):\n",
"# net.eval() # 评估模式, 这会关闭dropout\n",
"# acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()\n",
"# net.train() # 改回训练模式\n",
"# else: # 自定义的模型\n",
"# if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数\n",
"# # 将is_training设置成False\n",
"# acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() \n",
"# else:\n",
"# acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() \n",
"# n += y.shape[0]\n",
"# return acc_sum / n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1, loss 0.0045, train acc 0.561, test acc 0.662\n",
"epoch 2, loss 0.0023, train acc 0.783, test acc 0.786\n",
"epoch 3, loss 0.0019, train acc 0.823, test acc 0.773\n",
"epoch 4, loss 0.0017, train acc 0.838, test acc 0.847\n",
"epoch 5, loss 0.0016, train acc 0.848, test acc 0.809\n"
]
}
],
"source": [
"num_epochs, lr, batch_size = 5, 100.0, 256 # 这里的学习率设置的很大原因同3.9.6节。\n",
"loss = torch.nn.CrossEntropyLoss()\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n",
"d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"net = nn.Sequential(\n",
" d2l.FlattenLayer(),\n",
" nn.Linear(num_inputs, num_hiddens1),\n",
" nn.ReLU(),\n",
" nn.Dropout(drop_prob1),\n",
" nn.Linear(num_hiddens1, num_hiddens2), \n",
" nn.ReLU(),\n",
" nn.Dropout(drop_prob2),\n",
" nn.Linear(num_hiddens2, 10)\n",
" )\n",
"\n",
"for param in net.parameters():\n",
" nn.init.normal_(param, mean=0, std=0.01)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1, loss 0.0048, train acc 0.526, test acc 0.743\n",
"epoch 2, loss 0.0023, train acc 0.779, test acc 0.764\n",
"epoch 3, loss 0.0020, train acc 0.815, test acc 0.819\n",
"epoch 4, loss 0.0018, train acc 0.836, test acc 0.814\n",
"epoch 5, loss 0.0016, train acc 0.848, test acc 0.842\n"
]
}
],
"source": [
"optimizer = torch.optim.SGD(net.parameters(), lr=0.5)\n",
"d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}