You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

193 lines
4.1 KiB

3 years ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 8.3 自动并行计算"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2019-05-10T16:16:41.669018Z",
"start_time": "2019-05-10T16:16:36.457355Z"
}
},
"outputs": [],
"source": [
"import torch\n",
"import time\n",
"\n",
"assert torch.cuda.device_count() >= 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2019-05-10T16:17:29.013953Z",
"start_time": "2019-05-10T16:16:41.673871Z"
}
},
"outputs": [],
"source": [
"x_gpu1 = torch.rand(size=(100, 100), device='cuda:0')\n",
"x_gpu2 = torch.rand(size=(100, 100), device='cuda:2')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2019-05-10T16:17:29.021652Z",
"start_time": "2019-05-10T16:17:29.017222Z"
}
},
"outputs": [],
"source": [
"class Benchmark(): # 本类已保存在d2lzh_pytorch包中方便以后使用\n",
" def __init__(self, prefix=None):\n",
" self.prefix = prefix + ' ' if prefix else ''\n",
"\n",
" def __enter__(self):\n",
" self.start = time.time()\n",
"\n",
" def __exit__(self, *args):\n",
" print('%stime: %.4f sec' % (self.prefix, time.time() - self.start))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2019-05-10T16:17:29.069210Z",
"start_time": "2019-05-10T16:17:29.023602Z"
}
},
"outputs": [],
"source": [
"def run(x):\n",
" for _ in range(20000):\n",
" y = torch.mm(x, x)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2019-05-10T16:17:29.767144Z",
"start_time": "2019-05-10T16:17:29.071262Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Run on GPU1. time: 0.2989 sec\n",
"Then run on GPU2. time: 0.3518 sec\n"
]
}
],
"source": [
"with Benchmark('Run on GPU1.'):\n",
" run(x_gpu1)\n",
" torch.cuda.synchronize()\n",
"\n",
"with Benchmark('Then run on GPU2.'):\n",
" run(x_gpu2)\n",
" torch.cuda.synchronize()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2019-05-10T16:17:30.282318Z",
"start_time": "2019-05-10T16:17:29.770313Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Run on both GPU1 and GPU2 in parallel. time: 0.5076 sec\n"
]
}
],
"source": [
"with Benchmark('Run on both GPU1 and GPU2 in parallel.'):\n",
" run(x_gpu1)\n",
" run(x_gpu2)\n",
" torch.cuda.synchronize()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:py36]",
"language": "python",
"name": "conda-env-py36-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}