You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
270 lines
6.3 KiB
270 lines
6.3 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 4.4 自定义层\n",
|
|
"## 4.4.1 不含模型参数的自定义层"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0.4.1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torch import nn\n",
|
|
"\n",
|
|
"print(torch.__version__)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class CenteredLayer(nn.Module):\n",
|
|
" def __init__(self, **kwargs):\n",
|
|
" super(CenteredLayer, self).__init__(**kwargs)\n",
|
|
" def forward(self, x):\n",
|
|
" return x - x.mean()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([-2., -1., 0., 1., 2.])"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"layer = CenteredLayer()\n",
|
|
"layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.0"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"y = net(torch.rand(4, 8))\n",
|
|
"y.mean().item()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 4.4.2 含模型参数的自定义层"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"MyListDense(\n",
|
|
" (params): ParameterList(\n",
|
|
" (0): Parameter containing: [torch.FloatTensor of size 4x4]\n",
|
|
" (1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
|
|
" (2): Parameter containing: [torch.FloatTensor of size 4x4]\n",
|
|
" (3): Parameter containing: [torch.FloatTensor of size 4x1]\n",
|
|
" )\n",
|
|
")\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"class MyListDense(nn.Module):\n",
|
|
" def __init__(self):\n",
|
|
" super(MyListDense, self).__init__()\n",
|
|
" self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])\n",
|
|
" self.params.append(nn.Parameter(torch.randn(4, 1)))\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" for i in range(len(self.params)):\n",
|
|
" x = torch.mm(x, self.params[i])\n",
|
|
" return x\n",
|
|
"net = MyListDense()\n",
|
|
"print(net)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"MyDictDense(\n",
|
|
" (params): ParameterDict(\n",
|
|
" (linear1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
|
|
" (linear2): Parameter containing: [torch.FloatTensor of size 4x1]\n",
|
|
" (linear3): Parameter containing: [torch.FloatTensor of size 4x2]\n",
|
|
" )\n",
|
|
")\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"class MyDictDense(nn.Module):\n",
|
|
" def __init__(self):\n",
|
|
" super(MyDictDense, self).__init__()\n",
|
|
" self.params = nn.ParameterDict({\n",
|
|
" 'linear1': nn.Parameter(torch.randn(4, 4)),\n",
|
|
" 'linear2': nn.Parameter(torch.randn(4, 1))\n",
|
|
" })\n",
|
|
" self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增\n",
|
|
"\n",
|
|
" def forward(self, x, choice='linear1'):\n",
|
|
" return torch.mm(x, self.params[choice])\n",
|
|
"\n",
|
|
"net = MyDictDense()\n",
|
|
"print(net)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"tensor([[1.5082, 1.5574, 2.1651, 1.2409]], grad_fn=<MmBackward>)\n",
|
|
"tensor([[-0.8783]], grad_fn=<MmBackward>)\n",
|
|
"tensor([[ 2.2193, -1.6539]], grad_fn=<MmBackward>)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"x = torch.ones(1, 4)\n",
|
|
"print(net(x, 'linear1'))\n",
|
|
"print(net(x, 'linear2'))\n",
|
|
"print(net(x, 'linear3'))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Sequential(\n",
|
|
" (0): MyDictDense(\n",
|
|
" (params): ParameterDict(\n",
|
|
" (linear1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
|
|
" (linear2): Parameter containing: [torch.FloatTensor of size 4x1]\n",
|
|
" (linear3): Parameter containing: [torch.FloatTensor of size 4x2]\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (1): MyListDense(\n",
|
|
" (params): ParameterList(\n",
|
|
" (0): Parameter containing: [torch.FloatTensor of size 4x4]\n",
|
|
" (1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
|
|
" (2): Parameter containing: [torch.FloatTensor of size 4x4]\n",
|
|
" (3): Parameter containing: [torch.FloatTensor of size 4x1]\n",
|
|
" )\n",
|
|
" )\n",
|
|
")\n",
|
|
"tensor([[-101.2394]], grad_fn=<MmBackward>)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"net = nn.Sequential(\n",
|
|
" MyDictDense(),\n",
|
|
" MyListDense(),\n",
|
|
")\n",
|
|
"print(net)\n",
|
|
"print(net(x))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python [default]",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|