Update seq2seq test to be 2.0 compatible.
Previously it was only tested under eager mode, and new issue was discovered by when "TF2_BEHAVIOR=1". Also the v2 tests has been updated to use keras.LSTMCell to ensure the correctness in v2. PiperOrigin-RevId: 236470595
This commit is contained in:
parent
7d58fd5675
commit
f1bbf1d83e
@ -13,11 +13,9 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for contrib.seq2seq.python.ops.attention_wrapper."""
|
||||
# pylint: disable=unused-import,g-bad-import-order
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
# pylint: enable=unused-import
|
||||
|
||||
import collections
|
||||
import functools
|
||||
@ -30,6 +28,7 @@ from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.layers import core as layers_core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -66,6 +65,7 @@ def get_result_summary(x):
|
||||
return x
|
||||
|
||||
|
||||
@test_util.run_v1_only
|
||||
class AttentionWrapperTest(test.TestCase):
|
||||
|
||||
def assertAllCloseOrEqual(self, x, y, **kwargs):
|
||||
|
@ -30,7 +30,6 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
@ -305,7 +304,10 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
attention_layer_size = attention_layer_size[0]
|
||||
if attention_layer is not None:
|
||||
attention_layer = attention_layer[0]
|
||||
cell = rnn_cell.LSTMCell(cell_depth, initializer="ones")
|
||||
cell = keras.layers.LSTMCell(cell_depth,
|
||||
recurrent_activation="sigmoid",
|
||||
kernel_initializer="ones",
|
||||
recurrent_initializer="ones")
|
||||
cell = wrapper.AttentionWrapper(
|
||||
cell,
|
||||
attention_mechanisms if is_multi else attention_mechanisms[0],
|
||||
@ -321,7 +323,7 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
sampler = sampler_py.TrainingSampler()
|
||||
my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler)
|
||||
initial_state = cell.zero_state(
|
||||
initial_state = cell.get_initial_state(
|
||||
dtype=dtypes.float32, batch_size=batch_size)
|
||||
final_outputs, final_state, _ = my_decoder(
|
||||
decoder_inputs,
|
||||
@ -330,7 +332,6 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
|
||||
self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
|
||||
self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple)
|
||||
|
||||
expected_time = (
|
||||
expected_final_state.time if context.executing_eagerly() else None)
|
||||
@ -342,9 +343,9 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual((batch_size, attention_depth),
|
||||
tuple(final_state.attention.get_shape().as_list()))
|
||||
self.assertEqual((batch_size, cell_depth),
|
||||
tuple(final_state.cell_state.c.get_shape().as_list()))
|
||||
tuple(final_state.cell_state[0].get_shape().as_list()))
|
||||
self.assertEqual((batch_size, cell_depth),
|
||||
tuple(final_state.cell_state.h.get_shape().as_list()))
|
||||
tuple(final_state.cell_state[1].get_shape().as_list()))
|
||||
|
||||
if alignment_history:
|
||||
if is_multi:
|
||||
@ -395,8 +396,9 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
expected_final_alignment_history,
|
||||
final_alignment_history_info)
|
||||
|
||||
@parameterized.parameters([np.float16, np.float32, np.float64])
|
||||
def _testBahdanauNormalizedDType(self, dtype):
|
||||
# TODO(b/126893309): reenable np.float16 once the bug is fixed.
|
||||
@parameterized.parameters([np.float32, np.float64])
|
||||
def testBahdanauNormalizedDType(self, dtype):
|
||||
encoder_outputs = self.encoder_outputs.astype(dtype)
|
||||
decoder_inputs = self.decoder_inputs.astype(dtype)
|
||||
attention_mechanism = wrapper.BahdanauAttentionV2(
|
||||
@ -405,7 +407,7 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
memory_sequence_length=self.encoder_sequence_length,
|
||||
normalize=True,
|
||||
dtype=dtype)
|
||||
cell = rnn_cell.LSTMCell(self.units)
|
||||
cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid")
|
||||
cell = wrapper.AttentionWrapper(cell, attention_mechanism)
|
||||
|
||||
sampler = sampler_py.TrainingSampler()
|
||||
@ -418,9 +420,9 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
|
||||
self.assertEqual(final_outputs.rnn_output.dtype, dtype)
|
||||
self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
|
||||
self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple)
|
||||
|
||||
@parameterized.parameters([np.float16, np.float32, np.float64])
|
||||
# TODO(b/126893309): reenable np.float16 once the bug is fixed.
|
||||
@parameterized.parameters([np.float32, np.float64])
|
||||
def testLuongScaledDType(self, dtype):
|
||||
# Test case for GitHub issue 18099
|
||||
encoder_outputs = self.encoder_outputs.astype(dtype)
|
||||
@ -432,7 +434,7 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
scale=True,
|
||||
dtype=dtype,
|
||||
)
|
||||
cell = rnn_cell.LSTMCell(self.units)
|
||||
cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid")
|
||||
cell = wrapper.AttentionWrapper(cell, attention_mechanism)
|
||||
|
||||
sampler = sampler_py.TrainingSampler()
|
||||
@ -445,7 +447,6 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
|
||||
self.assertEqual(final_outputs.rnn_output.dtype, dtype)
|
||||
self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
|
||||
self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple)
|
||||
|
||||
def testBahdanauNotNormalized(self):
|
||||
create_attention_mechanism = wrapper.BahdanauAttentionV2
|
||||
@ -455,11 +456,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=4.8290324),
|
||||
sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype(np.int32), mean=0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype(np.float32), mean=6.7445569),
|
||||
time=3,
|
||||
@ -490,11 +491,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype("float32"), mean=6.3075728),
|
||||
time=3,
|
||||
@ -520,11 +521,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype("float32"), mean=4.084631),
|
||||
time=3,
|
||||
@ -550,11 +551,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype("float32"), mean=4.0846314),
|
||||
time=3,
|
||||
@ -581,11 +582,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=3.86666666))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 10), dtype=np.dtype("float32"), mean=0.011346335),
|
||||
time=3,
|
||||
@ -613,11 +614,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype("float32"), mean=8.361186),
|
||||
time=3,
|
||||
@ -648,11 +649,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype("float32"), mean=7.3326721),
|
||||
time=3,
|
||||
@ -682,11 +683,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605),
|
||||
time=3,
|
||||
@ -716,11 +717,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605),
|
||||
time=3,
|
||||
|
@ -13,31 +13,30 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for contrib.seq2seq.python.seq2seq.basic_decoder."""
|
||||
# pylint: disable=unused-import,g-bad-import-order
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
# pylint: enable=unused-import
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
||||
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.layers import core as layers_core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
@test_util.run_v1_only
|
||||
class BasicDecoderTest(test.TestCase):
|
||||
|
||||
def _testStepWithTrainingHelper(self, use_output_layer):
|
||||
|
@ -187,14 +187,23 @@ class TestArrayShapeChecks(test.TestCase):
|
||||
shape=dynamic_shape)
|
||||
|
||||
batch_size = array_ops.constant(batch_size)
|
||||
check_op = beam_search_decoder._check_batch_beam(t, batch_size, beam_width) # pylint: disable=protected-access
|
||||
|
||||
with self.cached_session() as sess:
|
||||
if is_valid:
|
||||
sess.run(check_op)
|
||||
def _test_body():
|
||||
# pylint: disable=protected-access
|
||||
if context.executing_eagerly():
|
||||
beam_search_decoder._check_batch_beam(t, batch_size, beam_width)
|
||||
else:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(check_op)
|
||||
with self.cached_session():
|
||||
check_op = beam_search_decoder._check_batch_beam(
|
||||
t, batch_size, beam_width)
|
||||
self.evaluate(check_op)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if is_valid:
|
||||
_test_body()
|
||||
else:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
_test_body()
|
||||
|
||||
def test_array_shape_dynamic_checks(self):
|
||||
self._test_array_shape_dynamic_checks(
|
||||
@ -463,6 +472,7 @@ class TestLargeBeamStep(test.TestCase):
|
||||
self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]])
|
||||
|
||||
|
||||
@test_util.run_v1_only
|
||||
class BeamSearchDecoderTest(test.TestCase):
|
||||
|
||||
def _testDynamicDecodeRNN(self, time_major, has_attention,
|
||||
|
@ -49,8 +49,8 @@ class GatherTreeTest(test.TestCase):
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=end_token)
|
||||
with self.session(use_gpu=True):
|
||||
self.assertAllEqual(expected_result, beams.eval())
|
||||
with self.cached_session(use_gpu=True):
|
||||
self.assertAllEqual(expected_result, self.evaluate(beams))
|
||||
|
||||
def testBadParentValuesOnCPU(self):
|
||||
# (batch_size = 1, max_time = 4, beams = 3)
|
||||
@ -62,15 +62,14 @@ class GatherTreeTest(test.TestCase):
|
||||
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
|
||||
max_sequence_lengths = [3]
|
||||
with ops.device("/cpu:0"):
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids,
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=end_token)
|
||||
with self.cached_session():
|
||||
with self.assertRaisesOpError(
|
||||
r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
|
||||
_ = beams.eval()
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids,
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=end_token)
|
||||
self.evaluate(beams)
|
||||
|
||||
def testBadParentValuesOnGPU(self):
|
||||
# Only want to run this test on CUDA devices, as gather_tree is not
|
||||
@ -93,8 +92,7 @@ class GatherTreeTest(test.TestCase):
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=end_token)
|
||||
with self.session(use_gpu=True):
|
||||
self.assertAllEqual(expected_result, beams.eval())
|
||||
self.assertAllEqual(expected_result, self.evaluate(beams))
|
||||
|
||||
def testGatherTreeBatch(self):
|
||||
batch_size = 10
|
||||
@ -103,7 +101,7 @@ class GatherTreeTest(test.TestCase):
|
||||
max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0]
|
||||
end_token = 5
|
||||
|
||||
with self.session(use_gpu=True):
|
||||
with self.cached_session(use_gpu=True):
|
||||
step_ids = np.random.randint(
|
||||
0, high=end_token + 1, size=(max_time, batch_size, beam_width))
|
||||
parent_ids = np.random.randint(
|
||||
@ -116,7 +114,7 @@ class GatherTreeTest(test.TestCase):
|
||||
end_token=end_token)
|
||||
|
||||
self.assertEqual((max_time, batch_size, beam_width), beams.shape)
|
||||
beams_value = beams.eval()
|
||||
beams_value = self.evaluate(beams)
|
||||
for b in range(batch_size):
|
||||
# Past max_sequence_lengths[b], we emit all end tokens.
|
||||
b_value = beams_value[max_sequence_lengths[b]:, b, :]
|
||||
|
@ -13,26 +13,25 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for contrib.seq2seq.python.seq2seq.decoder."""
|
||||
# pylint: disable=unused-import,g-bad-import-order
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
# pylint: enable=unused-import
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.platform import test
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
@test_util.run_v1_only
|
||||
class DynamicDecodeRNNTest(test.TestCase):
|
||||
|
||||
def _testDynamicDecodeRNN(self, time_major, maximum_iterations=None):
|
||||
|
@ -31,7 +31,7 @@ from tensorflow.python.platform import test
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LossTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
def config_default_values(self):
|
||||
self.batch_size = 2
|
||||
self.sequence_length = 3
|
||||
self.number_of_classes = 5
|
||||
@ -56,7 +56,8 @@ class LossTest(test.TestCase):
|
||||
self.expected_loss = 1.60944
|
||||
|
||||
def testSequenceLoss(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
self.config_default_values()
|
||||
with self.cached_session(use_gpu=True):
|
||||
average_loss_per_example = loss.sequence_loss(
|
||||
self.logits, self.targets, self.weights,
|
||||
average_across_timesteps=True,
|
||||
@ -90,7 +91,8 @@ class LossTest(test.TestCase):
|
||||
self.assertAllClose(compare_total, res)
|
||||
|
||||
def testSequenceLossClass(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
self.config_default_values()
|
||||
with self.cached_session(use_gpu=True):
|
||||
seq_loss = loss.SequenceLoss(average_across_timesteps=True,
|
||||
average_across_batch=True,
|
||||
sum_over_timesteps=False,
|
||||
@ -132,7 +134,8 @@ class LossTest(test.TestCase):
|
||||
self.assertAllClose(compare_total, res)
|
||||
|
||||
def testSumReduction(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
self.config_default_values()
|
||||
with self.cached_session(use_gpu=True):
|
||||
seq_loss = loss.SequenceLoss(average_across_timesteps=False,
|
||||
average_across_batch=False,
|
||||
sum_over_timesteps=True,
|
||||
@ -174,6 +177,7 @@ class LossTest(test.TestCase):
|
||||
self.assertAllClose(compare_total, res)
|
||||
|
||||
def testWeightedSumReduction(self):
|
||||
self.config_default_values()
|
||||
weights = [
|
||||
constant_op.constant(1.0, shape=[self.batch_size])
|
||||
for _ in range(self.sequence_length)
|
||||
@ -181,7 +185,7 @@ class LossTest(test.TestCase):
|
||||
# Make the last element in the sequence to have zero weights.
|
||||
weights[-1] = constant_op.constant(0.0, shape=[self.batch_size])
|
||||
self.weights = array_ops.stack(weights, axis=1)
|
||||
with self.test_session(use_gpu=True):
|
||||
with self.cached_session(use_gpu=True):
|
||||
seq_loss = loss.SequenceLoss(average_across_timesteps=False,
|
||||
average_across_batch=False,
|
||||
sum_over_timesteps=True,
|
||||
@ -225,12 +229,13 @@ class LossTest(test.TestCase):
|
||||
self.assertAllClose(compare_total, res)
|
||||
|
||||
def testZeroWeights(self):
|
||||
self.config_default_values()
|
||||
weights = [
|
||||
constant_op.constant(0.0, shape=[self.batch_size])
|
||||
for _ in range(self.sequence_length)
|
||||
]
|
||||
weights = array_ops.stack(weights, axis=1)
|
||||
with self.test_session(use_gpu=True):
|
||||
with self.cached_session(use_gpu=True):
|
||||
average_loss_per_example = loss.sequence_loss(
|
||||
self.logits, self.targets, weights,
|
||||
average_across_timesteps=True,
|
||||
|
@ -2347,7 +2347,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
|
||||
if self._initial_cell_state is not None:
|
||||
cell_state = self._initial_cell_state
|
||||
else:
|
||||
cell_state = self._cell.zero_state(batch_size, dtype)
|
||||
cell_state = self._cell.get_initial_state(batch_size=batch_size,
|
||||
dtype=dtype)
|
||||
error_message = (
|
||||
"When calling zero_state of AttentionWrapper %s: " % self._base_name +
|
||||
"Non-matching batch sizes between the memory "
|
||||
|
@ -218,7 +218,7 @@ def _check_batch_beam(t, batch_size, beam_width):
|
||||
"incompatible with the dynamic shape of %s elements. "
|
||||
"Consider setting reorder_tensor_arrays to False to disable "
|
||||
"TensorArray reordering during the beam search."
|
||||
% (t.name))
|
||||
% (t if context.executing_eagerly() else t.name))
|
||||
rank = t.shape.ndims
|
||||
shape = array_ops.shape(t)
|
||||
if rank == 2:
|
||||
|
Loading…
Reference in New Issue
Block a user