Update seq2seq test to be 2.0 compatible.

Previously it was only tested under eager mode, and new issue was discovered by when "TF2_BEHAVIOR=1".

Also the v2 tests has been updated to use keras.LSTMCell to ensure the correctness in v2.

PiperOrigin-RevId: 236470595
This commit is contained in:
Scott Zhu 2019-03-02 10:51:52 -08:00 committed by TensorFlower Gardener
parent 7d58fd5675
commit f1bbf1d83e
9 changed files with 109 additions and 96 deletions

View File

@ -13,11 +13,9 @@
# limitations under the License.
# ==============================================================================
"""Tests for contrib.seq2seq.python.ops.attention_wrapper."""
# pylint: disable=unused-import,g-bad-import-order
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
import collections
import functools
@ -30,6 +28,7 @@ from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.layers import core as layers_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
@ -66,6 +65,7 @@ def get_result_summary(x):
return x
@test_util.run_v1_only
class AttentionWrapperTest(test.TestCase):
def assertAllCloseOrEqual(self, x, y, **kwargs):

View File

@ -30,7 +30,6 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.keras import initializers
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import nest
@ -305,7 +304,10 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
attention_layer_size = attention_layer_size[0]
if attention_layer is not None:
attention_layer = attention_layer[0]
cell = rnn_cell.LSTMCell(cell_depth, initializer="ones")
cell = keras.layers.LSTMCell(cell_depth,
recurrent_activation="sigmoid",
kernel_initializer="ones",
recurrent_initializer="ones")
cell = wrapper.AttentionWrapper(
cell,
attention_mechanisms if is_multi else attention_mechanisms[0],
@ -321,7 +323,7 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sampler = sampler_py.TrainingSampler()
my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler)
initial_state = cell.zero_state(
initial_state = cell.get_initial_state(
dtype=dtypes.float32, batch_size=batch_size)
final_outputs, final_state, _ = my_decoder(
decoder_inputs,
@ -330,7 +332,6 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple)
expected_time = (
expected_final_state.time if context.executing_eagerly() else None)
@ -342,9 +343,9 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
self.assertEqual((batch_size, attention_depth),
tuple(final_state.attention.get_shape().as_list()))
self.assertEqual((batch_size, cell_depth),
tuple(final_state.cell_state.c.get_shape().as_list()))
tuple(final_state.cell_state[0].get_shape().as_list()))
self.assertEqual((batch_size, cell_depth),
tuple(final_state.cell_state.h.get_shape().as_list()))
tuple(final_state.cell_state[1].get_shape().as_list()))
if alignment_history:
if is_multi:
@ -395,8 +396,9 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
expected_final_alignment_history,
final_alignment_history_info)
@parameterized.parameters([np.float16, np.float32, np.float64])
def _testBahdanauNormalizedDType(self, dtype):
# TODO(b/126893309): reenable np.float16 once the bug is fixed.
@parameterized.parameters([np.float32, np.float64])
def testBahdanauNormalizedDType(self, dtype):
encoder_outputs = self.encoder_outputs.astype(dtype)
decoder_inputs = self.decoder_inputs.astype(dtype)
attention_mechanism = wrapper.BahdanauAttentionV2(
@ -405,7 +407,7 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
memory_sequence_length=self.encoder_sequence_length,
normalize=True,
dtype=dtype)
cell = rnn_cell.LSTMCell(self.units)
cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid")
cell = wrapper.AttentionWrapper(cell, attention_mechanism)
sampler = sampler_py.TrainingSampler()
@ -418,9 +420,9 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
self.assertEqual(final_outputs.rnn_output.dtype, dtype)
self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple)
@parameterized.parameters([np.float16, np.float32, np.float64])
# TODO(b/126893309): reenable np.float16 once the bug is fixed.
@parameterized.parameters([np.float32, np.float64])
def testLuongScaledDType(self, dtype):
# Test case for GitHub issue 18099
encoder_outputs = self.encoder_outputs.astype(dtype)
@ -432,7 +434,7 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
scale=True,
dtype=dtype,
)
cell = rnn_cell.LSTMCell(self.units)
cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid")
cell = wrapper.AttentionWrapper(cell, attention_mechanism)
sampler = sampler_py.TrainingSampler()
@ -445,7 +447,6 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
self.assertEqual(final_outputs.rnn_output.dtype, dtype)
self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple)
def testBahdanauNotNormalized(self):
create_attention_mechanism = wrapper.BahdanauAttentionV2
@ -455,11 +456,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=4.8290324),
sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype(np.int32), mean=0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824),
ResultSummary(
shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype(np.float32), mean=6.7445569),
time=3,
@ -490,11 +491,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype("float32"), mean=6.3075728),
time=3,
@ -520,11 +521,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype("float32"), mean=4.084631),
time=3,
@ -550,11 +551,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype("float32"), mean=4.0846314),
time=3,
@ -581,11 +582,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=3.86666666))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002)],
attention=ResultSummary(
shape=(5, 10), dtype=np.dtype("float32"), mean=0.011346335),
time=3,
@ -613,11 +614,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype("float32"), mean=8.361186),
time=3,
@ -648,11 +649,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype("float32"), mean=7.3326721),
time=3,
@ -682,11 +683,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605),
time=3,
@ -716,11 +717,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605),
time=3,

