You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

225 lines
4.4 KiB

3 years ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5.3 多输入通道和多输出通道\n",
"## 5.3.1 多输入通道"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4.1\n"
]
}
],
"source": [
"import torch\n",
"from torch import nn\n",
"import sys\n",
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"\n",
"print(torch.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def corr2d_multi_in(X, K):\n",
" # 沿着X和K的第0维通道维分别计算再相加\n",
" res = d2l.corr2d(X[0, :, :], K[0, :, :])\n",
" for i in range(1, X.shape[0]):\n",
" res += d2l.corr2d(X[i, :, :], K[i, :, :])\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 56., 72.],\n",
" [104., 120.]])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = torch.tensor([[[0, 1, 2], [3, 4, 5], [6, 7, 8]],\n",
" [[1, 2, 3], [4, 5, 6], [7, 8, 9]]])\n",
"K = torch.tensor([[[0, 1], [2, 3]], [[1, 2], [3, 4]]])\n",
"\n",
"corr2d_multi_in(X, K)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.3.2 多输出通道"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def corr2d_multi_in_out(X, K):\n",
" # 对K的第0维遍历每次同输入X做互相关计算。所有结果使用stack函数合并在一起\n",
" return torch.stack([corr2d_multi_in(X, k) for k in K])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([3, 2, 2, 2])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"K = torch.stack([K, K + 1, K + 2])\n",
"K.shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 56., 72.],\n",
" [104., 120.]],\n",
"\n",
" [[ 76., 100.],\n",
" [148., 172.]],\n",
"\n",
" [[ 96., 128.],\n",
" [192., 224.]]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"corr2d_multi_in_out(X, K)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.3.3 $1\\times 1$卷积层"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def corr2d_multi_in_out_1x1(X, K):\n",
" c_i, h, w = X.shape\n",
" c_o = K.shape[0]\n",
" X = X.view(c_i, h * w)\n",
" K = K.view(c_o, c_i)\n",
" Y = torch.mm(K, X) # 全连接层的矩阵乘法\n",
" return Y.view(c_o, h, w)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = torch.rand(3, 3, 3)\n",
"K = torch.rand(2, 3, 1, 1)\n",
"\n",
"Y1 = corr2d_multi_in_out_1x1(X, K)\n",
"Y2 = corr2d_multi_in_out(X, K)\n",
"\n",
"(Y1 - Y2).norm().item() < 1e-6"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}