From 57d6a3ee564e89cf8318b5d2f3b851888f21b86e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 8 Sep 2016 12:28:28 -0800 Subject: [PATCH] Add seeds to RunConfig for dynamic_rnn_estimator_test to eliminate flake. Change: 132593375 --- tensorflow/contrib/learn/BUILD | 2 +- .../estimators/dynamic_rnn_estimator_test.py | 28 ++++++++++++------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index c639359a5ff..84a016b65b1 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -497,7 +497,7 @@ py_test( py_test( name = "dynamic_rnn_estimator_test", - size = "small", + size = "medium", srcs = ["python/learn/estimators/dynamic_rnn_estimator_test.py"], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index f79665cea6f..37e829cf12e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import numpy as np import tensorflow as tf -from tensorflow.contrib import layers from tensorflow.contrib.learn.python.learn.estimators import dynamic_rnn_estimator @@ -264,11 +263,11 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase): np.testing.assert_almost_equal(flattened_logits, target_column_input_logits) np.testing.assert_equal(expected_predictions, predictions) - def testLearnLinearExtrapolation(self): - """Tests that `_MultiValueRNNEstimator` can learn a linear function.""" + def testLearnSineFunction(self): + """Tests that `_MultiValueRNNEstimator` can learn a sine function.""" batch_size = 8 sequence_length = 64 - train_steps = 100 + train_steps = 200 eval_steps = 20 cell_size = 4 learning_rate = 0.1 @@ -291,8 +290,9 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase): return input_fn + config = tf.contrib.learn.RunConfig(tf_random_seed=1234) sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_regressor( - num_units=cell_size, learning_rate=learning_rate) + num_units=cell_size, learning_rate=learning_rate, config=config) train_input_fn = get_sin_input_fn( batch_size, sequence_length, np.pi / 32, seed=1234) @@ -332,11 +332,15 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase): return {'inputs': inputs}, labels return input_fn + config = tf.contrib.learn.RunConfig(tf_random_seed=21212) sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_classifier( - num_classes=2, num_units=cell_size, learning_rate=learning_rate) + num_classes=2, + num_units=cell_size, + learning_rate=learning_rate, + config=config) - train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=1234) - eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=4321) + train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321) + eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123) sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) evaluation = sequence_estimator.evaluate( @@ -412,12 +416,14 @@ class SingleValueRNNEstimatorTest(tf.test.TestCase): return {'inputs': inputs}, labels return input_fn + config = tf.contrib.learn.RunConfig(tf_random_seed=6) sequence_regressor = dynamic_rnn_estimator.single_value_rnn_regressor( num_units=cell_size, cell_type=cell_type, optimizer_type=optimizer_type, learning_rate=learning_rate, - momentum=momentum) + momentum=momentum, + config=config) train_input_fn = get_mean_input_fn(batch_size, sequence_length, 121) eval_input_fn = get_mean_input_fn(batch_size, sequence_length, 212) @@ -456,13 +462,15 @@ class SingleValueRNNEstimatorTest(tf.test.TestCase): return {'inputs': inputs}, labels return input_fn + config = tf.contrib.learn.RunConfig(tf_random_seed=77) sequence_classifier = dynamic_rnn_estimator.single_value_rnn_classifier( num_classes=2, num_units=cell_size, cell_type=cell_type, optimizer_type=optimizer_type, learning_rate=learning_rate, - momentum=momentum) + momentum=momentum, + config=config) train_input_fn = get_majority_input_fn(batch_size, sequence_length, 1111) eval_input_fn = get_majority_input_fn(batch_size, sequence_length, 2222)