You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

209 lines
320 KiB

3 years ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 9.6 目标检测数据集(皮卡丘)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.2.0\n"
]
}
],
"source": [
"%matplotlib inline\n",
"import os\n",
"import json\n",
"import numpy as np\n",
"import torch\n",
"import torchvision\n",
"from PIL import Image\n",
"\n",
"import sys\n",
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"print(torch.__version__)\n",
"\n",
"data_dir = '../../data/pikachu'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9.6.1 下载数据集\n",
"请运行[脚本](https://github.com/ShusenTang/Dive-into-DL-PyTorch/blob/master/code/chapter09_computer-vision/9.6.0_prepare_pikachu.ipynb)准备好数据集。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"assert os.path.exists(os.path.join(data_dir, \"train\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9.6.2 读取数据集"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# 本类已保存在d2lzh_pytorch包中方便以后使用\n",
"class PikachuDetDataset(torch.utils.data.Dataset):\n",
" \"\"\"皮卡丘检测数据集类\"\"\"\n",
" def __init__(self, data_dir, part, image_size=(256, 256)):\n",
" assert part in [\"train\", \"val\"]\n",
" self.image_size = image_size\n",
" self.image_dir = os.path.join(data_dir, part, \"images\")\n",
" \n",
" with open(os.path.join(data_dir, part, \"label.json\")) as f:\n",
" self.label = json.load(f)\n",
" \n",
" self.transform = torchvision.transforms.Compose([\n",
" # 将 PIL 图片转换成位于[0.0, 1.0]的floatTensor, shape (C x H x W)\n",
" torchvision.transforms.ToTensor()])\n",
" \n",
" def __len__(self):\n",
" return len(self.label)\n",
" \n",
" def __getitem__(self, index):\n",
" image_path = str(index + 1) + \".png\"\n",
" \n",
" cls = self.label[image_path][\"class\"]\n",
" label = np.array([cls] + self.label[image_path][\"loc\"], \n",
" dtype=\"float32\")[None, :]\n",
" \n",
" PIL_img = Image.open(os.path.join(self.image_dir, image_path)\n",
" ).convert('RGB').resize(self.image_size)\n",
" img = self.transform(PIL_img)\n",
" \n",
" sample = {\n",
" \"label\": label, # shape: (1, 5) [class, xmin, ymin, xmax, ymax]\n",
" \"image\": img # shape: (3, *image_size)\n",
" }\n",
" \n",
" return sample\n",
" \n",
"\n",
"# 本函数已保存在d2lzh_pytorch包中方便以后使用\n",
"def load_data_pikachu(batch_size, edge_size=256, data_dir = '../../data/pikachu'): \n",
" \"\"\"edge_size输出图像的宽和高\"\"\"\n",
" image_size = (edge_size, edge_size)\n",
" train_dataset = PikachuDetDataset(data_dir, 'train', image_size)\n",
" val_dataset = PikachuDetDataset(data_dir, 'val', image_size)\n",
" \n",
"\n",
" train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, \n",
" shuffle=True, num_workers=4)\n",
"\n",
" val_iter = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size,\n",
" shuffle=False, num_workers=4)\n",
" return train_iter, val_iter"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([32, 3, 256, 256]) torch.Size([32, 1, 5])\n"
]
}
],
"source": [
"batch_size, edge_size = 32, 256\n",
"\n",
"train_iter, _ = load_data_pikachu(batch_size, edge_size, data_dir)\n",
"batch = iter(train_iter).next()\n",
"\n",
"print(batch[\"image\"].shape, batch[\"label\"].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9.6.3 图示数据"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkMAAADsCAYAAAB37KKJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvXm8bVdV5/sdY861m9Pcc+85997kJrkJNz0JbQhCQmMS\nURTR8lmAvTwtLQGBz6tSH88SxaqyfZY+fWpRZff02VCI+kSagCAEgQAFUbqE9H1y++Z0u1lrzjne\nH3OuvfdJokX0fD71B2d8Pueee/bea68115xzjN/4jWaJmbEjO7IjO7IjO7IjO/KVKvo/+wJ2ZEd2\nZEd2ZEd2ZEf+Z8oOGNqRHdmRHdmRHdmRr2jZAUM7siM7siM7siM78hUtO2BoR3ZkR3ZkR3ZkR76i\nZQcM7ciO7MiO7MiO7MhXtOyAoR3ZkR3ZkR3ZkR35ipYdMLQjO7IjO7IjO7IjX9GyA4Z2ZEd2ZEd2\nZEd25CtadsDQjuzIjuzIjuzIjnxFi38yH67mdllvaR8WA+o8iJBiwDmPCIBMP1z+qwgGiEh+SWzy\npohAfheR8mr5RybvP/b1cowUJCcz5yyfEx5/XD7P9LvzB2361+T8kq94cr0yOyqecJgi7TDYejWC\ntOd43DGToU9l5m+RfO+85pcEuP/++zlx4sRjBvxPk917lu3gwQu23D8p92Ty92PvbXvZ7cfK+yEm\nQtPQ63an4zFoQiCmhHeelCJOHeryHG5ubrK4OD+ZlZTy/ywZTVPT6XZJMRFjIIZAp9NhMBgQQmRl\nZZnBYMDq6iqLi4v0+nPU9ZjKV6xvbLB7926OHT0CZuxaWmI4GjEcbHLWWQfw3mGAqmApISI0IWIG\nIsapE6fYHK6xML+bXbuWOH78FJsbpzj7nP3UdWLfyj7ECSklVPJvEcE55eGHj7O2dpLLnnoZo40R\nDzx8P7v3LHPO2Wc95i4a99//wLbNJcDi4oLt3buMGagqIpr3XPmxcqdFhHqwjiOA2eR1ZhrRW/uP\ngIiW12RmPUjeF1pWpmi7YxEt+518/sd3uLfp79m3LAGGWSx/W/kbzBJN0xBjJCQjRUNUMbMZvdPq\nEaHXrXCqqCqdbp+qN1e+x8qYy6cN1jc3WVxY4OjhB1lamEPVTcaVxzvVD4+V2fEdOX6aM2ub2zKf\nvYVlW1w5+ITnnCiZdn6mam1Gw225yPz7sbqGrVpxetxjXi06Wlo9/rjLsYmOb48WARWhq5FB0K3T\nvGU5GDarqyevPsE4thw2O3ihbN7HfffjLlWkzL+Uc5fR2sxtMlg78SDD9ZPbMpd7V5bt/IMHn/B6\n/kHZshfLvn3cW4/5vsf9OaPH29fsH/jw416a3Jl/8PLaOygx0YzHmEV8r08cjSYzamXdJMD3eqir\nst6YGNupCBBjeOwCme6Bxy/U/4FMP/S5W28/YWb7/kdHPCkw1F/ay9Xf94t4IskSWnVYX12l8rCy\na4HgelkxlgGrczjAVFER1AmqBRipUoni1CHOcKJANpyq4JzDFYPbUUG8y0YHh3MCmlAFLx6n04Gr\nF9TyufPnBVXFiJhk0OZbZe6nxqIScGLZkHihEsE5wbkKKRqnch4ln1fIRsfMMKdoAlcpYuAVooBL\nStURxDKAUFW0cHFOFRGbKFNDQVI+jwjRjD1dYU/f0ykr4XlXP/fJTNc/KmeffS5/9b6bp0gLUJcB\nQXtPXDuXM/yhGFDut6oyHgUeeuRhLr7oKThT0ERMwsZgxNETpzl69Bh79y7jTDh06FzqcU3VgdGg\n5iMf/CDf+opvoSah5hiNRvhOFzXhrrvvwouxf98Kn7j5FprhkNAEXK/LtS+8hg9/4AN47xFf8fTn\nXMX99z7I5mCNF73oxaR6zM/89L+n1+ny4q95EadOrXH66GFe86Ov5wPv/wihXuCbXv4CRsMBvV6P\nJgUeuO9Bztq/zLv+4tP86Tt+kzf9u9ezsdmh1zmLj3/sj/jkJz7Pt333K7j+RV/LeQf3Mxxusri4\nC5OAkwqzhg/9za28+Sdfw4c/8Um++PHbeM2P/Cuef/31/NJP/SRzc/28JwTG4zHXXnvtts0lwNn7\nV3jja7+ThZXzUO/oVD263T6dTgfnHFVVkWKNuC5xvMaR2z5BFY/iEoDQxBojIjhiTFiEJNDpLQAQ\nUZzvoOoR9Zh2MJfoVAtQVah2UOnhvEN8hYoHMihTFZxzxBgRixmQ1YOJJZcYCM0Aiw0xDEAiGg0s\n0TQjBsMNDj96nDNn1ji5tkG/t0QzTjjnMMnfF0mIesSMKy4/l4Ven4XFRZ5y8VWcdelVWEiEEIhN\noCaSgqHm+cBHb+Jl113Dh971nzl35Ty63R5VpUTxqCrO+claF8m6ZFYza9nQ3/+//8q2zeXiykFe\n+ZYP4BCSZBA3e15RBcuAUGKaAl5V1FrgYo9zZlSEpIJEQ1QmYE4mzpxMwKVoetx15fELycBhqAre\nKZICvqoy0IiRqlfREbhg14DbT84RUiBExSShQGPFpZ0YPiOlPMZkCXBZLRlgEXUuv14ArSvjNwqA\nxmFmpFTAjkJE0ARGQESJZmj7nSip1bsT5A9IIiXjT37qhm2by/MPHuSjH3z35O9Z52D2HjzRb33M\na1v+Lyk7YjZ9PYlBBAdE4vR7Zo4Tmz1H3HreCVjMa6p1FrHWWTXEIhXG2pkznLr7HipLCIbTwK5L\nruTkl27P608jIUVM+hy45HLSOWfRxSFa4bwH8Xg1UqgZrp0hpQApFFYin1NaZF3Elb02WbPtfXwC\nryGhk2PPvuKaB76cuXpSYMjIgCYG6FQd1tbW2LVrCecd4xRQMcRNL1DNiCpUQKIscnGgRkcchhFI\nuCiIN9zEg3UkA3GKMwGXFZEXRV32BlWzh4/qlo2f53R6c0yFyNTAi4Dplns8XQiqGcs6IQJabmi0\nmDdqsgmYQYQoIN5NJ649nyXUe1xVri9OQQZYVnAxZu9djVh8K0dmg1xZEN7lk50OgqWauI2PkfOV\nQzSRsIlCb9chs4p0i5MoqAioYSgnTp5iYWEXhw6dT0eMpNlnCKkhWaTb63Pw4EF6/S4Wa5rQoC7R\n7/bpacW1L74BE6Ue1SzMeeo6sLa2wdzcHCvLezlx7Bi33nYPFxy6gCuvuIy3/vpvcO75h7j55k/y\noutfyuFHHmJ5eTcJ4SmHDnLixAmOPPoIl156GRdfdjH79ixzwVMuwdJ9fORD7+a3fvO3ueYFL+PK\nZ17MrXfcybOefilqMNhM7F5Z5pd/6ed5+JGHMGp+6ed/gaXdK/z27/8ub/7p9/O2P/wgv/nWn2Bl\nzwVcdNHZ9Hu7SSmiWhFCoqo8L33ps3nLW0YZ9IsQQsNguMniwlxZmDAeNwyGY5ombN9kAk4dhw8f\n5soDFxNSs4UVyms34XyHSoTQ380Fz7qBhz771zhOk2KDc466jqjaZK2DMa5rup3OZI1Au18KM5IA\nEpYaqLpZQacACioekfxdZnm87bZzzpEs5s+TQCLJGtTlBaemNHVN0zSMRzWhAYuOc/bvp3I9Whei\nNSjjAE2MxLoBjThvVJWwtGsvQhesIUEGA0lxGA3QUcef/r+/zlWXn0+/t4BKVebPZ7BVFK86wCqc\nThmoLU6sbRvJl50zVZKAtwx+Wn4m30PJwEKM5LO+yAxl8VMmjne5l1tYNMP78q5mfTMFRUzGJTMe\nkIgUo5vADC8tK5Pvj/ceJTuBVa/L5jjiK4OU6FfKOHicJBJKsIRLVoy0A004J1hUVLJeSZANMYKv\nKmIIWT+KluswAoalwkglQBK+yraAJFmjWgBcdkJTmOi0lDLQEK+QDBGlsQzUHhfh+GeLTdaQWR5T\nC0DSDBCxwnappfz5x7w/O5+ZNE1g2apOJEXEMkRsF6cAluLkHFvAmCXSDLtqaeZ6AElGsEhH4Mjt\ntzPeWKNLtuuI4IiU7UodA1GUaIKLgbF6dp13iP555xJcRac2UmU4M7zVHHvkfrqVQ1Ik0eqGPG9t\n8o6llqXNdrXVmKp+Em1
"text/plain": [
"<matplotlib.figure.Figure at 0x11f18e780>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"imgs = batch[\"image\"][0:10].permute(0,2,3,1)\n",
"bboxes = batch[\"label\"][0:10, 0, 1:]\n",
"\n",
"axes = d2l.show_images(imgs, 2, 5).flatten()\n",
"for ax, bb in zip(axes, bboxes):\n",
" d2l.show_bboxes(ax, [bb*edge_size], colors=['w'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}