{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 5.4 池化层" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4.1\n" ] } ], "source": [ "import torch\n", "from torch import nn\n", "\n", "print(torch.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5.4.1 二维最大池化层和平均池化层" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def pool2d(X, pool_size, mode='max'):\n", " X = X.float()\n", " p_h, p_w = pool_size\n", " Y = torch.zeros(X.shape[0] - p_h + 1, X.shape[1] - p_w + 1)\n", " for i in range(Y.shape[0]):\n", " for j in range(Y.shape[1]):\n", " if mode == 'max':\n", " Y[i, j] = X[i: i + p_h, j: j + p_w].max()\n", " elif mode == 'avg':\n", " Y[i, j] = X[i: i + p_h, j: j + p_w].mean() \n", " return Y" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[4., 5.],\n", " [7., 8.]])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])\n", "pool2d(X, (2, 2))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[2., 3.],\n", " [5., 6.]])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pool2d(X, (2, 2), 'avg')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5.4.2 填充和步幅" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 0., 1., 2., 3.],\n", " [ 4., 5., 6., 7.],\n", " [ 8., 9., 10., 11.],\n", " [12., 13., 14., 15.]]]])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = torch.arange(16, dtype=torch.float).view((1, 1, 4, 4))\n", "X" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[[10.]]]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pool2d = nn.MaxPool2d(3)\n", "pool2d(X) " ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 5., 7.],\n", " [13., 15.]]]])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n", "pool2d(X)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 1., 3.],\n", " [ 9., 11.],\n", " [13., 15.]]]])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pool2d = nn.MaxPool2d((2, 4), padding=(1, 2), stride=(2, 3))\n", "pool2d(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5.4.3 多通道" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 0., 1., 2., 3.],\n", " [ 4., 5., 6., 7.],\n", " [ 8., 9., 10., 11.],\n", " [12., 13., 14., 15.]],\n", "\n", " [[ 1., 2., 3., 4.],\n", " [ 5., 6., 7., 8.],\n", " [ 9., 10., 11., 12.],\n", " [13., 14., 15., 16.]]]])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = torch.cat((X, X + 1), dim=1)\n", "X" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 5., 7.],\n", " [13., 15.]],\n", "\n", " [[ 6., 8.],\n", " [14., 16.]]]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n", "pool2d(X)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }