{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 10.8 文本情感分类:使用卷积神经网络(textCNN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-04T15:24:30.611583Z",
     "start_time": "2019-07-04T15:24:28.120724Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0.0 cuda\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "from torch import nn\n",
    "import torchtext.vocab as Vocab\n",
    "import torch.utils.data as Data\n",
    "import  torch.nn.functional as F\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\") \n",
    "import d2lzh_pytorch as d2l\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "DATA_ROOT = \"/S1/CSCL/tangss/Datasets\"\n",
    "print(torch.__version__, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.8.1 一维卷积层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-04T15:24:30.618608Z",
     "start_time": "2019-07-04T15:24:30.614302Z"
    }
   },
   "outputs": [],
   "source": [
    "def corr1d(X, K):\n",
    "    w = K.shape[0]\n",
    "    Y = torch.zeros((X.shape[0] - w + 1))\n",
    "    for i in range(Y.shape[0]):\n",
    "        Y[i] = (X[i: i + w] * K).sum()\n",
    "    return Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-04T15:24:30.634912Z",
     "start_time": "2019-07-04T15:24:30.621140Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 2.,  5.,  8., 11., 14., 17.])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X, K = torch.tensor([0, 1, 2, 3, 4, 5, 6]), torch.tensor([1, 2])\n",
    "corr1d(X, K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-04T15:24:30.645344Z",
     "start_time": "2019-07-04T15:24:30.637083Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 2.,  8., 14., 20., 26., 32.])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def corr1d_multi_in(X, K):\n",
    "    # 首先沿着X和K的第0维(通道维)遍历并计算一维互相关结果。然后将所有结果堆叠起来沿第0维累加\n",
    "    return torch.stack([corr1d(x, k) for x, k in zip(X, K)]).sum(dim=0)\n",
    "\n",
    "X = torch.tensor([[0, 1, 2, 3, 4, 5, 6],\n",
    "              [1, 2, 3, 4, 5, 6, 7],\n",
    "              [2, 3, 4, 5, 6, 7, 8]])\n",
    "K = torch.tensor([[1, 2], [3, 4], [-1, -3]])\n",
    "corr1d_multi_in(X, K)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.8.2 时序最大池化层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-04T15:24:30.650834Z",
     "start_time": "2019-07-04T15:24:30.647333Z"
    }
   },
   "outputs": [],
   "source": [
    "class GlobalMaxPool1d(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GlobalMaxPool1d, self).__init__()\n",
    "    def forward(self, x):\n",
    "         # x shape: (batch_size, channel, seq_len)\n",
    "        return F.max_pool1d(x, kernel_size=x.shape[2]) # shape: (batch_size, channel, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.8.3 读取和预处理IMDb数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-04T15:24:58.666425Z",
     "start_time": "2019-07-04T15:24:30.652855Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 12500/12500 [00:02<00:00, 4376.39it/s]\n",
      "100%|██████████| 12500/12500 [00:02<00:00, 4834.11it/s]\n",
      "100%|██████████| 12500/12500 [00:02<00:00, 4556.64it/s]\n",
      "100%|██████████| 12500/12500 [00:11<00:00, 1076.09it/s]\n"
     ]
    }
   ],
   "source": [
    "batch_size = 64\n",
    "train_data = d2l.read_imdb('train', data_root=os.path.join(DATA_ROOT, \"aclImdb\"))\n",
    "test_data = d2l.read_imdb('test', data_root=os.path.join(DATA_ROOT, \"aclImdb\"))\n",
    "vocab = d2l.get_vocab_imdb(train_data)\n",
    "train_set = Data.TensorDataset(*d2l.preprocess_imdb(train_data, vocab))\n",
    "test_set = Data.TensorDataset(*d2l.preprocess_imdb(test_data, vocab))\n",
    "train_iter = Data.DataLoader(train_set, batch_size, shuffle=True)\n",
    "test_iter = Data.DataLoader(test_set, batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.8.4 textCNN模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-04T15:24:58.674283Z",
     "start_time": "2019-07-04T15:24:58.668832Z"
    }
   },
   "outputs": [],
   "source": [
    "class TextCNN(nn.Module):\n",
    "    def __init__(self, vocab, embed_size, kernel_sizes, num_channels):\n",
    "        super(TextCNN, self).__init__()\n",
    "        self.embedding = nn.Embedding(len(vocab), embed_size)\n",
    "        # 不参与训练的嵌入层\n",
    "        self.constant_embedding = nn.Embedding(len(vocab), embed_size)\n",
    "        self.dropout = nn.Dropout(0.5)\n",
    "        self.decoder = nn.Linear(sum(num_channels), 2)\n",
    "        # 时序最大池化层没有权重,所以可以共用一个实例\n",
    "        self.pool = GlobalMaxPool1d()\n",
    "        self.convs = nn.ModuleList()  # 创建多个一维卷积层\n",
    "        for c, k in zip(num_channels, kernel_sizes):\n",
    "            self.convs.append(nn.Conv1d(in_channels = 2*embed_size, \n",
    "                                        out_channels = c, \n",
    "                                        kernel_size = k))\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        # 将两个形状是(批量大小, 词数, 词向量维度)的嵌入层的输出按词向量连结\n",
    "        embeddings = torch.cat((\n",
    "            self.embedding(inputs), \n",
    "            self.constant_embedding(inputs)), dim=2) # (batch, seq_len, 2*embed_size)\n",
    "        # 根据Conv1D要求的输入格式,将词向量维,即一维卷积层的通道维(即词向量那一维),变换到前一维\n",
    "        embeddings = embeddings.permute(0, 2, 1)\n",
    "        # 对于每个一维卷积层,在时序最大池化后会得到一个形状为(批量大小, 通道大小, 1)的\n",
    "        # Tensor。使用flatten函数去掉最后一维,然后在通道维上连结\n",
    "        encoding = torch.cat([self.pool(F.relu(conv(embeddings))).squeeze(-1) for conv in self.convs], dim=1)\n",
    "        # 应用丢弃法后使用全连接层得到输出\n",
    "        outputs = self.decoder(self.dropout(encoding))\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-04T15:24:58.764854Z",
     "start_time": "2019-07-04T15:24:58.675824Z"
    }
   },
   "outputs": [],
   "source": [
    "embed_size, kernel_sizes, nums_channels = 100, [3, 4, 5], [100, 100, 100]\n",
    "net = TextCNN(vocab, embed_size, kernel_sizes, nums_channels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.8.4.1 加载预训练的词向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-04T15:25:00.616142Z",
     "start_time": "2019-07-04T15:24:58.766569Z"
    }
   },
   "outputs": [],
   "source": [
    "glove_vocab = Vocab.GloVe(name='6B', dim=100, cache=os.path.join(DATA_ROOT, \"glove\"))\n",
    "net.embedding.weight.data.copy_(\n",
    "    d2l.load_pretrained_embedding(vocab.itos, glove_vocab))\n",
    "net.constant_embedding.weight.data.copy_(\n",
    "    d2l.load_pretrained_embedding(vocab.itos, glove_vocab))\n",
    "net.constant_embedding.weight.requires_grad = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.8.4.2 训练并评价模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-04T15:28:36.938512Z",
     "start_time": "2019-07-04T15:25:00.618194Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.4811, train acc 0.762, test acc 0.848, time 42.6 sec\n",
      "epoch 2, loss 0.1601, train acc 0.864, test acc 0.869, time 42.3 sec\n",
      "epoch 3, loss 0.0714, train acc 0.915, test acc 0.879, time 42.3 sec\n",
      "epoch 4, loss 0.0289, train acc 0.958, test acc 0.867, time 42.3 sec\n",
      "epoch 5, loss 0.0124, train acc 0.979, test acc 0.861, time 42.3 sec\n"
     ]
    }
   ],
   "source": [
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr)\n",
    "loss = nn.CrossEntropyLoss()\n",
    "d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-04T15:28:36.945999Z",
     "start_time": "2019-07-04T15:28:36.940672Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'positive'"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d2l.predict_sentiment(net, vocab, ['this', 'movie', 'is', 'so', 'great'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-04T15:28:36.954105Z",
     "start_time": "2019-07-04T15:28:36.947516Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'negative'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d2l.predict_sentiment(net, vocab, ['this', 'movie', 'is', 'so', 'bad'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py36]",
   "language": "python",
   "name": "conda-env-py36-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}