-
Notifications
You must be signed in to change notification settings - Fork 6
/
vocabulary.py
99 lines (83 loc) · 3.41 KB
/
vocabulary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
A Vocabulary maintains a mapping between words and corresponding unique
integers, holds special integers (tokens) for indicating start and end of
sequence, and offers functionality to map out-of-vocabulary words to the
corresponding token.
"""
import json
import os
from typing import List
class Vocabulary(object):
"""
A simple Vocabulary class which maintains a mapping between words and
integer tokens. Can be initialized either by word counts from the VisDial
v1.0 train dataset, or a pre-saved vocabulary mapping.
Parameters
----------
word_counts_path: str
Path to a json file containing counts of each word across captions,
questions and answers of the VisDial v1.0 train dataset.
min_count : int, optional (default=0)
When initializing the vocabulary from word counts, you can specify a
minimum count, and every token with a count less than this will be
excluded from vocabulary.
"""
PAD_TOKEN = "<PAD>"
SOS_TOKEN = "<S>"
EOS_TOKEN = "</S>"
UNK_TOKEN = "<UNK>"
PAD_INDEX = 0
SOS_INDEX = 1
EOS_INDEX = 2
UNK_INDEX = 3
def __init__(self, word_counts_path: str, min_count: int = 5):
if not os.path.exists(word_counts_path):
raise FileNotFoundError(
f"Word counts do not exist at {word_counts_path}"
)
with open(word_counts_path, "r") as word_counts_file:
word_counts = json.load(word_counts_file)
# form a list of (word, count) tuples and apply min_count threshold
word_counts = [
(word, count)
for word, count in word_counts.items()
if count >= min_count
]
# sort in descending order of word counts
word_counts = sorted(word_counts, key=lambda wc: -wc[1])
words = [w[0] for w in word_counts]
self.word2index = {}
self.word2index[self.PAD_TOKEN] = self.PAD_INDEX
self.word2index[self.SOS_TOKEN] = self.SOS_INDEX
self.word2index[self.EOS_TOKEN] = self.EOS_INDEX
self.word2index[self.UNK_TOKEN] = self.UNK_INDEX
for index, word in enumerate(words):
self.word2index[word] = index + 4
self.index2word = {
index: word for word, index in self.word2index.items()
}
@classmethod
def from_saved(cls, saved_vocabulary_path: str) -> "Vocabulary":
"""Build the vocabulary from a json file saved by ``save`` method.
Parameters
----------
saved_vocabulary_path : str
Path to a json file containing word to integer mappings
(saved vocabulary).
"""
with open(saved_vocabulary_path, "r") as saved_vocabulary_file:
cls.word2index = json.load(saved_vocabulary_file)
cls.index2word = {
index: word for word, index in cls.word2index.items()
}
def to_indices(self, words: List[str]) -> List[int]:
return [self.word2index.get(word, self.UNK_INDEX) for word in words]
def to_words(self, indices: List[int]) -> List[str]:
return [
self.index2word.get(index, self.UNK_TOKEN) for index in indices
]
def save(self, save_vocabulary_path: str) -> None:
with open(save_vocabulary_path, "w") as save_vocabulary_file:
json.dump(self.word2index, save_vocabulary_file)
def __len__(self):
return len(self.index2word)