Selaa lähdekoodia

Minor updates

Fred Damstra [Titan] 3 vuotta sitten
vanhempi
sitoutus
d719288777
1 muutettua tiedostoa jossa 8 lisäystä ja 4 poistoa
  1. 8 4
      RNN-Chatbots/mybot/train.py

+ 8 - 4
RNN-Chatbots/mybot/train.py

@@ -44,7 +44,8 @@ SAMPLE_QUESTIONS = [
     "Are you human?",
     "Do you like me?",
     "What did you to today?",
-    "Do you like pizza?"
+    "Do you like pizza?",
+    "Are you a man or a woman?"
 ]
 
 # Globals, set elsewhere
@@ -54,7 +55,7 @@ VOCAB_SIZE = None
 EPOCHS = None
 
 # random but predictable results:
-tf.random.set_seed(4242)
+#tf.random.set_seed(4242)
 logger = logging.getLogger()
 
 
@@ -134,7 +135,10 @@ def tokenizer_init(questions, answers):
     global VOCAB_SIZE
     # Build tokenizer using tfds for both questions and answers
     # not great taht this is depecated. TODO: Update to use tensorflow_text
-    tokenizer = tfds.features.text.SubwordTextEncoder.build_from_corpus(questions + answers, target_vocab_size=2**13)
+    try:
+        tokenizer = tfds.features.text.SubwordTextEncoder.build_from_corpus(questions + answers, target_vocab_size=2**13)
+    except:
+        tokenizer = tfds.deprecated.text.SubwordTextEncoder.build_from_corpus(questions + answers, target_vocab_size=2**13)
     # Define start and end token to indicate the start and end of a sentence
     START_TOKEN, END_TOKEN = [tokenizer.vocab_size], [tokenizer.vocab_size + 1]
     # Vocabulary size plus start and end token
@@ -584,7 +588,7 @@ def main():
 
     if args.resume:
         logger.debug('Loading previously saved checkpoint')
-        model.load_weights(os.path.join(args.checkpointdir, f'{args.output}.ckpt'))
+        model.load_weights(os.path.join(args.checkpointdir, f'{args.output}.ckpt')).expect_partial()
 
     if args.loadcheckpoint:
         logger.debug('Loading previously saved checkpoint but expecting partial')