View File

@ -13,31 +13,30 @@
# limitations under the License.
# ==============================================================================
"""Tests for contrib.seq2seq.python.seq2seq.basic_decoder."""
# pylint: disable=unused-import,g-bad-import-order
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
import numpy as np
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.layers import core as layers_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
# pylint: enable=g-import-not-at-top
@test_util.run_v1_only
class BasicDecoderTest(test.TestCase):
def _testStepWithTrainingHelper(self, use_output_layer):

View File

@ -187,14 +187,23 @@ class TestArrayShapeChecks(test.TestCase):
shape=dynamic_shape)
batch_size = array_ops.constant(batch_size)
check_op = beam_search_decoder._check_batch_beam(t, batch_size, beam_width) # pylint: disable=protected-access
with self.cached_session() as sess:
if is_valid:
sess.run(check_op)
def _test_body():
# pylint: disable=protected-access
if context.executing_eagerly():
beam_search_decoder._check_batch_beam(t, batch_size, beam_width)
else:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(check_op)
with self.cached_session():
check_op = beam_search_decoder._check_batch_beam(
t, batch_size, beam_width)
self.evaluate(check_op)
# pylint: enable=protected-access
if is_valid:
_test_body()
else:
with self.assertRaises(errors.InvalidArgumentError):
_test_body()
def test_array_shape_dynamic_checks(self):
self._test_array_shape_dynamic_checks(
@ -463,6 +472,7 @@ class TestLargeBeamStep(test.TestCase):
self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]])
@test_util.run_v1_only
class BeamSearchDecoderTest(test.TestCase):
def _testDynamicDecodeRNN(self, time_major, has_attention,

View File

@ -49,8 +49,8 @@ class GatherTreeTest(test.TestCase):
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=end_token)
with self.session(use_gpu=True):
self.assertAllEqual(expected_result, beams.eval())
with self.cached_session(use_gpu=True):
self.assertAllEqual(expected_result, self.evaluate(beams))
def testBadParentValuesOnCPU(self):
# (batch_size = 1, max_time = 4, beams = 3)
@ -62,15 +62,14 @@ class GatherTreeTest(test.TestCase):
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
max_sequence_lengths = [3]
with ops.device("/cpu:0"):
beams = beam_search_ops.gather_tree(
step_ids=step_ids,
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=end_token)
with self.cached_session():
with self.assertRaisesOpError(
r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
_ = beams.eval()
beams = beam_search_ops.gather_tree(
step_ids=step_ids,
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=end_token)
self.evaluate(beams)
def testBadParentValuesOnGPU(self):
# Only want to run this test on CUDA devices, as gather_tree is not
@ -93,8 +92,7 @@ class GatherTreeTest(test.TestCase):
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=end_token)
with self.session(use_gpu=True):
self.assertAllEqual(expected_result, beams.eval())
self.assertAllEqual(expected_result, self.evaluate(beams))
def testGatherTreeBatch(self):
batch_size = 10
@ -103,7 +101,7 @@ class GatherTreeTest(test.TestCase):
max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0]
end_token = 5
with self.session(use_gpu=True):
with self.cached_session(use_gpu=True):
step_ids = np.random.randint(
0, high=end_token + 1, size=(max_time, batch_size, beam_width))
parent_ids = np.random.randint(
@ -116,7 +114,7 @@ class GatherTreeTest(test.TestCase):
end_token=end_token)
self.assertEqual((max_time, batch_size, beam_width), beams.shape)
beams_value = beams.eval()
beams_value = self.evaluate(beams)
for b in range(batch_size):
# Past max_sequence_lengths[b], we emit all end tokens.
b_value = beams_value[max_sequence_lengths[b]:, b, :]

