Set graph's random seed when training

This commit is contained in:
Reuben Morais 2018-12-11 21:49:19 -02:00 committed by GitHub
parent da135ca3f9
commit 8ebfe80dd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 0 deletions

View File

@ -890,6 +890,7 @@ def main(_):
if len(FLAGS.worker_hosts) == 0: if len(FLAGS.worker_hosts) == 0:
# Only one local task: this process (default case - no cluster) # Only one local task: this process (default case - no cluster)
with tf.Graph().as_default(): with tf.Graph().as_default():
tf.set_random_seed(FLAGS.random_seed)
train() train()
# Now do a final test epoch # Now do a final test epoch
if FLAGS.test: if FLAGS.test: