You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
255 lines
4.9 KiB
255 lines
4.9 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 4.5 读取和存储"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0.4.1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torch import nn\n",
|
|
"\n",
|
|
"print(torch.__version__)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 4.5.1 读写`Tensor`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"x = torch.ones(3)\n",
|
|
"torch.save(x, 'x.pt')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([1., 1., 1.])"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"x2 = torch.load('x.pt')\n",
|
|
"x2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"y = torch.zeros(4)\n",
|
|
"torch.save([x, y], 'xy.pt')\n",
|
|
"xy_list = torch.load('xy.pt')\n",
|
|
"xy_list"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"torch.save({'x': x, 'y': y}, 'xy_dict.pt')\n",
|
|
"xy = torch.load('xy_dict.pt')\n",
|
|
"xy"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 4.5.2 读写模型\n",
|
|
"### 4.5.2.1 `state_dict`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"OrderedDict([('hidden.weight', tensor([[ 0.1836, -0.1812, -0.1681],\n",
|
|
" [ 0.0406, 0.3061, 0.4599]])),\n",
|
|
" ('hidden.bias', tensor([-0.3384, 0.1910])),\n",
|
|
" ('output.weight', tensor([[0.0380, 0.4919]])),\n",
|
|
" ('output.bias', tensor([0.1451]))])"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"class MLP(nn.Module):\n",
|
|
" def __init__(self):\n",
|
|
" super(MLP, self).__init__()\n",
|
|
" self.hidden = nn.Linear(3, 2)\n",
|
|
" self.act = nn.ReLU()\n",
|
|
" self.output = nn.Linear(2, 1)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" a = self.act(self.hidden(x))\n",
|
|
" return self.output(a)\n",
|
|
"\n",
|
|
"net = MLP()\n",
|
|
"net.state_dict()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'param_groups': [{'dampening': 0,\n",
|
|
" 'lr': 0.001,\n",
|
|
" 'momentum': 0.9,\n",
|
|
" 'nesterov': False,\n",
|
|
" 'params': [4624483024, 4624484608, 4624484680, 4624484752],\n",
|
|
" 'weight_decay': 0}],\n",
|
|
" 'state': {}}"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
|
|
"optimizer.state_dict()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 4.5.2.2 保存和加载模型"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[1],\n",
|
|
" [1]], dtype=torch.uint8)"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"X = torch.randn(2, 3)\n",
|
|
"Y = net(X)\n",
|
|
"\n",
|
|
"PATH = \"./net.pt\"\n",
|
|
"torch.save(net.state_dict(), PATH)\n",
|
|
"\n",
|
|
"net2 = MLP()\n",
|
|
"net2.load_state_dict(torch.load(PATH))\n",
|
|
"Y2 = net2(X)\n",
|
|
"Y2 == Y"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python [default]",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|