You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

291 lines
5.5 KiB

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5.4 池化层"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4.1\n"
]
}
],
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"print(torch.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.4.1 二维最大池化层和平均池化层"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def pool2d(X, pool_size, mode='max'):\n",
" X = X.float()\n",
" p_h, p_w = pool_size\n",
" Y = torch.zeros(X.shape[0] - p_h + 1, X.shape[1] - p_w + 1)\n",
" for i in range(Y.shape[0]):\n",
" for j in range(Y.shape[1]):\n",
" if mode == 'max':\n",
" Y[i, j] = X[i: i + p_h, j: j + p_w].max()\n",
" elif mode == 'avg':\n",
" Y[i, j] = X[i: i + p_h, j: j + p_w].mean() \n",
" return Y"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[4., 5.],\n",
" [7., 8.]])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])\n",
"pool2d(X, (2, 2))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[2., 3.],\n",
" [5., 6.]])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pool2d(X, (2, 2), 'avg')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.4.2 填充和步幅"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [12., 13., 14., 15.]]]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = torch.arange(16, dtype=torch.float).view((1, 1, 4, 4))\n",
"X"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[10.]]]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pool2d = nn.MaxPool2d(3)\n",
"pool2d(X) "
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 5., 7.],\n",
" [13., 15.]]]])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n",
"pool2d(X)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 1., 3.],\n",
" [ 9., 11.],\n",
" [13., 15.]]]])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pool2d = nn.MaxPool2d((2, 4), padding=(1, 2), stride=(2, 3))\n",
"pool2d(X)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.4.3 多通道"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 0., 1., 2., 3.],\n",
" [ 4., 5., 6., 7.],\n",
" [ 8., 9., 10., 11.],\n",
" [12., 13., 14., 15.]],\n",
"\n",
" [[ 1., 2., 3., 4.],\n",
" [ 5., 6., 7., 8.],\n",
" [ 9., 10., 11., 12.],\n",
" [13., 14., 15., 16.]]]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = torch.cat((X, X + 1), dim=1)\n",
"X"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ 5., 7.],\n",
" [13., 15.]],\n",
"\n",
" [[ 6., 8.],\n",
" [14., 16.]]]])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n",
"pool2d(X)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}