{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 4.4 自定义层\n", "## 4.4.1 不含模型参数的自定义层" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4.1\n" ] } ], "source": [ "import torch\n", "from torch import nn\n", "\n", "print(torch.__version__)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class CenteredLayer(nn.Module):\n", " def __init__(self, **kwargs):\n", " super(CenteredLayer, self).__init__(**kwargs)\n", " def forward(self, x):\n", " return x - x.mean()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-2., -1., 0., 1., 2.])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "layer = CenteredLayer()\n", "layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.0" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y = net(torch.rand(4, 8))\n", "y.mean().item()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.4.2 含模型参数的自定义层" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MyListDense(\n", " (params): ParameterList(\n", " (0): Parameter containing: [torch.FloatTensor of size 4x4]\n", " (1): Parameter containing: [torch.FloatTensor of size 4x4]\n", " (2): Parameter containing: [torch.FloatTensor of size 4x4]\n", " (3): Parameter containing: [torch.FloatTensor of size 4x1]\n", " )\n", ")\n" ] } ], "source": [ "class MyListDense(nn.Module):\n", " def __init__(self):\n", " super(MyListDense, self).__init__()\n", " self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])\n", " self.params.append(nn.Parameter(torch.randn(4, 1)))\n", "\n", " def forward(self, x):\n", " for i in range(len(self.params)):\n", " x = torch.mm(x, self.params[i])\n", " return x\n", "net = MyListDense()\n", "print(net)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MyDictDense(\n", " (params): ParameterDict(\n", " (linear1): Parameter containing: [torch.FloatTensor of size 4x4]\n", " (linear2): Parameter containing: [torch.FloatTensor of size 4x1]\n", " (linear3): Parameter containing: [torch.FloatTensor of size 4x2]\n", " )\n", ")\n" ] } ], "source": [ "class MyDictDense(nn.Module):\n", " def __init__(self):\n", " super(MyDictDense, self).__init__()\n", " self.params = nn.ParameterDict({\n", " 'linear1': nn.Parameter(torch.randn(4, 4)),\n", " 'linear2': nn.Parameter(torch.randn(4, 1))\n", " })\n", " self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增\n", "\n", " def forward(self, x, choice='linear1'):\n", " return torch.mm(x, self.params[choice])\n", "\n", "net = MyDictDense()\n", "print(net)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[1.5082, 1.5574, 2.1651, 1.2409]], grad_fn=)\n", "tensor([[-0.8783]], grad_fn=)\n", "tensor([[ 2.2193, -1.6539]], grad_fn=)\n" ] } ], "source": [ "x = torch.ones(1, 4)\n", "print(net(x, 'linear1'))\n", "print(net(x, 'linear2'))\n", "print(net(x, 'linear3'))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential(\n", " (0): MyDictDense(\n", " (params): ParameterDict(\n", " (linear1): Parameter containing: [torch.FloatTensor of size 4x4]\n", " (linear2): Parameter containing: [torch.FloatTensor of size 4x1]\n", " (linear3): Parameter containing: [torch.FloatTensor of size 4x2]\n", " )\n", " )\n", " (1): MyListDense(\n", " (params): ParameterList(\n", " (0): Parameter containing: [torch.FloatTensor of size 4x4]\n", " (1): Parameter containing: [torch.FloatTensor of size 4x4]\n", " (2): Parameter containing: [torch.FloatTensor of size 4x4]\n", " (3): Parameter containing: [torch.FloatTensor of size 4x1]\n", " )\n", " )\n", ")\n", "tensor([[-101.2394]], grad_fn=)\n" ] } ], "source": [ "net = nn.Sequential(\n", " MyDictDense(),\n", " MyListDense(),\n", ")\n", "print(net)\n", "print(net(x))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }