You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

379 lines
8.0 KiB

3 years ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4.2 模型参数的访问、初始化和共享"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4.1\n"
]
}
],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.nn import init\n",
"\n",
"print(torch.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (0): Linear(in_features=4, out_features=3, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=3, out_features=1, bias=True)\n",
")\n"
]
}
],
"source": [
"net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1)) # pytorch已进行默认初始化\n",
"\n",
"print(net)\n",
"X = torch.rand(2, 4)\n",
"Y = net(X).sum()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4.2.1 访问模型参数"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'generator'>\n",
"0.weight torch.Size([3, 4])\n",
"0.bias torch.Size([3])\n",
"2.weight torch.Size([1, 3])\n",
"2.bias torch.Size([1])\n"
]
}
],
"source": [
"print(type(net.named_parameters()))\n",
"for name, param in net.named_parameters():\n",
" print(name, param.size())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"weight torch.Size([3, 4]) <class 'torch.nn.parameter.Parameter'>\n",
"bias torch.Size([3]) <class 'torch.nn.parameter.Parameter'>\n"
]
}
],
"source": [
"for name, param in net[0].named_parameters():\n",
" print(name, param.size(), type(param))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"weight1\n"
]
}
],
"source": [
"class MyModel(nn.Module):\n",
" def __init__(self, **kwargs):\n",
" super(MyModel, self).__init__(**kwargs)\n",
" self.weight1 = nn.Parameter(torch.rand(20, 20))\n",
" self.weight2 = torch.rand(20, 20)\n",
" def forward(self, x):\n",
" pass\n",
" \n",
"n = MyModel()\n",
"for name, param in n.named_parameters():\n",
" print(name)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 0.2719, -0.0898, -0.2462, 0.0655],\n",
" [-0.4669, -0.2703, 0.3230, 0.2067],\n",
" [-0.2708, 0.1171, -0.0995, 0.3913]])\n",
"None\n",
"tensor([[-0.2281, -0.0653, -0.1646, -0.2569],\n",
" [-0.1916, -0.0549, -0.1382, -0.2158],\n",
" [ 0.0000, 0.0000, 0.0000, 0.0000]])\n"
]
}
],
"source": [
"weight_0 = list(net[0].parameters())[0]\n",
"print(weight_0.data)\n",
"print(weight_0.grad)\n",
"Y.backward()\n",
"print(weight_0.grad)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4.2.2 初始化模型参数"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.weight tensor([[ 0.0030, 0.0094, 0.0070, -0.0010],\n",
" [ 0.0001, 0.0039, 0.0105, -0.0126],\n",
" [ 0.0105, -0.0135, -0.0047, -0.0006]])\n",
"2.weight tensor([[-0.0074, 0.0051, 0.0066]])\n"
]
}
],
"source": [
"for name, param in net.named_parameters():\n",
" if 'weight' in name:\n",
" init.normal_(param, mean=0, std=0.01)\n",
" print(name, param.data)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.bias tensor([0., 0., 0.])\n",
"2.bias tensor([0.])\n"
]
}
],
"source": [
"for name, param in net.named_parameters():\n",
" if 'bias' in name:\n",
" init.constant_(param, val=0)\n",
" print(name, param.data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4.2.3 自定义初始化方法"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def init_weight_(tensor):\n",
" with torch.no_grad():\n",
" tensor.uniform_(-10, 10)\n",
" tensor *= (tensor.abs() >= 5).float()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.weight tensor([[ 7.0403, 0.0000, -9.4569, 7.0111],\n",
" [-0.0000, -0.0000, 0.0000, 0.0000],\n",
" [ 9.8063, -0.0000, 0.0000, -9.7993]])\n",
"2.weight tensor([[-5.8198, 7.7558, -5.0293]])\n"
]
}
],
"source": [
"for name, param in net.named_parameters():\n",
" if 'weight' in name:\n",
" init_weight_(param)\n",
" print(name, param.data)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.bias tensor([1., 1., 1.])\n",
"2.bias tensor([1.])\n"
]
}
],
"source": [
"for name, param in net.named_parameters():\n",
" if 'bias' in name:\n",
" param.data += 1\n",
" print(name, param.data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4.2.4 共享模型参数"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (0): Linear(in_features=1, out_features=1, bias=False)\n",
" (1): Linear(in_features=1, out_features=1, bias=False)\n",
")\n",
"0.weight tensor([[3.]])\n"
]
}
],
"source": [
"linear = nn.Linear(1, 1, bias=False)\n",
"net = nn.Sequential(linear, linear) \n",
"print(net)\n",
"for name, param in net.named_parameters():\n",
" init.constant_(param, val=3)\n",
" print(name, param.data)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n"
]
}
],
"source": [
"print(id(net[0]) == id(net[1]))\n",
"print(id(net[0].weight) == id(net[1].weight))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(9., grad_fn=<SumBackward0>)\n",
"tensor([[6.]])\n"
]
}
],
"source": [
"x = torch.ones(1, 1)\n",
"y = net(x).sum()\n",
"print(y)\n",
"y.backward()\n",
"print(net[0].weight.grad)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}