Add seeds to RunConfig for dynamic_rnn_estimator_test to eliminate flake.
Change: 132593375
This commit is contained in:
parent
24728ede8c
commit
57d6a3ee56
@ -497,7 +497,7 @@ py_test(
|
|||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "dynamic_rnn_estimator_test",
|
name = "dynamic_rnn_estimator_test",
|
||||||
size = "small",
|
size = "medium",
|
||||||
srcs = ["python/learn/estimators/dynamic_rnn_estimator_test.py"],
|
srcs = ["python/learn/estimators/dynamic_rnn_estimator_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.contrib import layers
|
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import dynamic_rnn_estimator
|
from tensorflow.contrib.learn.python.learn.estimators import dynamic_rnn_estimator
|
||||||
|
|
||||||
|
|
||||||
@ -264,11 +263,11 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase):
|
|||||||
np.testing.assert_almost_equal(flattened_logits, target_column_input_logits)
|
np.testing.assert_almost_equal(flattened_logits, target_column_input_logits)
|
||||||
np.testing.assert_equal(expected_predictions, predictions)
|
np.testing.assert_equal(expected_predictions, predictions)
|
||||||
|
|
||||||
def testLearnLinearExtrapolation(self):
|
def testLearnSineFunction(self):
|
||||||
"""Tests that `_MultiValueRNNEstimator` can learn a linear function."""
|
"""Tests that `_MultiValueRNNEstimator` can learn a sine function."""
|
||||||
batch_size = 8
|
batch_size = 8
|
||||||
sequence_length = 64
|
sequence_length = 64
|
||||||
train_steps = 100
|
train_steps = 200
|
||||||
eval_steps = 20
|
eval_steps = 20
|
||||||
cell_size = 4
|
cell_size = 4
|
||||||
learning_rate = 0.1
|
learning_rate = 0.1
|
||||||
@ -291,8 +290,9 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase):
|
|||||||
|
|
||||||
return input_fn
|
return input_fn
|
||||||
|
|
||||||
|
config = tf.contrib.learn.RunConfig(tf_random_seed=1234)
|
||||||
sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_regressor(
|
sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_regressor(
|
||||||
num_units=cell_size, learning_rate=learning_rate)
|
num_units=cell_size, learning_rate=learning_rate, config=config)
|
||||||
|
|
||||||
train_input_fn = get_sin_input_fn(
|
train_input_fn = get_sin_input_fn(
|
||||||
batch_size, sequence_length, np.pi / 32, seed=1234)
|
batch_size, sequence_length, np.pi / 32, seed=1234)
|
||||||
@ -332,11 +332,15 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase):
|
|||||||
return {'inputs': inputs}, labels
|
return {'inputs': inputs}, labels
|
||||||
return input_fn
|
return input_fn
|
||||||
|
|
||||||
|
config = tf.contrib.learn.RunConfig(tf_random_seed=21212)
|
||||||
sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_classifier(
|
sequence_estimator = dynamic_rnn_estimator.multi_value_rnn_classifier(
|
||||||
num_classes=2, num_units=cell_size, learning_rate=learning_rate)
|
num_classes=2,
|
||||||
|
num_units=cell_size,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
config=config)
|
||||||
|
|
||||||
train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=1234)
|
train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321)
|
||||||
eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=4321)
|
eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123)
|
||||||
|
|
||||||
sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
|
sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
|
||||||
evaluation = sequence_estimator.evaluate(
|
evaluation = sequence_estimator.evaluate(
|
||||||
@ -412,12 +416,14 @@ class SingleValueRNNEstimatorTest(tf.test.TestCase):
|
|||||||
return {'inputs': inputs}, labels
|
return {'inputs': inputs}, labels
|
||||||
return input_fn
|
return input_fn
|
||||||
|
|
||||||
|
config = tf.contrib.learn.RunConfig(tf_random_seed=6)
|
||||||
sequence_regressor = dynamic_rnn_estimator.single_value_rnn_regressor(
|
sequence_regressor = dynamic_rnn_estimator.single_value_rnn_regressor(
|
||||||
num_units=cell_size,
|
num_units=cell_size,
|
||||||
cell_type=cell_type,
|
cell_type=cell_type,
|
||||||
optimizer_type=optimizer_type,
|
optimizer_type=optimizer_type,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
momentum=momentum)
|
momentum=momentum,
|
||||||
|
config=config)
|
||||||
|
|
||||||
train_input_fn = get_mean_input_fn(batch_size, sequence_length, 121)
|
train_input_fn = get_mean_input_fn(batch_size, sequence_length, 121)
|
||||||
eval_input_fn = get_mean_input_fn(batch_size, sequence_length, 212)
|
eval_input_fn = get_mean_input_fn(batch_size, sequence_length, 212)
|
||||||
@ -456,13 +462,15 @@ class SingleValueRNNEstimatorTest(tf.test.TestCase):
|
|||||||
return {'inputs': inputs}, labels
|
return {'inputs': inputs}, labels
|
||||||
return input_fn
|
return input_fn
|
||||||
|
|
||||||
|
config = tf.contrib.learn.RunConfig(tf_random_seed=77)
|
||||||
sequence_classifier = dynamic_rnn_estimator.single_value_rnn_classifier(
|
sequence_classifier = dynamic_rnn_estimator.single_value_rnn_classifier(
|
||||||
num_classes=2,
|
num_classes=2,
|
||||||
num_units=cell_size,
|
num_units=cell_size,
|
||||||
cell_type=cell_type,
|
cell_type=cell_type,
|
||||||
optimizer_type=optimizer_type,
|
optimizer_type=optimizer_type,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
momentum=momentum)
|
momentum=momentum,
|
||||||
|
config=config)
|
||||||
|
|
||||||
train_input_fn = get_majority_input_fn(batch_size, sequence_length, 1111)
|
train_input_fn = get_majority_input_fn(batch_size, sequence_length, 1111)
|
||||||
eval_input_fn = get_majority_input_fn(batch_size, sequence_length, 2222)
|
eval_input_fn = get_majority_input_fn(batch_size, sequence_length, 2222)
|
||||||
|
Loading…
Reference in New Issue
Block a user