You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
426 lines
219 KiB
426 lines
219 KiB
3 years ago
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# 9.9 语义分割和数据集"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"1.2.0 0.4.0a0+6b959ee\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%matplotlib inline\n",
|
||
|
"import time\n",
|
||
|
"import torch\n",
|
||
|
"import torch.nn.functional as F\n",
|
||
|
"import torchvision\n",
|
||
|
"import numpy as np\n",
|
||
|
"from PIL import Image\n",
|
||
|
"from tqdm import tqdm\n",
|
||
|
"\n",
|
||
|
"import sys\n",
|
||
|
"sys.path.append(\"..\") \n",
|
||
|
"import d2lzh_pytorch as d2l\n",
|
||
|
"\n",
|
||
|
"print(torch.__version__, torchvision.__version__)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## 9.9.2 Pascal VOC2012语义分割数据集"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\u001b[1m\u001b[36mAnnotations\u001b[m\u001b[m \u001b[1m\u001b[36mJPEGImages\u001b[m\u001b[m \u001b[1m\u001b[36mSegmentationObject\u001b[m\u001b[m\r\n",
|
||
|
"\u001b[1m\u001b[36mImageSets\u001b[m\u001b[m \u001b[1m\u001b[36mSegmentationClass\u001b[m\u001b[m\r\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"!ls ../../data/VOCdevkit/VOC2012"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {
|
||
|
"collapsed": true
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# 本函数已保存在d2lzh_pytorch中方便以后使用\n",
|
||
|
"def read_voc_images(root=\"../../data/VOCdevkit/VOC2012\", \n",
|
||
|
" is_train=True, max_num=None):\n",
|
||
|
" txt_fname = '%s/ImageSets/Segmentation/%s' % (\n",
|
||
|
" root, 'train.txt' if is_train else 'val.txt')\n",
|
||
|
" with open(txt_fname, 'r') as f:\n",
|
||
|
" images = f.read().split()\n",
|
||
|
" if max_num is not None:\n",
|
||
|
" images = images[:min(max_num, len(images))]\n",
|
||
|
" features, labels = [None] * len(images), [None] * len(images)\n",
|
||
|
" for i, fname in tqdm(enumerate(images)):\n",
|
||
|
" features[i] = Image.open('%s/JPEGImages/%s.jpg' % (root, fname)).convert(\"RGB\")\n",
|
||
|
" labels[i] = Image.open('%s/SegmentationClass/%s.png' % (root, fname)).convert(\"RGB\")\n",
|
||
|
" return features, labels # PIL image"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"100it [00:01, 54.94it/s]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"voc_dir = \"../../data/VOCdevkit/VOC2012\"\n",
|
||
|
"train_features, train_labels = read_voc_images(voc_dir, max_num=100)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkMAAADUCAYAAACfxSdJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvXm8bedd3vd9hzXu8cznnnNnXV1NlmXZkjxhYxEbOwEC\nDUMaxqQp2LjQUoY0JGmbtJBA0qQhlISUgIHYKcUGgwNmcHBsjCPbsixLsnylO+nOZz5nz2t6h/6x\n9r2yIQVTbj/8of2ccz9377XW2evd67fe933e3/As4b1nhhlmmGGGGWaY4aUK+efdgBlmmGGGGWaY\nYYY/T8zI0AwzzDDDDDPM8JLGjAzNMMMMM8wwwwwvaczI0AwzzDDDDDPM8JLGjAzNMMMMM8wwwwwv\naczI0AwzzDDDDDPM8JLGjAzNMMMMM8wwwwwvaczI0AwzzDDDDDPM8JLGjAzNMMMMM8wwwwwvaczI\n0AwzzDDDDDPM8JKG/vNuwAx/Pmi3W35peREhFNZWt7YLPFXlCMIQIcBZC9SPbPHeI6Xk1iNcvvBR\nLuLme4HHf/G+m58w3Xbzdf3W47zHe4F3Duc91hlkNUJKiXMOIQRCCvDgnEUIiRB1ewQghGC6AQRI\nITDWIoRAipvNsvQHEEUK7x1SKLyQIARCCNz03M45PAJ/q31i2lZPEAb1OUW9XQgxfS9uXZN634t/\nd+v73tw53ZYXFVVlbh39Z8XCwrw/cuwoWxvbrK0d4uxmn2Y3ZUF7ehWspBHgMWVFEARMsiHeebxz\nJM0mn3vqaZZWV9ne3SMKIvLxiJVmC4+nsBYXBTTimLTRYDgcUxU5SimklARRiLGOsshoNVtoHeC8\nIwpDxpOMvCgQzhAnKXlREAQBzWaTbDKmqixeQFVWtJtNrLM4PJPRmDAMCcIIUxmarSbeWfr9PghB\nFEVUVUUcR4j+COMcgVAYbxEI8ILKWZSUIGtbOilZOn4UASitEAisrY+xxlD0dnnu6ian1ucJlOPG\n9R4nH3gVtirQQQQ4QOKcRcovXkdeunSZ3d3d22LPMEp82OqAA49HSpBC133DVmitMNYjRX3vSxlg\nTAmuQCDxwiNkgJQh3ju8dzhXIQCpAgIBpTEgZd2XqPuPc9N727vpPWsBhRQCIUALR2ktENT9HcCV\neG8BXjyfkARS47wDqTCmRCoQToAKwJUvXkNXIJXHe3erL3rvcdagg2Dap+uT1eOKQEmB8x6JwrgS\nKTRC1H03Hxu8FUgpkVISBgFSCrRSaK2x1qJDTRylKKURAoyxRHGMnY4ZGxsb9Hq922LLRqvlF5eW\n6vFLCKT0OC9BeLQUZFlef0UlUAKSRhu8RXgFwtUf4gWTyZhJ3scLhXAeJSWNtInWijCM8L4eF521\nCCnBe6SSSFFfh1tjFgKB5EUDTk+BwXmLdx7r3K1rHYUBeV4Rx+H0yPoTbr6+NTd8wTsx9bE4b5FC\nUtoKcCDqdgkkHo8QDiVjjKmm5xe3xnYpBA7wzqFUhBSC0mQopZHo+j75z1zvm23wvm6/t45nnn5y\n13u/9CfZakaGXqJYWl7kn/7kjzDoFyhvAHDOoaTg3LlzPPiqt2HKHawrKIsc721NWG5O8N7WnVIp\nvPdYVyG8qwcy5/DUHdmamsx4Ww9+1loKUyJ9PQjZylAJgSkto6zEWstgPKE8835arQ5SWZJGjCkt\nRVESRSHWWsqyIgpiosQwGRuiGKIkxuYlzW6KKUrKQlAVOTqCqqr4wO9YDh3yzLcaqEBSRi129/dZ\nXV1ie2+XJG5SVCXWK3xlGFeCySSnLHO8CFhbXyVQ02nCeJrNJlVV4b1HB4IgCCmKAiHqSVZQXxut\nNVJKsiy7RSg//dTF22rPUa/H43vb9M9v81sfeg+nv+7tvP/qVX74VffxE+e2WCgGvOOBw1y5dI0j\nR07x7DO/RznKEQSsrS9y/JFHede7fpJvf8f3c+bZT/B96TqL6yvQnedqPmThvvtYWVnkTW/9Cj76\n2Cd56onPogUsLi7zpje+gWFR8KEP/gb/1Xe9g6XuIs+cf5YTK4d47FOfZmNnn8Hede694x6efu7z\nvP6tf5HDh5a5vnmBZx57lkYUs7F9g0de8zD3nLqbJ597js899TRVNeTBB1/LeDjinvtfTieWfPhj\nH8c5x4kTJ7h8+Sp3HD7EC+/9bR6YX2R/koNwLM7Nc6O/w2u+91u48MHH2bjwPFvDjH/0iY8iKJEy\nYGtnC+89y0urWOHZv7HB8toKeW+bx9/z03zjf/tP+K1f/gesuwl2+Bzf/Xd/kw9sjNHTyeDylcsc\nO3ri1uD/0EMP3TZbtlopL/uL30FmLcpajLAEQYMoSrCVoah6pJ0lIhmS965TlTmd9gL9iUQ0JGQD\nyjInDZqMsh5BlNBMUoxO0NUIZ7eoXIc1OaYfrzMY9Gl1u/T7PQLlccYipcH6kqoqkAKSJGV/+wYv\nX1vnepWQRvuM9iaUYQo0QAjiRhMVzTMf5hT7Y/bLHYwucHaPlvY0m6+gNxqjVJ8iu0R73iDLHKcT\ngtijZIiQmkAKnDAU1uBthVYRgZRYPE46hNVo5xiPM+IwonSephc4EbHUWOY3fvVxANI0RQnNkfVF\nVubmWF1dJWomnDx5J3ff90r2D/q0Wh0Gg33m5+cBSVFkfOd3vv222XJxeYm//WP/mDBUBME8YZBT\nWSjo02zM8/ynnqDVChg5S6Ulj77+YZRuEoQplS/QXuBMxfOf/zgf+pV/TaCXKIoxpx58LadO38mj\nD/4lKjUmLyY0GzFYx1NPP07UjEB6FAtEDYlWFcLU45CVKc6VeGsIXIzxOY6M0HuCIMJ6T5GNkYHi\nNa94Pecvb3HqjiUUIXiBFAECgcOAnxLRW9QknC74PHnhaMQBv/Dr/xq8JAgW8N4ShA7nBFb20XYO\nZ6HEIbwiCKHwlnwcE4gxJgzRwiG9Ji8z/sKb3sbq3AJGNG4tRCU1AXLe4wVoaiJVWkeVVxxtxpe/\nFFvNyNBLFMaUfP7pJzly5J5biwSlFN461tYPMxocoFRFo9WkKCc4awCJlBrvPUrV/3tncN7Uqyxf\nr+yqqsJaD6K+WYX3SKVwriZIgVSAqFfrzqEBrwRKBSBBekelG2g8iABXCmRgCW2AtRZjSpSQxIHG\n+jFx3MCYnPFwTJq02d/rEQcxcWopc0UoAioXouJt4niOKFDYICAvC5zV6AAcAqk8ygW4oiQINV5J\njDF4HyC8JJIaqWsyo1NNnk8IghC8wuAR0+8nhEAphalcPfi4irIUhGGKxFOU2W23Z7q4iCXj333w\n5/nGb/hWOofafPJglSeubLIsFKlu4J3k+PHjCDxUiihuUlUF165dY/Opx/mpf/ZjvPfnf5r5M/v8\n4sEBpw6v8p1HjlC1NIFUnLrrNMODEaGI6C6sUA0PiBsRi2uHOf+JTzC/uMATj3+CBx5+GFcIJqai\n0U45//jjLEaK7soC/mnY391BUaFLQagc7cVFdnt7tNtthFZILA899AgvPP88YSvhYDLB+ZK8iMiH\nfdK5Qxwc9Dly+BAPPfgw5cgx+P0nabRShv0JggChIi7+/H9gPzK4sqIbhoyLHFvkNNM24+GAkyfv\nBSr6vV1W1o7isejmMm/8b36Uq9/+Tp791V/mje/8n/nw+/933v2zdzB57w/wi09M+Js/8k85evQo\nkCFcAyvNbbWlN4bxaBP8CKk7JCqhzHoMsn103MGWY8rNy0zsGGQX1TxBgcK4AY1M4ZMGc+15IrNL\nbhxpAJMqwxQZxztd5pIl5prH2S220RNJwwdU1S7OlIhonnJ0hbYumXhNoFuE6SKIiruOKDb715gL\nDeODFqpzlFgnOB8QBwJUghKezc0bHFkoyYxikvXRhPRH58mLMd25RYJ4wKSR44wkbLdIVUxhJ4BF\nKMfAViRJgswMvlL4KiMjQgUgrMMax+m1e7jqzyOiAFU4Qp2iKs/2wWXwBqkCrBG0WoqsMiTtJq1W\ni62tLZ6fFBw/fjdpnGC
|
||
|
"text/plain": [
|
||
|
"<matplotlib.figure.Figure at 0x125436a90>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"n = 5\n",
|
||
|
"imgs = train_features[0:n] + train_labels[0:n]\n",
|
||
|
"d2l.show_images(imgs, 2, n);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {
|
||
|
"collapsed": true
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# 本函数已保存在d2lzh_pytorch中方便以后使用\n",
|
||
|
"VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],\n",
|
||
|
" [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],\n",
|
||
|
" [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],\n",
|
||
|
" [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],\n",
|
||
|
" [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],\n",
|
||
|
" [0, 64, 128]]\n",
|
||
|
"# 本函数已保存在d2lzh_pytorch中方便以后使用\n",
|
||
|
"VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',\n",
|
||
|
" 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',\n",
|
||
|
" 'diningtable', 'dog', 'horse', 'motorbike', 'person',\n",
|
||
|
" 'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {
|
||
|
"collapsed": true
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"colormap2label = torch.zeros(256 ** 3, dtype=torch.uint8)\n",
|
||
|
"for i, colormap in enumerate(VOC_COLORMAP):\n",
|
||
|
" colormap2label[(colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i\n",
|
||
|
"\n",
|
||
|
"# 本函数已保存在d2lzh_pytorch中方便以后使用\n",
|
||
|
"def voc_label_indices(colormap, colormap2label):\n",
|
||
|
" \"\"\"\n",
|
||
|
" convert colormap (PIL image) to colormap2label (uint8 tensor).\n",
|
||
|
" \"\"\"\n",
|
||
|
" colormap = np.array(colormap.convert(\"RGB\")).astype('int32')\n",
|
||
|
" idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256\n",
|
||
|
" + colormap[:, :, 2])\n",
|
||
|
" return colormap2label[idx]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],\n",
|
||
|
" [0, 0, 0, 0, 0, 0, 0, 1, 1, 1],\n",
|
||
|
" [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],\n",
|
||
|
" [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],\n",
|
||
|
" [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],\n",
|
||
|
" [0, 0, 0, 0, 1, 1, 1, 1, 1, 1],\n",
|
||
|
" [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],\n",
|
||
|
" [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],\n",
|
||
|
" [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],\n",
|
||
|
" [0, 0, 0, 0, 0, 0, 0, 0, 1, 1]], dtype=torch.uint8), 'aeroplane')"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"y = voc_label_indices(train_labels[0], colormap2label)\n",
|
||
|
"y[105:115, 130:140], VOC_CLASSES[1]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### 9.9.2.1 预处理数据"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkMAAADMCAYAAABwQKe/AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvWmwbcd13/db3Xs459zhTXgPwMNEQCRASSBBShw0UZRN\nMZ40xBmKUUofVHIqSZUrVc5Qdr6m7HLiOJWKSymnZMuJJVnlyLJmyRpiiSJBcRQgkoJIDCSAh+kB\neNMdzzl77+5e+dDde+9z73sPeOGlrBTvIh/uvWfYZ5/uXqv/a63/Wi2qyrEcy7Ecy7Ecy7Ecyzeq\nmH/XN3Asx3Isx3Isx3Isx/LvUo7B0LEcy7Ecy7Ecy7F8Q8sxGDqWYzmWYzmWYzmWb2g5BkPHcizH\ncizHcizH8g0tx2DoWI7lWI7lWI7lWL6h5RgMHcuxHMuxHMuxHMs3tByDoWM5lmM5lmM5lmP5hpZj\nMHQsx3Isx3Isx3Is39ByDIaO5ViO5ViO5ViO5RtajsHQsRzLsRzLsRzLsXxDS3ErL948fUrP3nUX\nAoAgAiEEVMEYQSRiK9WAyb+jCIJI/BdCQIzEx4ygIYAIaPzRi4CsfLoc+E1AdPTY6quv/87rPyMC\n+VQSufkb3uB6YxkfcyL9YxoCIYT4mMho7ATUx7elsSIEMIblYkk9qXnphRe5evnKm/v4N5CN06f1\n7F13k+dxLMP4Hv4G+ZdD4y3XmQEZXes6Iul9UfQ6rxw/d50r5SVw3fs/+M4by62+/ijkpeef5+rl\ny0f2USfO3KZ33HvvddfvoZETDj5yndfnPw7MSB6g9DP/erPrjZ/7sxrfg5/tg8cYe0vvG9+rHnpU\n8/8BePnCBa4d0Xyunzylp8/fFW1AEjk4D2nxy6F5HT1yHRt6/fVxeEEMa0bTdd7Qiq4oo6ArGn29\nd49VP20Bh173tQ7oG73/ep95lLp54swZPXfPfav2c2Wvy2OfR+vmH3twHaw8d93Xp19GX/TNfLGv\nh44evOZRHgR2cB519MufPv7YZVU9+0bXuCUwdPb8XfzDX/rXiAjGRLAjoiyXS0Qs+Zwz7zvWp+tM\np1MgAyaP9575fI61lhACk8kM7z11XTKbzSgrS9e0FEUFBApjEaOICFbj9cXEzzAjgBXvIy4Uo6N7\nY3XxGKS/x4OLyo4WZQZtVgyqijGGEALpsum9Bgj936qK1dB/37Zd0nUdrSqlsdR1HAtrhXmzBDX4\nICy7Bt81TCcV1ne0ndKkzxTv8GL4wh9/kQ996EP80Hd94Fam66Zy9q57+Lu//GtYa+NY2mFMbR4b\nYzDp+63O+eG/QbFiRvORnk/fJUpYfS7/rqvzkec8TUd8DQEhbmbWENfC6H0igk3z2z8GgKb7jYBX\nBMyNtFCUwGgt8LUZhYzVRfI3H+QH3vOer+HKh+WOe+/ln37s40BYmbM8FlYMKsN45bnOYswQJhaJ\nz/fPjcYYhvHLr+nX/3jcDm7eo0EXPfz88PgwT9pvGsN7+3scbbdvJF3XsWgbZrMZRsxom47X6MGc\nRn0PDDqfbUEWTbed15kGIQj89fe9903dy5uR0+fv4r//l78YbYASddSYXm+yTTLGYCX+rnlujEFU\nsdYCIepZiH+LCBIUNUJpLIrvrxkwiCh29BlZ/00aqThI6XNGOhwfHttiRYKuPG/sWB+kvxcAm76n\nqo/fVTlkp8WQXpPGfRVDJBsCoX88jOzAjbV4vBbyq45SN8/dcx8/8XuPxjHJuoigyQG2CEjo96Y8\nz/m7qw52VUb7j9FAEJPmymOU1fkYva6fGx1soaTxPWh7D4kMdnZVJwcxOoz7jWRscyPwPaC7euML\nHLz2weNUg4JJ95bXjSX+riHwtrq8cPO7i3JLYEhQvO8QEbyPSloay6yeAFBVVQxmlAW+7XDO0XUd\n1lqapqGqKqpqQlEUNE3DcrmkqCs8yvbeLnVR4nxLXXrm8z3KsqQsDNPplMoWVFWFcw6CMm8aQghU\nVcXa2ho2G6lAb7yM5M0xKnQYKYiOFTMoKnExKRoXJUrcwoQQXHqPxI0GCMmQAL3RBJM+L4KfyWTG\n3DnatmV3scR7D8mABZSgBhcCt587R/AdYgr8vMPPF8y9IwSDEaVe3+Ty9i7O+1uZrjeQ+P2it6wQ\nYoTKjIxa3gTHCmrM8HNQCu1BVW9gAdQnhY9jGmQALVnBgg+ItUiCC70+qw6bejISmsAmId4nCdhE\nIxG3NTHSK/rKt03vNYfd5bg2iEpfjIDUsElyyPgO11WsEcJ1jMT4cwxKDIJ+PXwuAI2AsQc7Gj8+\njYXmMEaeX5EVMCO9MZbewc8GtQc5WWdGBtQoCfBq/4J4Hbuy+emKMRxA6xgwphtO+sqheVr9toej\nwjcCR2VZ0nVd1FM7gG/VqPeD8+zjmoK4dlEIZlgfeTcgRT5C/Gn16D1pEe03dFUf71UBDIWJdkEI\ngx3TLuopEVyQN8GkH1Y06oBAaYu4gZocyVeMxGh+Bp8ZVIUQMGWZohnSj9tYLBrXQHIeJTlXkqMd\nyZRaAxqkX5dCvseom4VE4HZIkTRAkLR802beD1Rag9l/SxvzODMBN56fN4rEfK0io30kAyJJ0Tab\n13taP8bISAfi66xNwFTS3Y7GxqIQfLSFJjmaAGlcLSN746MzL4AGHZyhfLkEhGFw3CIQS6/x0TaM\ngckY4JiUATroaIwGIn8Mq9Bz9fmDeq0o5jpAKQyqiDUQgq6syxAUI4oeckNvLLcEhiIKcBhTxg90\nLb4oEBGmdYURj7WKqMeWQtMs2VifIcDadJYU2tC1nnKtJjgPEgjOMa0qrFXqskYDbK5vUNc1k6rA\nIDjXEZyjMIbOd8wmNaj0yqQhIEbxQXojEBJaVq94onc6BkEQB9SKoASC5sc0AqmQELn6fnIEAc1x\noWQ0kocTjCAEvGvjY6LQObRp8D4QKAjGskypMpM+6+WLr2GtxVpLYQw2tNRovJ4YvPdcuXaVcJ1N\n/muRsiyHTelAJGElcmNM/7rsvYzHENIiNwFBVqMKOixwY0yvQL2XZG3ceAYUFB8feUcDqF3dlAev\nJj0XdNDtlc0axByIVIyUNc/3WD/zSOdNeWx8+9fkTZXrAKHryNPPPMlDD37zm3rtrYqQvOGRwewj\nCOk12ZgZkwCBJpcqGeS45odN3xB/RrUIq/N6KOYeXQRIcxc8oY+qDnMX5zLE+x0N7ACwrme8Rh5v\nj1APv66/JqtGVYOws7XNuTvOJicnAkMj8dXZiB7c5AF8+hxRzRmj4fMkGuCj1cq0lpK7ayWBQ8lR\ng7Ht8sk+xaiBEAGUNUlPZfX12akZf9+s39mByfZNiDphiwINKcKUwNdgK5JeJxAU711iZHcEUHJk\nF7KzlUGSHMpvSLLFeePPWjqOTMANorsrnkp2Cm5N2Q46O0cheQ6Ty4GI6UMs0eGL//qIHESwmObF\nZt3SIS0qRkkuZ9z0FURCem2aQ4kg0sTwXoxIBY1gOYkmvdYwOKu9qRzb0JH9NAxz3Y9bvq9VDRn9\nd+WhG8pBkCTEKGCvZTq6J8Az2OH8uIjgreLVY+ybp0XfEhgKITCfz5lMYiRINFDaAu8cnYlKZfOg\npZ/eOSQoRWFBFdc5CAGxBZOqoCwLVJWiiLdixaB4vFOsekLrUBE0TbxzrvdY8mCFMHgfGvywkRmD\ncw6DBQlIGIzt4EEHQu/xxO9orU1K59PPZHjFghpEDQaXFNTi0/Pqc3SjQmy8x7ouqespLnR4DTRN\nQ+MMHvAqFEWJC54QBHUOnzZ/o4r3gc45zpzY7EHlUUn0EDxGLIJgk0KON0iTl2YIESyoxnB3r9yj\nNCEeCdHQKXFD7IFKVj715MiE4tNmOxiBnKoZQrkk65fASgI92ecTIwlQ0lvJ3usfbcAwGH7IG+9I\n8cbGOD3X74v9c4OVHNJJbz5V41HOnTuXNvGvjzc6BkFi4t9WhjU9fk0WzaDnBreU05HGDIDU9vri\nR585hOOjToZ8gWyeh+v
|
||
|
"text/plain": [
|
||
|
"<matplotlib.figure.Figure at 0x126c4a390>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# 本函数已保存在d2lzh_pytorch中方便以后使用\n",
|
||
|
"def voc_rand_crop(feature, label, height, width):\n",
|
||
|
" \"\"\"\n",
|
||
|
" Random crop feature (PIL image) and label (PIL image).\n",
|
||
|
" \"\"\"\n",
|
||
|
" i, j, h, w = torchvision.transforms.RandomCrop.get_params(\n",
|
||
|
" feature, output_size=(height, width))\n",
|
||
|
" \n",
|
||
|
" feature = torchvision.transforms.functional.crop(feature, i, j, h, w)\n",
|
||
|
" label = torchvision.transforms.functional.crop(label, i, j, h, w) \n",
|
||
|
"\n",
|
||
|
" return feature, label\n",
|
||
|
"\n",
|
||
|
"imgs = []\n",
|
||
|
"for _ in range(n):\n",
|
||
|
" imgs += voc_rand_crop(train_features[0], train_labels[0], 200, 300)\n",
|
||
|
"d2l.show_images(imgs[::2] + imgs[1::2], 2, n);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### 9.9.2.2 自定义语义分割数据集类"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"metadata": {
|
||
|
"collapsed": true
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# 本函数已保存在d2lzh_pytorch中方便以后使用\n",
|
||
|
"class VOCSegDataset(torch.utils.data.Dataset):\n",
|
||
|
" def __init__(self, is_train, crop_size, voc_dir, colormap2label, max_num=None):\n",
|
||
|
" \"\"\"\n",
|
||
|
" crop_size: (h, w)\n",
|
||
|
" \"\"\"\n",
|
||
|
" self.rgb_mean = np.array([0.485, 0.456, 0.406])\n",
|
||
|
" self.rgb_std = np.array([0.229, 0.224, 0.225])\n",
|
||
|
" self.tsf = torchvision.transforms.Compose([\n",
|
||
|
" torchvision.transforms.ToTensor(),\n",
|
||
|
" torchvision.transforms.Normalize(mean=self.rgb_mean, \n",
|
||
|
" std=self.rgb_std)\n",
|
||
|
" ])\n",
|
||
|
" \n",
|
||
|
" self.crop_size = crop_size # (h, w)\n",
|
||
|
" features, labels = read_voc_images(root=voc_dir, \n",
|
||
|
" is_train=is_train, \n",
|
||
|
" max_num=max_num)\n",
|
||
|
" self.features = self.filter(features) # PIL image\n",
|
||
|
" self.labels = self.filter(labels) # PIL image\n",
|
||
|
" self.colormap2label = colormap2label\n",
|
||
|
" print('read ' + str(len(self.features)) + ' valid examples')\n",
|
||
|
"\n",
|
||
|
" def filter(self, imgs):\n",
|
||
|
" return [img for img in imgs if (\n",
|
||
|
" img.size[1] >= self.crop_size[0] and\n",
|
||
|
" img.size[0] >= self.crop_size[1])]\n",
|
||
|
"\n",
|
||
|
" def __getitem__(self, idx):\n",
|
||
|
" feature, label = voc_rand_crop(self.features[idx], self.labels[idx],\n",
|
||
|
" *self.crop_size)\n",
|
||
|
" \n",
|
||
|
" return (self.tsf(feature),\n",
|
||
|
" voc_label_indices(label, self.colormap2label))\n",
|
||
|
"\n",
|
||
|
" def __len__(self):\n",
|
||
|
" return len(self.features)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### 9.9.2.3 读取数据集"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"100it [00:00, 104.07it/s]\n",
|
||
|
"6it [00:00, 56.42it/s]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"read 75 valid examples\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"100it [00:01, 56.74it/s]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"read 77 valid examples\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"crop_size = (320, 480)\n",
|
||
|
"max_num = 100\n",
|
||
|
"voc_train = VOCSegDataset(True, crop_size, voc_dir, colormap2label, max_num)\n",
|
||
|
"voc_test = VOCSegDataset(False, crop_size, voc_dir, colormap2label, max_num)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"metadata": {
|
||
|
"collapsed": true
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"batch_size = 64\n",
|
||
|
"num_workers = 0 if sys.platform.startswith('win32') else 4\n",
|
||
|
"train_iter = torch.utils.data.DataLoader(voc_train, batch_size, shuffle=True,\n",
|
||
|
" drop_last=True, num_workers=num_workers)\n",
|
||
|
"test_iter = torch.utils.data.DataLoader(voc_test, batch_size, drop_last=True,\n",
|
||
|
" num_workers=num_workers)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"torch.float32 torch.Size([64, 3, 320, 480])\n",
|
||
|
"torch.uint8 torch.Size([64, 320, 480])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"for X, Y in train_iter:\n",
|
||
|
" print(X.dtype, X.shape)\n",
|
||
|
" print(y.dtype, Y.shape)\n",
|
||
|
" break"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"collapsed": true
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.6.2"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|