You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
379 lines
8.0 KiB
379 lines
8.0 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 4.2 模型参数的访问、初始化和共享"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0.4.1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torch import nn\n",
|
|
"from torch.nn import init\n",
|
|
"\n",
|
|
"print(torch.__version__)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Sequential(\n",
|
|
" (0): Linear(in_features=4, out_features=3, bias=True)\n",
|
|
" (1): ReLU()\n",
|
|
" (2): Linear(in_features=3, out_features=1, bias=True)\n",
|
|
")\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1)) # pytorch已进行默认初始化\n",
|
|
"\n",
|
|
"print(net)\n",
|
|
"X = torch.rand(2, 4)\n",
|
|
"Y = net(X).sum()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 4.2.1 访问模型参数"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"<class 'generator'>\n",
|
|
"0.weight torch.Size([3, 4])\n",
|
|
"0.bias torch.Size([3])\n",
|
|
"2.weight torch.Size([1, 3])\n",
|
|
"2.bias torch.Size([1])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(type(net.named_parameters()))\n",
|
|
"for name, param in net.named_parameters():\n",
|
|
" print(name, param.size())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"weight torch.Size([3, 4]) <class 'torch.nn.parameter.Parameter'>\n",
|
|
"bias torch.Size([3]) <class 'torch.nn.parameter.Parameter'>\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for name, param in net[0].named_parameters():\n",
|
|
" print(name, param.size(), type(param))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"weight1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"class MyModel(nn.Module):\n",
|
|
" def __init__(self, **kwargs):\n",
|
|
" super(MyModel, self).__init__(**kwargs)\n",
|
|
" self.weight1 = nn.Parameter(torch.rand(20, 20))\n",
|
|
" self.weight2 = torch.rand(20, 20)\n",
|
|
" def forward(self, x):\n",
|
|
" pass\n",
|
|
" \n",
|
|
"n = MyModel()\n",
|
|
"for name, param in n.named_parameters():\n",
|
|
" print(name)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"tensor([[ 0.2719, -0.0898, -0.2462, 0.0655],\n",
|
|
" [-0.4669, -0.2703, 0.3230, 0.2067],\n",
|
|
" [-0.2708, 0.1171, -0.0995, 0.3913]])\n",
|
|
"None\n",
|
|
"tensor([[-0.2281, -0.0653, -0.1646, -0.2569],\n",
|
|
" [-0.1916, -0.0549, -0.1382, -0.2158],\n",
|
|
" [ 0.0000, 0.0000, 0.0000, 0.0000]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"weight_0 = list(net[0].parameters())[0]\n",
|
|
"print(weight_0.data)\n",
|
|
"print(weight_0.grad)\n",
|
|
"Y.backward()\n",
|
|
"print(weight_0.grad)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 4.2.2 初始化模型参数"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0.weight tensor([[ 0.0030, 0.0094, 0.0070, -0.0010],\n",
|
|
" [ 0.0001, 0.0039, 0.0105, -0.0126],\n",
|
|
" [ 0.0105, -0.0135, -0.0047, -0.0006]])\n",
|
|
"2.weight tensor([[-0.0074, 0.0051, 0.0066]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for name, param in net.named_parameters():\n",
|
|
" if 'weight' in name:\n",
|
|
" init.normal_(param, mean=0, std=0.01)\n",
|
|
" print(name, param.data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0.bias tensor([0., 0., 0.])\n",
|
|
"2.bias tensor([0.])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for name, param in net.named_parameters():\n",
|
|
" if 'bias' in name:\n",
|
|
" init.constant_(param, val=0)\n",
|
|
" print(name, param.data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 4.2.3 自定义初始化方法"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def init_weight_(tensor):\n",
|
|
" with torch.no_grad():\n",
|
|
" tensor.uniform_(-10, 10)\n",
|
|
" tensor *= (tensor.abs() >= 5).float()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0.weight tensor([[ 7.0403, 0.0000, -9.4569, 7.0111],\n",
|
|
" [-0.0000, -0.0000, 0.0000, 0.0000],\n",
|
|
" [ 9.8063, -0.0000, 0.0000, -9.7993]])\n",
|
|
"2.weight tensor([[-5.8198, 7.7558, -5.0293]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for name, param in net.named_parameters():\n",
|
|
" if 'weight' in name:\n",
|
|
" init_weight_(param)\n",
|
|
" print(name, param.data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0.bias tensor([1., 1., 1.])\n",
|
|
"2.bias tensor([1.])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for name, param in net.named_parameters():\n",
|
|
" if 'bias' in name:\n",
|
|
" param.data += 1\n",
|
|
" print(name, param.data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 4.2.4 共享模型参数"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Sequential(\n",
|
|
" (0): Linear(in_features=1, out_features=1, bias=False)\n",
|
|
" (1): Linear(in_features=1, out_features=1, bias=False)\n",
|
|
")\n",
|
|
"0.weight tensor([[3.]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"linear = nn.Linear(1, 1, bias=False)\n",
|
|
"net = nn.Sequential(linear, linear) \n",
|
|
"print(net)\n",
|
|
"for name, param in net.named_parameters():\n",
|
|
" init.constant_(param, val=3)\n",
|
|
" print(name, param.data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"True\n",
|
|
"True\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(id(net[0]) == id(net[1]))\n",
|
|
"print(id(net[0].weight) == id(net[1].weight))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"tensor(9., grad_fn=<SumBackward0>)\n",
|
|
"tensor([[6.]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"x = torch.ones(1, 1)\n",
|
|
"y = net(x).sum()\n",
|
|
"print(y)\n",
|
|
"y.backward()\n",
|
|
"print(net[0].weight.grad)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python [default]",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|