You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

225 lines
4.4 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5.3 多输入通道和多输出通道\n",
"## 5.3.1 多输入通道"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4.1\n"
]
}
],
"source": [
"import torch\n",
"from torch import nn\n",
"import sys\n",
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"\n",
"print(torch.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def corr2d_multi_in(X, K):\n",
" # 沿着X和K的第0维通道维分别计算再相加\n",
" res = d2l.corr2d(X[0, :, :], K[0, :, :])\n",
" for i in range(1, X.shape[0]):\n",
" res += d2l.corr2d(X[i, :, :], K[i, :, :])\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 56., 72.],\n",
" [104., 120.]])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = torch.tensor([[[0, 1, 2], [3, 4, 5], [6, 7, 8]],\n",
" [[1, 2, 3], [4, 5, 6], [7, 8, 9]]])\n",
"K = torch.tensor([[[0, 1], [2, 3]], [[1, 2], [3, 4]]])\n",
"\n",
"corr2d_multi_in(X, K)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.3.2 多输出通道"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def corr2d_multi_in_out(X, K):\n",
" # 对K的第0维遍历每次同输入X做互相关计算。所有结果使用stack函数合并在一起\n",
" return torch.stack([corr2d_multi_in(X, k) for k in K])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([3, 2, 2, 2])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"K = torch.stack([K, K + 1, K + 2])\n",
"K.shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 56., 72.],\n",
" [104., 120.]],\n",
"\n",
" [[ 76., 100.],\n",
" [148., 172.]],\n",
"\n",
" [[ 96., 128.],\n",
" [192., 224.]]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"corr2d_multi_in_out(X, K)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5.3.3 $1\\times 1$卷积层"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def corr2d_multi_in_out_1x1(X, K):\n",
" c_i, h, w = X.shape\n",
" c_o = K.shape[0]\n",
" X = X.view(c_i, h * w)\n",
" K = K.view(c_o, c_i)\n",
" Y = torch.mm(K, X) # 全连接层的矩阵乘法\n",
" return Y.view(c_o, h, w)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = torch.rand(3, 3, 3)\n",
"K = torch.rand(2, 3, 1, 1)\n",
"\n",
"Y1 = corr2d_multi_in_out_1x1(X, K)\n",
"Y2 = corr2d_multi_in_out(X, K)\n",
"\n",
"(Y1 - Y2).norm().item() < 1e-6"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}