-
Notifications
You must be signed in to change notification settings - Fork 8
/
test_load_glove.lua
47 lines (34 loc) · 1.54 KB
/
test_load_glove.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
require 'nn'
require 'misc.DataLoader'
local utils = require 'misc.utils'
local net_utils = require 'misc.net_utils'
loader = DataLoader{h5_file = '../coco_data/cocotalk.h5', json_file = '../coco_data/cocotalk.json'}
local use_glove = true
test_glove_weights = nil
test_glove_table = nil
if use_glove then
-- vocab is: ix_to_word
-- ex: ix_to_word[tostring(1)] = 'woods'
local ix_to_word = loader:getVocab()
local word_to_ix = utils.invert_key_value(ix_to_word)
local vocab_size = loader:getVocabSize()
-- here, we use glove vector of dimension 300
local glove_weights = torch.Tensor(vocab_size+1, 300)
-- note here that lookuptable.weights is of size (vocab_size + 1, 300), +1 means the '<END>' token, which corresponds to index vocab_size+1
-- glove_table['word'] = vector
local glove_table = net_utils.load_glove('..//glove_word2vec/glove.6B.300d.txt', 300)
test_glove_table = glove_table
for ix, word in pairs(ix_to_word) do
if word == 'UNK' then -- 'UNK' in our case corresponds to '<unk>' in our glove table, that is one difference i found in my case
glove_weights[tonumber(ix)] = glove_table['unk']
else
if glove_table[word] == nil then
print(word .. ' not exists ' .. 'in glove files')
glove_weights[tonumber(ix)] = torch.Tensor(300):uniform(-1, 1)
else
glove_weights[tonumber(ix)] = glove_table[word]
end
end
end
test_glove_weights = glove_weights
end