{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 10.6 求近义词和类比词\n", "## 10.6.1 使用预训练的词向量" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0.0\n" ] }, { "data": { "text/plain": [ "dict_keys(['charngram.100d', 'fasttext.en.300d', 'fasttext.simple.300d', 'glove.42B.300d', 'glove.840B.300d', 'glove.twitter.27B.25d', 'glove.twitter.27B.50d', 'glove.twitter.27B.100d', 'glove.twitter.27B.200d', 'glove.6B.50d', 'glove.6B.100d', 'glove.6B.200d', 'glove.6B.300d'])" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torchtext.vocab as vocab\n", "\n", "print(torch.__version__)\n", "vocab.pretrained_aliases.keys()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['glove.42B.300d',\n", " 'glove.840B.300d',\n", " 'glove.twitter.27B.25d',\n", " 'glove.twitter.27B.50d',\n", " 'glove.twitter.27B.100d',\n", " 'glove.twitter.27B.200d',\n", " 'glove.6B.50d',\n", " 'glove.6B.100d',\n", " 'glove.6B.200d',\n", " 'glove.6B.300d']" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[key for key in vocab.pretrained_aliases.keys()\n", " if \"glove\" in key]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "cache_dir = \"/Users/tangshusen/Datasets/glove\"\n", "# glove = vocab.pretrained_aliases[\"glove.6B.50d\"](cache=cache_dir)\n", "glove = vocab.GloVe(name='6B', dim=50, cache=cache_dir) # 与上面等价" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "一共包含400000个词。\n" ] } ], "source": [ "print(\"一共包含%d个词。\" % len(glove.stoi))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(3366, 'beautiful')" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "glove.stoi['beautiful'], glove.itos[3366]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10.6.2 应用预训练词向量\n", "### 10.6.2.1 求近义词" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def knn(W, x, k):\n", " # 添加的1e-9是为了数值稳定性\n", " cos = torch.matmul(W, x.view((-1,))) / (\n", " (torch.sum(W * W, dim=1) + 1e-9).sqrt() * torch.sum(x * x).sqrt())\n", " _, topk = torch.topk(cos, k=k)\n", " topk = topk.cpu().numpy()\n", " return topk, [cos[i].item() for i in topk]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def get_similar_tokens(query_token, k, embed):\n", " topk, cos = knn(embed.vectors,\n", " embed.vectors[embed.stoi[query_token]], k+1)\n", " for i, c in zip(topk[1:], cos[1:]): # 除去输入词\n", " print('cosine sim=%.3f: %s' % (c, (embed.itos[i])))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cosine sim=0.856: chips\n", "cosine sim=0.749: intel\n", "cosine sim=0.749: electronics\n" ] } ], "source": [ "get_similar_tokens('chip', 3, glove)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cosine sim=0.839: babies\n", "cosine sim=0.800: boy\n", "cosine sim=0.792: girl\n" ] } ], "source": [ "get_similar_tokens('baby', 3, glove)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cosine sim=0.921: lovely\n", "cosine sim=0.893: gorgeous\n", "cosine sim=0.830: wonderful\n" ] } ], "source": [ "get_similar_tokens('beautiful', 3, glove)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 10.6.2.2 求类比词" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def get_analogy(token_a, token_b, token_c, embed):\n", " vecs = [embed.vectors[embed.stoi[t]] \n", " for t in [token_a, token_b, token_c]]\n", " x = vecs[1] - vecs[0] + vecs[2]\n", " topk, cos = knn(embed.vectors, x, 1)\n", " return embed.itos[topk[0]]" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'daughter'" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_analogy('man', 'woman', 'son', glove)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'japan'" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_analogy('beijing', 'china', 'tokyo', glove)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'biggest'" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_analogy('bad', 'worst', 'big', glove)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'went'" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_analogy('do', 'did', 'go', glove)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }