You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
193 lines
4.1 KiB
193 lines
4.1 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 8.3 自动并行计算"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2019-05-10T16:16:41.669018Z",
|
|
"start_time": "2019-05-10T16:16:36.457355Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import time\n",
|
|
"\n",
|
|
"assert torch.cuda.device_count() >= 2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2019-05-10T16:17:29.013953Z",
|
|
"start_time": "2019-05-10T16:16:41.673871Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"x_gpu1 = torch.rand(size=(100, 100), device='cuda:0')\n",
|
|
"x_gpu2 = torch.rand(size=(100, 100), device='cuda:2')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2019-05-10T16:17:29.021652Z",
|
|
"start_time": "2019-05-10T16:17:29.017222Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Benchmark(): # 本类已保存在d2lzh_pytorch包中方便以后使用\n",
|
|
" def __init__(self, prefix=None):\n",
|
|
" self.prefix = prefix + ' ' if prefix else ''\n",
|
|
"\n",
|
|
" def __enter__(self):\n",
|
|
" self.start = time.time()\n",
|
|
"\n",
|
|
" def __exit__(self, *args):\n",
|
|
" print('%stime: %.4f sec' % (self.prefix, time.time() - self.start))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2019-05-10T16:17:29.069210Z",
|
|
"start_time": "2019-05-10T16:17:29.023602Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def run(x):\n",
|
|
" for _ in range(20000):\n",
|
|
" y = torch.mm(x, x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2019-05-10T16:17:29.767144Z",
|
|
"start_time": "2019-05-10T16:17:29.071262Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Run on GPU1. time: 0.2989 sec\n",
|
|
"Then run on GPU2. time: 0.3518 sec\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"with Benchmark('Run on GPU1.'):\n",
|
|
" run(x_gpu1)\n",
|
|
" torch.cuda.synchronize()\n",
|
|
"\n",
|
|
"with Benchmark('Then run on GPU2.'):\n",
|
|
" run(x_gpu2)\n",
|
|
" torch.cuda.synchronize()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2019-05-10T16:17:30.282318Z",
|
|
"start_time": "2019-05-10T16:17:29.770313Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Run on both GPU1 and GPU2 in parallel. time: 0.5076 sec\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"with Benchmark('Run on both GPU1 and GPU2 in parallel.'):\n",
|
|
" run(x_gpu1)\n",
|
|
" run(x_gpu2)\n",
|
|
" torch.cuda.synchronize()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python [conda env:py36]",
|
|
"language": "python",
|
|
"name": "conda-env-py36-py"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.8"
|
|
},
|
|
"varInspector": {
|
|
"cols": {
|
|
"lenName": 16,
|
|
"lenType": 16,
|
|
"lenVar": 40
|
|
},
|
|
"kernels_config": {
|
|
"python": {
|
|
"delete_cmd_postfix": "",
|
|
"delete_cmd_prefix": "del ",
|
|
"library": "var_list.py",
|
|
"varRefreshCmd": "print(var_dic_list())"
|
|
},
|
|
"r": {
|
|
"delete_cmd_postfix": ") ",
|
|
"delete_cmd_prefix": "rm(",
|
|
"library": "var_list.r",
|
|
"varRefreshCmd": "cat(var_dic_list()) "
|
|
}
|
|
},
|
|
"types_to_exclude": [
|
|
"module",
|
|
"function",
|
|
"builtin_function_or_method",
|
|
"instance",
|
|
"_Feature"
|
|
],
|
|
"window_display": false
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|