-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_parser.py
284 lines (259 loc) · 13.3 KB
/
run_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import os, sys, time, random, argparse, pickle
import config, data_utils, parser_model, pointer_parser_model
parser = argparse.ArgumentParser()
parser.add_argument('-lr', '--learning_rate', type=float, default=0.5,
help='Learning rate.')
parser.add_argument('-lrdf', '--learning_rate_decay_factor', type=float, default=0.9,
help='Learning rate decays by this much.')
parser.add_argument('-mgn', '--max_gradient-norm', type=float, default=5.0,
help='Clip gradients to this norm.')
parser.add_argument('-b', '--batch_size', type=int, default=20,
help='Batch size to use during training.')
parser.add_argument('-nl', '--num_layers', type=int, default=1,
help='Number of layers in the model.')
parser.add_argument('-dd', '--data_dir', default="data/Geo",
help='Data directory')
parser.add_argument('-td', '--train_dir', default="./tmp",
help='Training directory')
parser.add_argument('-mtds', '--max_train_data_size', type=int, default=0,
help='Limit on the size of training data (0: no limit).')
parser.add_argument('-esp', '--early_stopping_patience', type=int, default=5,
help='How many epochs to wait until early stopping is enforced.')
parser.add_argument('-nf', '--num_folds', type=int, default=10,
help='Number of folds for cross-validation')
parser.add_argument('-ls', '--layer_size', type=int, default=None)
parser.add_argument('mode', choices=['train', 'test'],
help='Way to run the app')
parser.add_argument('model_type', choices=['standard', 'pointer'],
help='Parser model to use')
FLAGS = parser.parse_args()
def load_data():
train_data, test_data, vocab_size = data_utils.load_raw_text(FLAGS.data_dir)
source_max_len, target_max_len = 0,0
for entry in train_data+test_data:
for sent in entry:
source_max_len = max(source_max_len, len(sent[0]))
target_max_len = max(target_max_len, len(sent[1]))
folds = []
fold_size = int(len(train_data)/FLAGS.num_folds)
for i in range(FLAGS.num_folds - 1):
folds.append(train_data[i*fold_size:(i+1)*fold_size])
folds.append(train_data[(FLAGS.num_folds-1)*fold_size:])
return folds, test_data, vocab_size, source_max_len, target_max_len
def create_model(session, conf, train_data):
"""Create model and initialize or load parameters in session."""
if FLAGS.model_type == "standard":
print("BUILDING STANDARD MODEL")
model = parser_model.MultiSentParseModel(conf, train_data)
else:
print("BUILDING POINTER MODEL")
model = pointer_parser_model.MultiSentPointerParseModel(conf, train_data)
print("Created model with fresh parameters.")
session.run(tf.initialize_all_variables())
return model
def train(sess, train_data, validation_data, conf, num_epochs = None):
print("Preparing model...")
model = create_model(sess, conf, train_data)
checkpoint_dir = os.path.join(FLAGS.train_dir, conf.get_dir())
checkpoint_path = os.path.join(checkpoint_dir, "parse.ckpt")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# The training loop.
step_time, loss = 0.0, 0.0
current_step = 0
previous_losses = []
best_validation_loss = float("inf")
best_validation_acc = 0.0
best_validation_sentence_acc = 0.0
best_validation_epoch = 0
print("Starting training")
sys.stdout.flush()
perfect_count = 0
epoch_count = 0
while not num_epochs or epoch_count < num_epochs:
# Get a batch and make a step.
entries, step_loss, step_outputs = model.step(sess, False)
loss += step_loss
current_step += 1
# Once in a while, we save checkpoint, print statistics, and run evals
if model.complete_epoch:
epoch_count += 1
model.complete_epoch = False
if not num_epochs:
# Check early stopping condition
#print("LAST BATCH:")
temp_loss, temp_total_acc, temp_sentence_acc = test(sess,
entries, model, conf.source_max_len)
#print("VALIDATION:")
validation_loss, validation_total_acc, validation_sentence_acc = test(sess, validation_data, model, conf.source_max_len)
if validation_sentence_acc == 1.0:
perfect_count += 1
else:
perfect_count = 0
print("Epoch %s learning rate %.4f training loss %.4f validation loss %.4f validation total acc %.4f validation sent acc %.4f" %
(epoch_count, model.learning_rate.eval(),
loss, validation_loss, validation_total_acc, validation_sentence_acc))
if validation_sentence_acc > best_validation_sentence_acc or (validation_sentence_acc == best_validation_sentence_acc and validation_loss < best_validation_loss):
best_validation_loss = validation_loss
best_validation_epoch = epoch_count
best_validation_acc = validation_total_acc
best_validation_sentence_acc = validation_sentence_acc
model.saver.save(sess, checkpoint_path, global_step=model.global_step)
if epoch_count - best_validation_epoch >= FLAGS.early_stopping_patience or perfect_count == 5:
print("Early stopping triggered. Restoring previous model")
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
model.saver.restore(sess, ckpt.model_checkpoint_path)
return model, best_validation_epoch
else:
print("\tEpoch %d of %d"%(epoch_count, num_epochs))
# Decrease learning rate if no improvement was seen over last 3 times.
if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
sess.run(model.learning_rate_decay_op)
previous_losses.append(loss)
if len(previous_losses) > 3:
del previous_losses[0]
# Save checkpoint and zero timer and loss
step_time, loss = 0.0, 0.0
sys.stdout.flush()
return model, epoch_count
def test(sess, test_data, model, source_max_len, dump_results=False):
_, loss, output_logits = model.step(sess, True, test_data)
total_acc, sentence_acc = evaluate_logits(output_logits, test_data,
source_max_len, dump_results)
return loss, total_acc, sentence_acc
def evaluate_logits(output_logits, test_data, source_max_len, dump_results=False):
total_outputs = []
for sent_ind in range(len(output_logits)):
temp_outputs = [[int(np.argmax(logit)) for logit in output_logit] for output_logit in output_logits[sent_ind]]
#Reshape outputs
outputs = np.array(temp_outputs).T.tolist()
for i in range(len(outputs)):
if outputs[i][0] == data_utils.PAD_ID:
outputs[i] = None
elif data_utils.LOGIC_EOS_ID in outputs[i]:
outputs[i] = outputs[i][:outputs[i].index(data_utils.LOGIC_EOS_ID)]
for ind, val in enumerate(outputs[i]):
if val >= len(data_utils.id_to_logic):
outputs[i][ind] = len(data_utils.id_to_logic) + source_max_len - (outputs[i][ind]-len(data_utils.id_to_logic))-1
total_outputs.append(outputs)
total_outputs = list(zip(*total_outputs))
#print("CORRECT OUTPUTS:")
#print(data_utils.ids_to_logics(test_data[0][1][1:-1]))
#print("GIVEN OUTPUTS")
#print(data_utils.ids_to_logics(outputs[0]))
if dump_results:
print("==============TEST FAILURES====================")
total_correct = 0.0
sentence_correct = 0.0
num_sentences = 0.0
num_entries = 0.0
for entry_ind in range(len(test_data)):
#print("ENTRY: %d"%entry_ind)
num_entries += 1.0
all_correct = True
for sent_ind in range(len(test_data[entry_ind])):
num_sentences += 1.0
if test_data[entry_ind][sent_ind][1][1:-1] != total_outputs[entry_ind][sent_ind]: #TODO: make sure this is correct
all_correct = False
else:
sentence_correct += 1.0
if all_correct:
total_correct += 1.0
elif dump_results:
for sent_ind in range(len(test_data[entry_ind])):
print(' '.join(data_utils.ids_to_words(test_data[entry_ind][sent_ind][0])))
print("\tCorrect: "+''.join(data_utils.ids_to_logics(test_data[entry_ind][sent_ind][1][1:-1],
test_data[entry_ind][sent_ind][0], source_max_len,
False)))
print("\tFound: "+''.join(data_utils.ids_to_logics(total_outputs[entry_ind][sent_ind],
test_data[entry_ind][sent_ind][0], source_max_len,
False)))
print("")
return total_correct/num_entries, sentence_correct/num_sentences
def cross_validate(splits, conf):
performance = 0
for i in range(len(splits)):
print("===================Beginning split %d========================"%i)
conf.fold = i
train_data = sum(splits[:i] + splits[i+1:], [])
validation_data = splits[i]
with tf.Session() as sess:
model,_ = train(sess, train_data, validation_data, conf)
loss, total_acc, sentence_acc = test(sess, validation_data, model)
performance += loss
tf.reset_default_graph()
return performance/len(splits)
def parameter_tuning(folds, source_vocab_size, target_vocab_size, source_max_len, target_max_len):
best_loss = None
best_config = None
for conf in config.config_beam_search(source_vocab_size, target_vocab_size, FLAGS.num_layers, FLAGS.batch_size, FLAGS.learning_rate, FLAGS.learning_rate_decay_factor, source_max_len, target_max_len, data_utils.words_to_id, data_utils.logic_to_id, data_utils.id_to_words, data_utils.id_to_logic):
print("+++++++++++++++++++++++Beginning cross-validation with dropout_rate = %0.1f, vector_size=%d++++++++++++++++++"%
(conf.dropout_rate, conf.layer_size))
loss = cross_validate(folds, conf)
if not best_loss or loss < best_loss:
best_loss = loss
best_config = conf
best_config.fold = None
print("Best config:")
print("\tdropout: %.1f, param size: %d"%(best_config.dropout_rate, best_config.layer_size))
return best_config
def main_train():
folds, test_data, (source_vocab_size, target_vocab_size), source_max_len, target_max_len = load_data()
train_data = sum(folds[:-1],[])
validation_data = folds[-1]
#conf = parameter_tuning(folds, source_vocab_size, target_vocab_size)
conf = list(config.config_beam_search(source_vocab_size, target_vocab_size, FLAGS.num_layers, FLAGS.batch_size, FLAGS.learning_rate, FLAGS.learning_rate_decay_factor, source_max_len, target_max_len, data_utils.words_to_id, data_utils.logic_to_id, data_utils.id_to_words, data_utils.id_to_logic))[0]
if FLAGS.layer_size != None:
conf.layer_size = FLAGS.layer_size
#First, train with held-out data to find number of iterations
with tf.Session() as sess:
model, num_steps = train(sess, train_data, validation_data, conf)
loss, total_acc, sentence_acc = test(sess, train_data, model,
source_max_len)
print("INTERMEDIATE RESULTS:")
print(" loss = %0.4f"%loss)
print(" total_acc = %0.4f"%total_acc)
print(" sentence_acc = %0.4f"%sentence_acc)
#Now train on full data set
#tf.reset_default_graph()
#train_data += validation_data
#with tf.Session() as sess:
#model, _ = train(sess, train_data, None, conf, num_steps)
model_path = os.path.join(FLAGS.train_dir, 'final_model')
model.saver.save(sess, model_path)
conf_out = open(os.path.join(FLAGS.train_dir, 'final_model.conf'), 'w')
pickle.dump(conf, conf_out)
conf_out.close()
loss, total_acc, sentence_acc = test(sess, test_data, model,
source_max_len)
print("FINAL RESULTS:")
print(" loss = %0.4f"%loss)
print(" total_acc = %0.4f"%total_acc)
print(" sentence_acc = %0.4f"%sentence_acc)
def main_test():
train_data, test_data, (source_vocab_size, target_vocab_size), source_max_len, _ = load_data()
test_conf_path = os.path.join(FLAGS.train_dir, 'final_model.conf')
conf_in = open(test_conf_path, 'r')
conf = pickle.load(conf_in)
conf_in.close()
with tf.Session() as sess:
model = create_model(sess, conf, None)
model.saver.restore(sess, os.path.join(FLAGS.train_dir, 'final_model'))
loss, total_acc, sentence_acc = test(sess, test_data, model,
source_max_len, True)
print("FINAL RESULTS:")
print(" loss = %0.4f"%loss)
print(" total_acc = %0.4f"%total_acc)
print(" sentence_acc = %0.4f"%sentence_acc)
def main(_):
if FLAGS.mode == "train":
main_train()
else:
main_test()
if __name__ == "__main__":
tf.app.run()