{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 4.2 模型参数的访问、初始化和共享" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4.1\n" ] } ], "source": [ "import torch\n", "from torch import nn\n", "from torch.nn import init\n", "\n", "print(torch.__version__)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential(\n", " (0): Linear(in_features=4, out_features=3, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=3, out_features=1, bias=True)\n", ")\n" ] } ], "source": [ "net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1)) # pytorch已进行默认初始化\n", "\n", "print(net)\n", "X = torch.rand(2, 4)\n", "Y = net(X).sum()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.2.1 访问模型参数" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "0.weight torch.Size([3, 4])\n", "0.bias torch.Size([3])\n", "2.weight torch.Size([1, 3])\n", "2.bias torch.Size([1])\n" ] } ], "source": [ "print(type(net.named_parameters()))\n", "for name, param in net.named_parameters():\n", " print(name, param.size())" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "weight torch.Size([3, 4]) \n", "bias torch.Size([3]) \n" ] } ], "source": [ "for name, param in net[0].named_parameters():\n", " print(name, param.size(), type(param))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "weight1\n" ] } ], "source": [ "class MyModel(nn.Module):\n", " def __init__(self, **kwargs):\n", " super(MyModel, self).__init__(**kwargs)\n", " self.weight1 = nn.Parameter(torch.rand(20, 20))\n", " self.weight2 = torch.rand(20, 20)\n", " def forward(self, x):\n", " pass\n", " \n", "n = MyModel()\n", "for name, param in n.named_parameters():\n", " print(name)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 0.2719, -0.0898, -0.2462, 0.0655],\n", " [-0.4669, -0.2703, 0.3230, 0.2067],\n", " [-0.2708, 0.1171, -0.0995, 0.3913]])\n", "None\n", "tensor([[-0.2281, -0.0653, -0.1646, -0.2569],\n", " [-0.1916, -0.0549, -0.1382, -0.2158],\n", " [ 0.0000, 0.0000, 0.0000, 0.0000]])\n" ] } ], "source": [ "weight_0 = list(net[0].parameters())[0]\n", "print(weight_0.data)\n", "print(weight_0.grad)\n", "Y.backward()\n", "print(weight_0.grad)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.2.2 初始化模型参数" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.weight tensor([[ 0.0030, 0.0094, 0.0070, -0.0010],\n", " [ 0.0001, 0.0039, 0.0105, -0.0126],\n", " [ 0.0105, -0.0135, -0.0047, -0.0006]])\n", "2.weight tensor([[-0.0074, 0.0051, 0.0066]])\n" ] } ], "source": [ "for name, param in net.named_parameters():\n", " if 'weight' in name:\n", " init.normal_(param, mean=0, std=0.01)\n", " print(name, param.data)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.bias tensor([0., 0., 0.])\n", "2.bias tensor([0.])\n" ] } ], "source": [ "for name, param in net.named_parameters():\n", " if 'bias' in name:\n", " init.constant_(param, val=0)\n", " print(name, param.data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.2.3 自定义初始化方法" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def init_weight_(tensor):\n", " with torch.no_grad():\n", " tensor.uniform_(-10, 10)\n", " tensor *= (tensor.abs() >= 5).float()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.weight tensor([[ 7.0403, 0.0000, -9.4569, 7.0111],\n", " [-0.0000, -0.0000, 0.0000, 0.0000],\n", " [ 9.8063, -0.0000, 0.0000, -9.7993]])\n", "2.weight tensor([[-5.8198, 7.7558, -5.0293]])\n" ] } ], "source": [ "for name, param in net.named_parameters():\n", " if 'weight' in name:\n", " init_weight_(param)\n", " print(name, param.data)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.bias tensor([1., 1., 1.])\n", "2.bias tensor([1.])\n" ] } ], "source": [ "for name, param in net.named_parameters():\n", " if 'bias' in name:\n", " param.data += 1\n", " print(name, param.data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.2.4 共享模型参数" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential(\n", " (0): Linear(in_features=1, out_features=1, bias=False)\n", " (1): Linear(in_features=1, out_features=1, bias=False)\n", ")\n", "0.weight tensor([[3.]])\n" ] } ], "source": [ "linear = nn.Linear(1, 1, bias=False)\n", "net = nn.Sequential(linear, linear) \n", "print(net)\n", "for name, param in net.named_parameters():\n", " init.constant_(param, val=3)\n", " print(name, param.data)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "True\n" ] } ], "source": [ "print(id(net[0]) == id(net[1]))\n", "print(id(net[0].weight) == id(net[1].weight))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(9., grad_fn=)\n", "tensor([[6.]])\n" ] } ], "source": [ "x = torch.ones(1, 1)\n", "y = net(x).sum()\n", "print(y)\n", "y.backward()\n", "print(net[0].weight.grad)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }