{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 4.5 读取和存储" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4.1\n" ] } ], "source": [ "import torch\n", "from torch import nn\n", "\n", "print(torch.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.5.1 读写`Tensor`" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "x = torch.ones(3)\n", "torch.save(x, 'x.pt')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1., 1., 1.])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x2 = torch.load('x.pt')\n", "x2" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y = torch.zeros(4)\n", "torch.save([x, y], 'xy.pt')\n", "xy_list = torch.load('xy.pt')\n", "xy_list" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.save({'x': x, 'y': y}, 'xy_dict.pt')\n", "xy = torch.load('xy_dict.pt')\n", "xy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4.5.2 读写模型\n", "### 4.5.2.1 `state_dict`" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "OrderedDict([('hidden.weight', tensor([[ 0.1836, -0.1812, -0.1681],\n", " [ 0.0406, 0.3061, 0.4599]])),\n", " ('hidden.bias', tensor([-0.3384, 0.1910])),\n", " ('output.weight', tensor([[0.0380, 0.4919]])),\n", " ('output.bias', tensor([0.1451]))])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class MLP(nn.Module):\n", " def __init__(self):\n", " super(MLP, self).__init__()\n", " self.hidden = nn.Linear(3, 2)\n", " self.act = nn.ReLU()\n", " self.output = nn.Linear(2, 1)\n", "\n", " def forward(self, x):\n", " a = self.act(self.hidden(x))\n", " return self.output(a)\n", "\n", "net = MLP()\n", "net.state_dict()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'param_groups': [{'dampening': 0,\n", " 'lr': 0.001,\n", " 'momentum': 0.9,\n", " 'nesterov': False,\n", " 'params': [4624483024, 4624484608, 4624484680, 4624484752],\n", " 'weight_decay': 0}],\n", " 'state': {}}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n", "optimizer.state_dict()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.5.2.2 保存和加载模型" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1],\n", " [1]], dtype=torch.uint8)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = torch.randn(2, 3)\n", "Y = net(X)\n", "\n", "PATH = \"./net.pt\"\n", "torch.save(net.state_dict(), PATH)\n", "\n", "net2 = MLP()\n", "net2.load_state_dict(torch.load(PATH))\n", "Y2 = net2(X)\n", "Y2 == Y" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }