train.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611
  1. #! /bin/env python3
  2. from __future__ import absolute_import, division, print_function, unicode_literals
  3. import argparse
  4. import logging
  5. import os
  6. import re
  7. import numpy as np
  8. import tensorflow as tf
  9. import tensorflow_datasets as tfds
  10. # Probably remove
  11. #import matplotlib.pyplot as plt
  12. # Constants
  13. MAX_SENTENCE_LENGTH = 40
  14. BATCH_SIZE = 64
  15. BUFFER_SIZE = 20000
  16. # Hyper-parameters
  17. # Notebook poitns out that 'num_layers, d_model, and units have been reduced, and to see
  18. # https://arxiv.org/abs/1706.03762 for more information.
  19. #NUM_LAYERS = 2
  20. NUM_LAYERS = 4
  21. #D_MODEL = 256
  22. D_MODEL = 512
  23. NUM_HEADS = 8
  24. #UNITS = 512
  25. UNITS = 1024
  26. DROPOUT = 0.1
  27. SAMPLE_QUESTIONS = [
  28. "Hi.",
  29. "What is your name?",
  30. "My name is Fred.",
  31. "What's your name?",
  32. "How ya doin?",
  33. "What do you do for a living?",
  34. "Would you like to have sex with me?",
  35. "Which sexy cartoon character do you like best?",
  36. "Anything else you want to ask?",
  37. "Are you human?",
  38. "Do you like me?",
  39. "What did you to today?",
  40. "Do you like pizza?",
  41. "Are you a man or a woman?"
  42. ]
  43. # Globals, set elsewhere
  44. START_TOKEN = None
  45. END_TOKEN = None
  46. VOCAB_SIZE = None
  47. EPOCHS = None
  48. # random but predictable results:
  49. #tf.random.set_seed(4242)
  50. logger = logging.getLogger()
  51. def initialize_dataset(questions, answers):
  52. # decoder inputs use the previous target as input
  53. # remove START_TOKEN from targets
  54. dataset = tf.data.Dataset.from_tensor_slices((
  55. {
  56. 'inputs': questions,
  57. 'dec_inputs': answers[:, :-1]
  58. },
  59. {
  60. 'outputs': answers[:, 1:]
  61. },
  62. ))
  63. dataset = dataset.cache()
  64. dataset = dataset.shuffle(BUFFER_SIZE)
  65. dataset = dataset.batch(BATCH_SIZE)
  66. dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  67. return dataset
  68. def preprocess_sentence(sentence):
  69. sentence = sentence.lower().strip()
  70. # creating a space between a word and the punctuation following it
  71. # eg: "he is a boy." => "he is a boy ."
  72. sentence = re.sub(r"([?.!,])", r" \1 ", sentence)
  73. sentence = re.sub(r'[" "]+', " ", sentence)
  74. # replacing everything with space except (a-z, A-Z, ".", "?", "!", ",")
  75. sentence = re.sub(r"[^a-zA-Z?.!,]+", " ", sentence)
  76. sentence = sentence.strip()
  77. # adding a start and an end token to the sentence
  78. return sentence
  79. # Tokenize, filter and pad sentences
  80. def tokenize_and_filter(tokenizer, inputs, outputs):
  81. tokenized_inputs, tokenized_outputs = [], []
  82. for (sentence1, sentence2) in zip(inputs, outputs):
  83. # tokenize sentence
  84. sentence1 = START_TOKEN + tokenizer.encode(sentence1) + END_TOKEN
  85. sentence2 = START_TOKEN + tokenizer.encode(sentence2) + END_TOKEN
  86. # check tokenized sentence max length
  87. if len(sentence1) <= MAX_SENTENCE_LENGTH and len(sentence2) <= MAX_SENTENCE_LENGTH:
  88. tokenized_inputs.append(sentence1)
  89. tokenized_outputs.append(sentence2)
  90. # pad tokenized sentences
  91. tokenized_inputs = tf.keras.preprocessing.sequence.pad_sequences(
  92. tokenized_inputs, maxlen=MAX_SENTENCE_LENGTH, padding='post')
  93. tokenized_outputs = tf.keras.preprocessing.sequence.pad_sequences(
  94. tokenized_outputs, maxlen=MAX_SENTENCE_LENGTH, padding='post')
  95. return tokenized_inputs, tokenized_outputs
  96. def load_file(filename):
  97. ''' Returns an array of questions and answers '''
  98. questions, answers = [], []
  99. count = 0
  100. with open(filename, 'r') as file:
  101. while True:
  102. q, a = file.readline(), file.readline()
  103. if not q or not a:
  104. logger.info(f'{count} question and answer pairs read.')
  105. return questions, answers
  106. count += 1
  107. questions.append(preprocess_sentence(q))
  108. answers.append(preprocess_sentence(a))
  109. # Unreachable
  110. def tokenizer_init(questions, answers):
  111. global START_TOKEN
  112. global END_TOKEN
  113. global VOCAB_SIZE
  114. # Build tokenizer using tfds for both questions and answers
  115. # not great taht this is depecated. TODO: Update to use tensorflow_text
  116. try:
  117. tokenizer = tfds.features.text.SubwordTextEncoder.build_from_corpus(questions + answers, target_vocab_size=2**13)
  118. except:
  119. tokenizer = tfds.deprecated.text.SubwordTextEncoder.build_from_corpus(questions + answers, target_vocab_size=2**13)
  120. # Define start and end token to indicate the start and end of a sentence
  121. START_TOKEN, END_TOKEN = [tokenizer.vocab_size], [tokenizer.vocab_size + 1]
  122. # Vocabulary size plus start and end token
  123. VOCAB_SIZE = tokenizer.vocab_size + 2
  124. logger.debug(f'Vocab size: {VOCAB_SIZE}')
  125. logger.debug(f'Tokenized sample question: { tokenizer.encode(questions[20]) }')
  126. return tokenizer
  127. def scaled_dot_product_attention(query, key, value, mask):
  128. """Calculate the attention weights. """
  129. matmul_qk = tf.matmul(query, key, transpose_b=True)
  130. # scale matmul_qk
  131. depth = tf.cast(tf.shape(key)[-1], tf.float32)
  132. logits = matmul_qk / tf.math.sqrt(depth)
  133. # add the mask to zero out padding tokens
  134. if mask is not None:
  135. logits += (mask * -1e9)
  136. # softmax is normalized on the last axis (seq_len_k)
  137. attention_weights = tf.nn.softmax(logits, axis=-1)
  138. output = tf.matmul(attention_weights, value)
  139. return output
  140. class MultiHeadAttention(tf.keras.layers.Layer):
  141. def __init__(self, d_model, num_heads, name="multi_head_attention"):
  142. super(MultiHeadAttention, self).__init__(name=name)
  143. self.num_heads = num_heads
  144. self.d_model = d_model
  145. assert d_model % self.num_heads == 0
  146. self.depth = d_model // self.num_heads
  147. self.query_dense = tf.keras.layers.Dense(units=d_model)
  148. self.key_dense = tf.keras.layers.Dense(units=d_model)
  149. self.value_dense = tf.keras.layers.Dense(units=d_model)
  150. self.dense = tf.keras.layers.Dense(units=d_model)
  151. def split_heads(self, inputs, batch_size):
  152. inputs = tf.reshape(
  153. inputs, shape=(batch_size, -1, self.num_heads, self.depth))
  154. return tf.transpose(inputs, perm=[0, 2, 1, 3])
  155. def call(self, inputs):
  156. query, key, value, mask = inputs['query'], inputs['key'], inputs[
  157. 'value'], inputs['mask']
  158. batch_size = tf.shape(query)[0]
  159. # linear layers
  160. query = self.query_dense(query)
  161. key = self.key_dense(key)
  162. value = self.value_dense(value)
  163. # split heads
  164. query = self.split_heads(query, batch_size)
  165. key = self.split_heads(key, batch_size)
  166. value = self.split_heads(value, batch_size)
  167. # scaled dot-product attention
  168. scaled_attention = scaled_dot_product_attention(query, key, value, mask)
  169. scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
  170. # concatenation of heads
  171. concat_attention = tf.reshape(scaled_attention,
  172. (batch_size, -1, self.d_model))
  173. # final linear layer
  174. outputs = self.dense(concat_attention)
  175. return outputs
  176. def create_padding_mask(x):
  177. mask = tf.cast(tf.math.equal(x, 0), tf.float32)
  178. # (batch_size, 1, 1, sequence length)
  179. return mask[:, tf.newaxis, tf.newaxis, :]
  180. def create_look_ahead_mask(x):
  181. seq_len = tf.shape(x)[1]
  182. look_ahead_mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
  183. padding_mask = create_padding_mask(x)
  184. return tf.maximum(look_ahead_mask, padding_mask)
  185. class PositionalEncoding(tf.keras.layers.Layer):
  186. def __init__(self, position, d_model):
  187. super(PositionalEncoding, self).__init__()
  188. self.pos_encoding = self.positional_encoding(position, d_model)
  189. def get_angles(self, position, i, d_model):
  190. angles = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(d_model, tf.float32))
  191. return position * angles
  192. def positional_encoding(self, position, d_model):
  193. angle_rads = self.get_angles(
  194. position=tf.range(position, dtype=tf.float32)[:, tf.newaxis],
  195. i=tf.range(d_model, dtype=tf.float32)[tf.newaxis, :],
  196. d_model=d_model)
  197. # apply sin to even index in the array
  198. sines = tf.math.sin(angle_rads[:, 0::2])
  199. # apply cos to odd index in the array
  200. cosines = tf.math.cos(angle_rads[:, 1::2])
  201. pos_encoding = tf.concat([sines, cosines], axis=-1)
  202. pos_encoding = pos_encoding[tf.newaxis, ...]
  203. return tf.cast(pos_encoding, tf.float32)
  204. def call(self, inputs):
  205. return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]
  206. def encoder_layer(units, d_model, num_heads, dropout, name="encoder_layer"):
  207. inputs = tf.keras.Input(shape=(None, d_model), name="inputs")
  208. padding_mask = tf.keras.Input(shape=(1, 1, None), name="padding_mask")
  209. attention = MultiHeadAttention(
  210. d_model, num_heads, name="attention")({
  211. 'query': inputs,
  212. 'key': inputs,
  213. 'value': inputs,
  214. 'mask': padding_mask
  215. })
  216. attention = tf.keras.layers.Dropout(rate=dropout)(attention)
  217. attention = tf.keras.layers.LayerNormalization(
  218. epsilon=1e-6)(inputs + attention)
  219. outputs = tf.keras.layers.Dense(units=units, activation='relu')(attention)
  220. outputs = tf.keras.layers.Dense(units=d_model)(outputs)
  221. outputs = tf.keras.layers.Dropout(rate=dropout)(outputs)
  222. outputs = tf.keras.layers.LayerNormalization(
  223. epsilon=1e-6)(attention + outputs)
  224. return tf.keras.Model(
  225. inputs=[inputs, padding_mask], outputs=outputs, name=name)
  226. def encoder(vocab_size,
  227. num_layers,
  228. units,
  229. d_model,
  230. num_heads,
  231. dropout,
  232. name="encoder"):
  233. inputs = tf.keras.Input(shape=(None,), name="inputs")
  234. padding_mask = tf.keras.Input(shape=(1, 1, None), name="padding_mask")
  235. embeddings = tf.keras.layers.Embedding(vocab_size, d_model)(inputs)
  236. embeddings *= tf.math.sqrt(tf.cast(d_model, tf.float32))
  237. embeddings = PositionalEncoding(vocab_size, d_model)(embeddings)
  238. outputs = tf.keras.layers.Dropout(rate=dropout)(embeddings)
  239. for i in range(num_layers):
  240. outputs = encoder_layer(
  241. units=units,
  242. d_model=d_model,
  243. num_heads=num_heads,
  244. dropout=dropout,
  245. name="encoder_layer_{}".format(i),
  246. )([outputs, padding_mask])
  247. return tf.keras.Model(
  248. inputs=[inputs, padding_mask], outputs=outputs, name=name)
  249. def decoder_layer(units, d_model, num_heads, dropout, name="decoder_layer"):
  250. inputs = tf.keras.Input(shape=(None, d_model), name="inputs")
  251. enc_outputs = tf.keras.Input(shape=(None, d_model), name="encoder_outputs")
  252. look_ahead_mask = tf.keras.Input(
  253. shape=(1, None, None), name="look_ahead_mask")
  254. padding_mask = tf.keras.Input(shape=(1, 1, None), name='padding_mask')
  255. attention1 = MultiHeadAttention(
  256. d_model, num_heads, name="attention_1")(inputs={
  257. 'query': inputs,
  258. 'key': inputs,
  259. 'value': inputs,
  260. 'mask': look_ahead_mask
  261. })
  262. attention1 = tf.keras.layers.LayerNormalization(
  263. epsilon=1e-6)(attention1 + inputs)
  264. attention2 = MultiHeadAttention(
  265. d_model, num_heads, name="attention_2")(inputs={
  266. 'query': attention1,
  267. 'key': enc_outputs,
  268. 'value': enc_outputs,
  269. 'mask': padding_mask
  270. })
  271. attention2 = tf.keras.layers.Dropout(rate=dropout)(attention2)
  272. attention2 = tf.keras.layers.LayerNormalization(
  273. epsilon=1e-6)(attention2 + attention1)
  274. outputs = tf.keras.layers.Dense(units=units, activation='relu')(attention2)
  275. outputs = tf.keras.layers.Dense(units=d_model)(outputs)
  276. outputs = tf.keras.layers.Dropout(rate=dropout)(outputs)
  277. outputs = tf.keras.layers.LayerNormalization(
  278. epsilon=1e-6)(outputs + attention2)
  279. return tf.keras.Model(
  280. inputs=[inputs, enc_outputs, look_ahead_mask, padding_mask],
  281. outputs=outputs,
  282. name=name)
  283. def decoder(vocab_size,
  284. num_layers,
  285. units,
  286. d_model,
  287. num_heads,
  288. dropout,
  289. name='decoder'):
  290. inputs = tf.keras.Input(shape=(None,), name='inputs')
  291. enc_outputs = tf.keras.Input(shape=(None, d_model), name='encoder_outputs')
  292. look_ahead_mask = tf.keras.Input(
  293. shape=(1, None, None), name='look_ahead_mask')
  294. padding_mask = tf.keras.Input(shape=(1, 1, None), name='padding_mask')
  295. embeddings = tf.keras.layers.Embedding(vocab_size, d_model)(inputs)
  296. embeddings *= tf.math.sqrt(tf.cast(d_model, tf.float32))
  297. embeddings = PositionalEncoding(vocab_size, d_model)(embeddings)
  298. outputs = tf.keras.layers.Dropout(rate=dropout)(embeddings)
  299. for i in range(num_layers):
  300. outputs = decoder_layer(
  301. units=units,
  302. d_model=d_model,
  303. num_heads=num_heads,
  304. dropout=dropout,
  305. name='decoder_layer_{}'.format(i),
  306. )(inputs=[outputs, enc_outputs, look_ahead_mask, padding_mask])
  307. return tf.keras.Model(
  308. inputs=[inputs, enc_outputs, look_ahead_mask, padding_mask],
  309. outputs=outputs,
  310. name=name)
  311. def transformer(vocab_size,
  312. num_layers,
  313. units,
  314. d_model,
  315. num_heads,
  316. dropout,
  317. name="transformer"):
  318. inputs = tf.keras.Input(shape=(None,), name="inputs")
  319. dec_inputs = tf.keras.Input(shape=(None,), name="dec_inputs")
  320. enc_padding_mask = tf.keras.layers.Lambda(
  321. create_padding_mask, output_shape=(1, 1, None),
  322. name='enc_padding_mask')(inputs)
  323. # mask the future tokens for decoder inputs at the 1st attention block
  324. look_ahead_mask = tf.keras.layers.Lambda(
  325. create_look_ahead_mask,
  326. output_shape=(1, None, None),
  327. name='look_ahead_mask')(dec_inputs)
  328. # mask the encoder outputs for the 2nd attention block
  329. dec_padding_mask = tf.keras.layers.Lambda(
  330. create_padding_mask, output_shape=(1, 1, None),
  331. name='dec_padding_mask')(inputs)
  332. enc_outputs = encoder(
  333. vocab_size=vocab_size,
  334. num_layers=num_layers,
  335. units=units,
  336. d_model=d_model,
  337. num_heads=num_heads,
  338. dropout=dropout,
  339. )(inputs=[inputs, enc_padding_mask])
  340. dec_outputs = decoder(
  341. vocab_size=vocab_size,
  342. num_layers=num_layers,
  343. units=units,
  344. d_model=d_model,
  345. num_heads=num_heads,
  346. dropout=dropout,
  347. )(inputs=[dec_inputs, enc_outputs, look_ahead_mask, dec_padding_mask])
  348. outputs = tf.keras.layers.Dense(units=vocab_size, name="outputs")(dec_outputs)
  349. return tf.keras.Model(inputs=[inputs, dec_inputs], outputs=outputs, name=name)
  350. def train_model():
  351. tf.keras.backend.clear_session()
  352. return transformer(
  353. vocab_size=VOCAB_SIZE,
  354. num_layers=NUM_LAYERS,
  355. units=UNITS,
  356. d_model=D_MODEL,
  357. num_heads=NUM_HEADS,
  358. dropout=DROPOUT)
  359. def loss_function(y_true, y_pred):
  360. y_true = tf.reshape(y_true, shape=(-1, MAX_SENTENCE_LENGTH - 1))
  361. loss = tf.keras.losses.SparseCategoricalCrossentropy(
  362. from_logits=True, reduction='none')(y_true, y_pred)
  363. mask = tf.cast(tf.not_equal(y_true, 0), tf.float32)
  364. loss = tf.multiply(loss, mask)
  365. return tf.reduce_mean(loss)
  366. # A custom learning rate
  367. class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  368. def __init__(self, d_model, warmup_steps=4000):
  369. super(CustomSchedule, self).__init__()
  370. self.d_model = d_model
  371. self.d_model = tf.cast(self.d_model, tf.float32)
  372. self.warmup_steps = warmup_steps
  373. def __call__(self, step):
  374. arg1 = tf.math.rsqrt(step)
  375. arg2 = step * (self.warmup_steps**-1.5)
  376. return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
  377. def get_config(self):
  378. # TODO: This is related to saving, and may not be correct at all.
  379. config = {
  380. 'd_model': self.d_model,
  381. 'warmup_steps': self.warmup_steps,
  382. }
  383. return config
  384. def accuracy(y_true, y_pred):
  385. # ensure labels have shape (batch_size, MAX_LENGTH - 1)
  386. y_true = tf.reshape(y_true, shape=(-1, MAX_SENTENCE_LENGTH - 1))
  387. return tf.keras.metrics.sparse_categorical_accuracy(y_true, y_pred)
  388. def evaluate(tokenizer, model, sentence):
  389. sentence = preprocess_sentence(sentence)
  390. sentence = tf.expand_dims(
  391. START_TOKEN + tokenizer.encode(sentence) + END_TOKEN, axis=0)
  392. output = tf.expand_dims(START_TOKEN, 0)
  393. for i in range(MAX_SENTENCE_LENGTH):
  394. predictions = model(inputs=[sentence, output], training=False)
  395. # select the last word from the seq_len dimension
  396. predictions = predictions[:, -1:, :]
  397. predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
  398. # return the result if the predicted_id is equal to the end token
  399. if tf.equal(predicted_id, END_TOKEN[0]):
  400. break
  401. # concatenated the predicted_id to the output which is given to the decoder
  402. # as its input.
  403. output = tf.concat([output, predicted_id], axis=-1)
  404. return tf.squeeze(output, axis=0)
  405. def predict(tokenizer, model, sentence):
  406. prediction = evaluate(tokenizer, model, sentence)
  407. predicted_sentence = tokenizer.decode(
  408. [i for i in prediction if i < tokenizer.vocab_size])
  409. #logger.debug('Input: {}'.format(sentence))
  410. #logger.debug('Output: {}'.format(predicted_sentence))
  411. return predicted_sentence
  412. def load_args():
  413. parser = argparse.ArgumentParser()
  414. parser.add_argument('source', help='Source File')
  415. parser.add_argument('output', help='File to save the model')
  416. parser.add_argument('--debug', help='Extra debugging messages', action='store_true')
  417. parser.add_argument('--epochs', help='Number of epochs to train', type=int, default=20)
  418. parser.add_argument('--checkpointdir', help='Filename for checkpoint state', default='checkpoints')
  419. parser.add_argument('--resume', help='Resume and keep training', action='store_true')
  420. parser.add_argument('--loadcheckpoint', help='Load checkpoint but skip training', action='store_true')
  421. args = parser.parse_args()
  422. global logger
  423. if args.debug:
  424. logger.setLevel(logging.DEBUG)
  425. else:
  426. logger.setLevel(logging.INFO)
  427. global EPOCHS
  428. EPOCHS = args.epochs
  429. return args
  430. def main():
  431. global logger
  432. handler = logging.StreamHandler()
  433. logger.addHandler(handler)
  434. args = load_args()
  435. logger.debug(f'Loading file {args.source}')
  436. questions, answers = load_file(args.source)
  437. logger.debug('Sample question: {}'.format(questions[25]))
  438. logger.debug('Sample answer: {}'.format(answers[25]))
  439. tokenizer = tokenizer_init(questions, answers)
  440. questions, answers = tokenize_and_filter(tokenizer, questions, answers)
  441. dataset = initialize_dataset(questions, answers)
  442. logger.debug(create_look_ahead_mask(tf.constant([[1, 2, 0, 4, 5]])))
  443. model = train_model()
  444. learning_rate = CustomSchedule(D_MODEL)
  445. optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
  446. model.compile(optimizer=optimizer, loss=loss_function, metrics=[accuracy])
  447. # Create a callback that saves the model's weights
  448. # TODO: Add a resume learning step. See https://www.tensorflow.org/tutorials/keras/save_and_load for
  449. # information on how tor esume from a checkpoint
  450. cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(args.checkpointdir, f'{args.output}.ckpt'),
  451. save_weights_only=True,
  452. verbose=1)
  453. if args.resume:
  454. logger.debug('Loading previously saved checkpoint')
  455. model.load_weights(os.path.join(args.checkpointdir, f'{args.output}.ckpt')).expect_partial()
  456. if args.loadcheckpoint:
  457. logger.debug('Loading previously saved checkpoint but expecting partial')
  458. model.load_weights(os.path.join(args.checkpointdir, f'{args.output}.ckpt')).expect_partial()
  459. if not args.loadcheckpoint:
  460. logger.debug(f'Untrained model:\n{model.summary()}')
  461. model.fit(dataset, epochs=EPOCHS, callbacks=[cp_callback])
  462. logger.debug(f'Trained model:\n{model.summary()}')
  463. # save the model - This doesn't work yet
  464. #model.save(args.output)
  465. for sentence in SAMPLE_QUESTIONS:
  466. print(f'Q: {sentence}')
  467. print(f'A: {predict(tokenizer, model, sentence)}')
  468. if __name__ == "__main__":
  469. main()