You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
354 lines
7.2 KiB
354 lines
7.2 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 10.6 求近义词和类比词\n",
|
|
"## 10.6.1 使用预训练的词向量"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"1.0.0\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"dict_keys(['charngram.100d', 'fasttext.en.300d', 'fasttext.simple.300d', 'glove.42B.300d', 'glove.840B.300d', 'glove.twitter.27B.25d', 'glove.twitter.27B.50d', 'glove.twitter.27B.100d', 'glove.twitter.27B.200d', 'glove.6B.50d', 'glove.6B.100d', 'glove.6B.200d', 'glove.6B.300d'])"
|
|
]
|
|
},
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"import torchtext.vocab as vocab\n",
|
|
"\n",
|
|
"print(torch.__version__)\n",
|
|
"vocab.pretrained_aliases.keys()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"['glove.42B.300d',\n",
|
|
" 'glove.840B.300d',\n",
|
|
" 'glove.twitter.27B.25d',\n",
|
|
" 'glove.twitter.27B.50d',\n",
|
|
" 'glove.twitter.27B.100d',\n",
|
|
" 'glove.twitter.27B.200d',\n",
|
|
" 'glove.6B.50d',\n",
|
|
" 'glove.6B.100d',\n",
|
|
" 'glove.6B.200d',\n",
|
|
" 'glove.6B.300d']"
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"[key for key in vocab.pretrained_aliases.keys()\n",
|
|
" if \"glove\" in key]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"cache_dir = \"/Users/tangshusen/Datasets/glove\"\n",
|
|
"# glove = vocab.pretrained_aliases[\"glove.6B.50d\"](cache=cache_dir)\n",
|
|
"glove = vocab.GloVe(name='6B', dim=50, cache=cache_dir) # 与上面等价"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"一共包含400000个词。\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"一共包含%d个词。\" % len(glove.stoi))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(3366, 'beautiful')"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"glove.stoi['beautiful'], glove.itos[3366]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 10.6.2 应用预训练词向量\n",
|
|
"### 10.6.2.1 求近义词"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def knn(W, x, k):\n",
|
|
" # 添加的1e-9是为了数值稳定性\n",
|
|
" cos = torch.matmul(W, x.view((-1,))) / (\n",
|
|
" (torch.sum(W * W, dim=1) + 1e-9).sqrt() * torch.sum(x * x).sqrt())\n",
|
|
" _, topk = torch.topk(cos, k=k)\n",
|
|
" topk = topk.cpu().numpy()\n",
|
|
" return topk, [cos[i].item() for i in topk]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_similar_tokens(query_token, k, embed):\n",
|
|
" topk, cos = knn(embed.vectors,\n",
|
|
" embed.vectors[embed.stoi[query_token]], k+1)\n",
|
|
" for i, c in zip(topk[1:], cos[1:]): # 除去输入词\n",
|
|
" print('cosine sim=%.3f: %s' % (c, (embed.itos[i])))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"cosine sim=0.856: chips\n",
|
|
"cosine sim=0.749: intel\n",
|
|
"cosine sim=0.749: electronics\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"get_similar_tokens('chip', 3, glove)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"cosine sim=0.839: babies\n",
|
|
"cosine sim=0.800: boy\n",
|
|
"cosine sim=0.792: girl\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"get_similar_tokens('baby', 3, glove)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"cosine sim=0.921: lovely\n",
|
|
"cosine sim=0.893: gorgeous\n",
|
|
"cosine sim=0.830: wonderful\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"get_similar_tokens('beautiful', 3, glove)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 10.6.2.2 求类比词"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_analogy(token_a, token_b, token_c, embed):\n",
|
|
" vecs = [embed.vectors[embed.stoi[t]] \n",
|
|
" for t in [token_a, token_b, token_c]]\n",
|
|
" x = vecs[1] - vecs[0] + vecs[2]\n",
|
|
" topk, cos = knn(embed.vectors, x, 1)\n",
|
|
" return embed.itos[topk[0]]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'daughter'"
|
|
]
|
|
},
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"get_analogy('man', 'woman', 'son', glove)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'japan'"
|
|
]
|
|
},
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"get_analogy('beijing', 'china', 'tokyo', glove)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'biggest'"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"get_analogy('bad', 'worst', 'big', glove)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'went'"
|
|
]
|
|
},
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"get_analogy('do', 'did', 'go', glove)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python [default]",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|