{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 9.6.0 准备皮卡丘数据集" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "from tqdm import tqdm\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from mxnet.gluon import utils as gutils # pip install mxnet\n", "from mxnet import image\n", "\n", "data_dir = '../../data/pikachu'\n", "os.makedirs(data_dir, exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. 下载原始数据集\n", "见http://zh.d2l.ai/chapter_computer-vision/object-detection-dataset.html" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def _download_pikachu(data_dir):\n", " root_url = ('https://apache-mxnet.s3-accelerate.amazonaws.com/'\n", " 'gluon/dataset/pikachu/')\n", " dataset = {'train.rec': 'e6bcb6ffba1ac04ff8a9b1115e650af56ee969c8',\n", " 'train.idx': 'dcf7318b2602c06428b9988470c731621716c393',\n", " 'val.rec': 'd6c33f799b4d058e82f2cb5bd9a976f69d72d520'}\n", " for k, v in dataset.items():\n", " gutils.download(root_url + k, os.path.join(data_dir, k), sha1_hash=v)\n", "\n", "if not os.path.exists(os.path.join(data_dir, \"train.rec\")):\n", " print(\"下载原始数据集到%s...\" % data_dir)\n", " _download_pikachu(data_dir)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. MXNet数据迭代器" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def load_data_pikachu(batch_size, edge_size=256): # edge_size:输出图像的宽和高\n", " train_iter = image.ImageDetIter(\n", " path_imgrec=os.path.join(data_dir, 'train.rec'),\n", " path_imgidx=os.path.join(data_dir, 'train.idx'),\n", " batch_size=batch_size,\n", " data_shape=(3, edge_size, edge_size), # 输出图像的形状\n", "# shuffle=False, # 以随机顺序读取数据集\n", "# rand_crop=1, # 随机裁剪的概率为1\n", " min_object_covered=0.95, max_attempts=200)\n", " val_iter = image.ImageDetIter(\n", " path_imgrec=os.path.join(data_dir, 'val.rec'), batch_size=batch_size,\n", " data_shape=(3, edge_size, edge_size), shuffle=False)\n", " return train_iter, val_iter" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((3, 256, 256), (1, 5))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch_size, edge_size = 1, 256\n", "train_iter, val_iter = load_data_pikachu(batch_size, edge_size)\n", "batch = train_iter.next()\n", "batch.data[0][0].shape, batch.label[0][0].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. 转换成PNG图片并保存" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def process(data_iter, save_dir):\n", " \"\"\"batch size == 1\"\"\"\n", " data_iter.reset() # 从头开始\n", " all_label = dict()\n", " id = 1\n", " os.makedirs(os.path.join(save_dir, 'images'), exist_ok=True)\n", " for sample in tqdm(data_iter):\n", " x = sample.data[0][0].asnumpy().transpose((1,2,0))\n", " plt.imsave(os.path.join(save_dir, 'images', str(id) + '.png'), x / 255.0)\n", "\n", " y = sample.label[0][0][0].asnumpy()\n", "\n", " label = {}\n", " label[\"class\"] = int(y[0])\n", " label[\"loc\"] = y[1:].tolist()\n", "\n", " all_label[str(id) + '.png'] = label.copy()\n", "\n", " id += 1\n", "\n", " with open(os.path.join(save_dir, 'label.json'), 'w') as f:\n", " json.dump(all_label, f, indent=True)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "900it [00:40, 22.03it/s]\n" ] } ], "source": [ "process(data_iter = train_iter, save_dir = os.path.join(data_dir, \"train\"))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100it [00:04, 22.86it/s]\n" ] } ], "source": [ "process(data_iter = val_iter, save_dir = os.path.join(data_dir, \"val\"))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.2" } }, "nbformat": 4, "nbformat_minor": 2 }