{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 10.3 word2vec的实现" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0.0\n" ] } ], "source": [ "import collections\n", "import math\n", "import random\n", "import sys\n", "import time\n", "import os\n", "import numpy as np\n", "import torch\n", "from torch import nn\n", "import torch.utils.data as Data\n", "\n", "sys.path.append(\"..\") \n", "import d2lzh_pytorch as d2l\n", "print(torch.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10.3.1 处理数据集" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "assert 'ptb.train.txt' in os.listdir(\"../../data/ptb\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'# sentences: 42068'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with open('../../data/ptb/ptb.train.txt', 'r') as f:\n", " lines = f.readlines()\n", " # st是sentence的缩写\n", " raw_dataset = [st.split() for st in lines]\n", "\n", "'# sentences: %d' % len(raw_dataset)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# tokens: 24 ['aer', 'banknote', 'berlitz', 'calloway', 'centrust']\n", "# tokens: 15 ['pierre', '', 'N', 'years', 'old']\n", "# tokens: 11 ['mr.', '', 'is', 'chairman', 'of']\n" ] } ], "source": [ "for st in raw_dataset[:3]:\n", " print('# tokens:', len(st), st[:5])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 10.3.1.1 建立词语索引" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# tk是token的缩写\n", "counter = collections.Counter([tk for st in raw_dataset for tk in st])\n", "counter = dict(filter(lambda x: x[1] >= 5, counter.items()))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'# tokens: 887100'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idx_to_token = [tk for tk, _ in counter.items()]\n", "token_to_idx = {tk: idx for idx, tk in enumerate(idx_to_token)}\n", "dataset = [[token_to_idx[tk] for tk in st if tk in token_to_idx]\n", " for st in raw_dataset]\n", "num_tokens = sum([len(st) for st in dataset])\n", "'# tokens: %d' % num_tokens" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 10.3.1.2 二次采样" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'# tokens: 375647'" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def discard(idx):\n", " return random.uniform(0, 1) < 1 - math.sqrt(\n", " 1e-4 / counter[idx_to_token[idx]] * num_tokens)\n", "\n", "subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]\n", "'# tokens: %d' % sum([len(st) for st in subsampled_dataset])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'# the: before=50770, after=2043'" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def compare_counts(token):\n", " return '# %s: before=%d, after=%d' % (token, sum(\n", " [st.count(token_to_idx[token]) for st in dataset]), sum(\n", " [st.count(token_to_idx[token]) for st in subsampled_dataset]))\n", "\n", "compare_counts('the')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'# join: before=45, after=45'" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "compare_counts('join')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 10.3.1.3 提取中心词和背景词" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def get_centers_and_contexts(dataset, max_window_size):\n", " centers, contexts = [], []\n", " for st in dataset:\n", " if len(st) < 2: # 每个句子至少要有2个词才可能组成一对“中心词-背景词”\n", " continue\n", " centers += st\n", " for center_i in range(len(st)):\n", " window_size = random.randint(1, max_window_size)\n", " indices = list(range(max(0, center_i - window_size),\n", " min(len(st), center_i + 1 + window_size)))\n", " indices.remove(center_i) # 将中心词排除在背景词之外\n", " contexts.append([st[idx] for idx in indices])\n", " return centers, contexts" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]\n", "center 0 has contexts [1, 2]\n", "center 1 has contexts [0, 2]\n", "center 2 has contexts [0, 1, 3, 4]\n", "center 3 has contexts [1, 2, 4, 5]\n", "center 4 has contexts [3, 5]\n", "center 5 has contexts [4, 6]\n", "center 6 has contexts [4, 5]\n", "center 7 has contexts [8, 9]\n", "center 8 has contexts [7, 9]\n", "center 9 has contexts [7, 8]\n" ] } ], "source": [ "tiny_dataset = [list(range(7)), list(range(7, 10))]\n", "print('dataset', tiny_dataset)\n", "for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):\n", " print('center', center, 'has contexts', context)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10.3.2 负采样" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def get_negatives(all_contexts, sampling_weights, K):\n", " all_negatives, neg_candidates, i = [], [], 0\n", " population = list(range(len(sampling_weights)))\n", " for contexts in all_contexts:\n", " negatives = []\n", " while len(negatives) < len(contexts) * K:\n", " if i == len(neg_candidates):\n", " # 根据每个词的权重(sampling_weights)随机生成k个词的索引作为噪声词。\n", " # 为了高效计算,可以将k设得稍大一点\n", " i, neg_candidates = 0, random.choices(\n", " population, sampling_weights, k=int(1e5))\n", " neg, i = neg_candidates[i], i + 1\n", " # 噪声词不能是背景词\n", " if neg not in set(contexts):\n", " negatives.append(neg)\n", " all_negatives.append(negatives)\n", " return all_negatives\n", "\n", "sampling_weights = [counter[w]**0.75 for w in idx_to_token]\n", "all_negatives = get_negatives(all_contexts, sampling_weights, 5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10.3.3 读取数据" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def batchify(data):\n", " \"\"\"用作DataLoader的参数collate_fn: 输入是个长为batchsize的list, list中的每个元素都是__getitem__得到的结果\"\"\"\n", " max_len = max(len(c) + len(n) for _, c, n in data)\n", " centers, contexts_negatives, masks, labels = [], [], [], []\n", " for center, context, negative in data:\n", " cur_len = len(context) + len(negative)\n", " centers += [center]\n", " contexts_negatives += [context + negative + [0] * (max_len - cur_len)]\n", " masks += [[1] * cur_len + [0] * (max_len - cur_len)]\n", " labels += [[1] * len(context) + [0] * (max_len - len(context))]\n", " return (torch.tensor(centers).view(-1, 1), torch.tensor(contexts_negatives),\n", " torch.tensor(masks), torch.tensor(labels))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "centers shape: torch.Size([512, 1])\n", "contexts_negatives shape: torch.Size([512, 60])\n", "masks shape: torch.Size([512, 60])\n", "labels shape: torch.Size([512, 60])\n" ] } ], "source": [ "class MyDataset(torch.utils.data.Dataset):\n", " def __init__(self, centers, contexts, negatives):\n", " assert len(centers) == len(contexts) == len(negatives)\n", " self.centers = centers\n", " self.contexts = contexts\n", " self.negatives = negatives\n", " \n", " def __getitem__(self, index):\n", " return (self.centers[index], self.contexts[index], self.negatives[index])\n", "\n", " def __len__(self):\n", " return len(self.centers)\n", "\n", "batch_size = 512\n", "num_workers = 0 if sys.platform.startswith('win32') else 4\n", "\n", "dataset = MyDataset(all_centers, \n", " all_contexts, \n", " all_negatives)\n", "data_iter = Data.DataLoader(dataset, batch_size, shuffle=True,\n", " collate_fn=batchify, \n", " num_workers=num_workers)\n", "for batch in data_iter:\n", " for name, data in zip(['centers', 'contexts_negatives', 'masks',\n", " 'labels'], batch):\n", " print(name, 'shape:', data.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10.3.4 跳字模型\n", "### 10.3.4.1 嵌入层" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Parameter containing:\n", "tensor([[-2.8935, 1.9747, -0.2081, -0.6574],\n", " [ 1.3135, -1.7396, -1.4210, 1.3302],\n", " [-0.0465, 1.0802, -0.5344, 0.5250],\n", " [-0.6899, 1.1832, -0.1694, 0.1382],\n", " [-1.3940, -1.4121, 0.1867, 0.7681],\n", " [ 0.2224, -0.3751, 0.5170, 0.1359],\n", " [-1.4377, 0.4700, 0.5167, 0.8427],\n", " [ 1.5523, 0.0542, 1.2034, -0.1215],\n", " [-0.4874, -0.7876, -1.1580, 0.0728],\n", " [-1.4077, -0.8691, -0.8106, -0.0612],\n", " [-0.4633, -1.8948, 0.1791, 2.1354],\n", " [ 0.4180, 1.3088, 1.2537, 2.0183],\n", " [ 1.5453, 1.3754, -0.3551, 0.4333],\n", " [ 1.7966, -0.2033, -0.5374, -0.0457],\n", " [ 1.7540, 0.3209, 0.9063, 1.0655],\n", " [-0.2148, -0.0743, -1.9261, 1.1415],\n", " [-0.6571, -0.7888, 0.6224, 1.0660],\n", " [-1.5191, 1.7596, 0.8295, 0.8935],\n", " [ 0.4348, -0.2445, -0.6763, 1.5176],\n", " [ 0.2910, 0.4196, -1.6204, 1.8422]], requires_grad=True)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "embed = nn.Embedding(num_embeddings=20, embedding_dim=4)\n", "embed.weight" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 1.3135, -1.7396, -1.4210, 1.3302],\n", " [-0.0465, 1.0802, -0.5344, 0.5250],\n", " [-0.6899, 1.1832, -0.1694, 0.1382]],\n", "\n", " [[-1.3940, -1.4121, 0.1867, 0.7681],\n", " [ 0.2224, -0.3751, 0.5170, 0.1359],\n", " [-1.4377, 0.4700, 0.5167, 0.8427]]], grad_fn=)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.long)\n", "embed(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 10.3.4.2 小批量乘法" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 1, 6])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = torch.ones((2, 1, 4))\n", "Y = torch.ones((2, 4, 6))\n", "torch.bmm(X, Y).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 10.3.4.3 跳字模型前向计算" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def skip_gram(center, contexts_and_negatives, embed_v, embed_u):\n", " v = embed_v(center)\n", " u = embed_u(contexts_and_negatives)\n", " pred = torch.bmm(v, u.permute(0, 2, 1))\n", " return pred" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10.3.5 训练模型\n", "### 10.3.5.1 二元交叉熵损失函数" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "class SigmoidBinaryCrossEntropyLoss(nn.Module):\n", " def __init__(self): # none mean sum\n", " super(SigmoidBinaryCrossEntropyLoss, self).__init__()\n", " def forward(self, inputs, targets, mask=None):\n", " \"\"\"\n", " input – Tensor shape: (batch_size, len)\n", " target – Tensor of the same shape as input\n", " \"\"\"\n", " inputs, targets, mask = inputs.float(), targets.float(), mask.float()\n", " res = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction=\"none\", weight=mask)\n", " return res.mean(dim=1)\n", "\n", "loss = SigmoidBinaryCrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.8740, 1.2100])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred = torch.tensor([[1.5, 0.3, -1, 2], [1.1, -0.6, 2.2, 0.4]])\n", "# 标签变量label中的1和0分别代表背景词和噪声词\n", "label = torch.tensor([[1, 0, 0, 0], [1, 1, 0, 0]])\n", "mask = torch.tensor([[1, 1, 1, 1], [1, 1, 1, 0]]) # 掩码变量\n", "loss(pred, label, mask) * mask.shape[1] / mask.float().sum(dim=1)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.8740\n", "1.2100\n" ] } ], "source": [ "def sigmd(x):\n", " return - math.log(1 / (1 + math.exp(-x)))\n", "\n", "print('%.4f' % ((sigmd(1.5) + sigmd(-0.3) + sigmd(1) + sigmd(-2)) / 4)) # 注意1-sigmoid(x) = sigmoid(-x)\n", "print('%.4f' % ((sigmd(1.1) + sigmd(-0.6) + sigmd(-2.2)) / 3))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 10.3.5.2 初始化模型参数" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "collapsed": true }, "outputs": [], "source": [ "embed_size = 100\n", "net = nn.Sequential(\n", " nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size),\n", " nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 10.3.5.3 定义训练函数" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def train(net, lr, num_epochs):\n", " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", " print(\"train on\", device)\n", " net = net.to(device)\n", " optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n", " for epoch in range(num_epochs):\n", " start, l_sum, n = time.time(), 0.0, 0\n", " for batch in data_iter:\n", " center, context_negative, mask, label = [d.to(device) for d in batch]\n", " \n", " pred = skip_gram(center, context_negative, net[0], net[1])\n", " \n", " # 使用掩码变量mask来避免填充项对损失函数计算的影响\n", " l = (loss(pred.view(label.shape), label, mask) *\n", " mask.shape[1] / mask.float().sum(dim=1)).mean() # 一个batch的平均loss\n", " optimizer.zero_grad()\n", " l.backward()\n", " optimizer.step()\n", " l_sum += l.cpu().item()\n", " n += 1\n", " print('epoch %d, loss %.2f, time %.2fs'\n", " % (epoch + 1, l_sum / n, time.time() - start))" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train on cpu\n", "epoch 1, loss 1.97, time 74.53s\n", "epoch 2, loss 0.62, time 81.85s\n", "epoch 3, loss 0.45, time 74.49s\n", "epoch 4, loss 0.39, time 72.04s\n", "epoch 5, loss 0.37, time 72.21s\n", "epoch 6, loss 0.35, time 71.81s\n", "epoch 7, loss 0.34, time 72.00s\n", "epoch 8, loss 0.33, time 74.45s\n", "epoch 9, loss 0.32, time 72.08s\n", "epoch 10, loss 0.32, time 72.05s\n" ] } ], "source": [ "train(net, 0.01, 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10.3.6 应用词嵌入模型" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cosine sim=0.478: hard-disk\n", "cosine sim=0.446: intel\n", "cosine sim=0.440: drives\n" ] } ], "source": [ "def get_similar_tokens(query_token, k, embed):\n", " W = embed.weight.data\n", " x = W[token_to_idx[query_token]]\n", " # 添加的1e-9是为了数值稳定性\n", " cos = torch.matmul(W, x) / (torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9).sqrt()\n", " _, topk = torch.topk(cos, k=k+1)\n", " topk = topk.cpu().numpy()\n", " for i in topk[1:]: # 除去输入词\n", " print('cosine sim=%.3f: %s' % (cos[i], (idx_to_token[i])))\n", " \n", "get_similar_tokens('chip', 3, net[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }