Add seeds to RunConfig for dynamic_rnn_estimator_test to eliminate flake.

Change: 132593375
This commit is contained in:
A. Unique TensorFlower 2016-09-08 12:28:28 -08:00 committed by TensorFlower Gardener
parent 24728ede8c
commit 57d6a3ee56
2 changed files with 19 additions and 11 deletions

View File

@ -497,7 +497,7 @@ py_test(
py_test(
name = "dynamic_rnn_estimator_test",
size = "small",
size = "medium",
srcs = ["python/learn/estimators/dynamic_rnn_estimator_test.py"],
srcs_version = "PY2AND3",
deps = [

View File

@ -21,7 +21,6 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib import layers
from tensorflow.contrib.learn.python.learn.estimators import dynamic_rnn_estimator
@ -264,11 +263,11 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase):
np.testing.assert_almost_equal(flattened_logits, target_column_input_logits)
np.testing.assert_equal(expected_predictions, predictions)
def testLearnLinearExtrapolation(self):
"""Tests that `_MultiValueRNNEstimator` can learn a linear function."""
def testLearnSineFunction(self):
"""Tests that `_MultiValueRNNEstimator` can learn a sine function."""
batch_size = 8
sequence_length = 64
train_steps = 100
train_steps = 200
eval_steps = 20
cell_size = 4
learning_rate = 0.1
@ -291,8 +290,9 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase):
return input_fn
config = tf.contrib.learn.RunConfig(tf_random_seed=1234)
sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_regressor(
num_units=cell_size, learning_rate=learning_rate)
num_units=cell_size, learning_rate=learning_rate, config=config)
train_input_fn = get_sin_input_fn(
batch_size, sequence_length, np.pi / 32, seed=1234)
@ -332,11 +332,15 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase):
return {'inputs': inputs}, labels
return input_fn
config = tf.contrib.learn.RunConfig(tf_random_seed=21212)
sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_classifier(
num_classes=2, num_units=cell_size, learning_rate=learning_rate)
num_classes=2,
num_units=cell_size,
learning_rate=learning_rate,
config=config)
train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=1234)
eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=4321)
train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321)
eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123)
sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
evaluation = sequence_estimator.evaluate(
@ -412,12 +416,14 @@ class SingleValueRNNEstimatorTest(tf.test.TestCase):
return {'inputs': inputs}, labels
return input_fn
config = tf.contrib.learn.RunConfig(tf_random_seed=6)
sequence_regressor = dynamic_rnn_estimator.single_value_rnn_regressor(
num_units=cell_size,
cell_type=cell_type,
optimizer_type=optimizer_type,
learning_rate=learning_rate,
momentum=momentum)
momentum=momentum,
config=config)
train_input_fn = get_mean_input_fn(batch_size, sequence_length, 121)
eval_input_fn = get_mean_input_fn(batch_size, sequence_length, 212)
@ -456,13 +462,15 @@ class SingleValueRNNEstimatorTest(tf.test.TestCase):
return {'inputs': inputs}, labels
return input_fn
config = tf.contrib.learn.RunConfig(tf_random_seed=77)
sequence_classifier = dynamic_rnn_estimator.single_value_rnn_classifier(
num_classes=2,
num_units=cell_size,
cell_type=cell_type,
optimizer_type=optimizer_type,
learning_rate=learning_rate,
momentum=momentum)
momentum=momentum,
config=config)
train_input_fn = get_majority_input_fn(batch_size, sequence_length, 1111)
eval_input_fn = get_majority_input_fn(batch_size, sequence_length, 2222)