View File

@ -13,26 +13,25 @@
# limitations under the License.
# ==============================================================================
"""Tests for contrib.seq2seq.python.seq2seq.decoder."""
# pylint: disable=unused-import,g-bad-import-order
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
import numpy as np
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import test
# pylint: enable=g-import-not-at-top
@test_util.run_v1_only
class DynamicDecodeRNNTest(test.TestCase):
def _testDynamicDecodeRNN(self, time_major, maximum_iterations=None):

View File

@ -31,7 +31,7 @@ from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class LossTest(test.TestCase):
def setUp(self):
def config_default_values(self):
self.batch_size = 2
self.sequence_length = 3
self.number_of_classes = 5
@ -56,7 +56,8 @@ class LossTest(test.TestCase):
self.expected_loss = 1.60944
def testSequenceLoss(self):
with self.test_session(use_gpu=True):
self.config_default_values()
with self.cached_session(use_gpu=True):
average_loss_per_example = loss.sequence_loss(
self.logits, self.targets, self.weights,
average_across_timesteps=True,
@ -90,7 +91,8 @@ class LossTest(test.TestCase):
self.assertAllClose(compare_total, res)
def testSequenceLossClass(self):
with self.test_session(use_gpu=True):
self.config_default_values()
with self.cached_session(use_gpu=True):
seq_loss = loss.SequenceLoss(average_across_timesteps=True,
average_across_batch=True,
sum_over_timesteps=False,
@ -132,7 +134,8 @@ class LossTest(test.TestCase):
self.assertAllClose(compare_total, res)
def testSumReduction(self):
with self.test_session(use_gpu=True):
self.config_default_values()
with self.cached_session(use_gpu=True):
seq_loss = loss.SequenceLoss(average_across_timesteps=False,
average_across_batch=False,
sum_over_timesteps=True,
@ -174,6 +177,7 @@ class LossTest(test.TestCase):
self.assertAllClose(compare_total, res)
def testWeightedSumReduction(self):
self.config_default_values()
weights = [
constant_op.constant(1.0, shape=[self.batch_size])
for _ in range(self.sequence_length)
@ -181,7 +185,7 @@ class LossTest(test.TestCase):
# Make the last element in the sequence to have zero weights.
weights[-1] = constant_op.constant(0.0, shape=[self.batch_size])
self.weights = array_ops.stack(weights, axis=1)
with self.test_session(use_gpu=True):
with self.cached_session(use_gpu=True):
seq_loss = loss.SequenceLoss(average_across_timesteps=False,
average_across_batch=False,
sum_over_timesteps=True,
@ -225,12 +229,13 @@ class LossTest(test.TestCase):
self.assertAllClose(compare_total, res)
def testZeroWeights(self):
self.config_default_values()
weights = [
constant_op.constant(0.0, shape=[self.batch_size])
for _ in range(self.sequence_length)
]
weights = array_ops.stack(weights, axis=1)
with self.test_session(use_gpu=True):
with self.cached_session(use_gpu=True):
average_loss_per_example = loss.sequence_loss(
self.logits, self.targets, weights,
average_across_timesteps=True,

View File

@ -2347,7 +2347,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
if self._initial_cell_state is not None:
cell_state = self._initial_cell_state
else:
cell_state = self._cell.zero_state(batch_size, dtype)
cell_state = self._cell.get_initial_state(batch_size=batch_size,
dtype=dtype)
error_message = (
"When calling zero_state of AttentionWrapper %s: " % self._base_name +
"Non-matching batch sizes between the memory "

View File

@ -218,7 +218,7 @@ def _check_batch_beam(t, batch_size, beam_width):
"incompatible with the dynamic shape of %s elements. "
"Consider setting reorder_tensor_arrays to False to disable "
"TensorArray reordering during the beam search."
% (t.name))
% (t if context.executing_eagerly() else t.name))
rank = t.shape.ndims
shape = array_ops.shape(t)
if rank == 2: