diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 1a5692f7b5b..98e54db4584 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -13,11 +13,9 @@ # limitations under the License. # ============================================================================== """Tests for contrib.seq2seq.python.ops.attention_wrapper.""" -# pylint: disable=unused-import,g-bad-import-order from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: enable=unused-import import collections import functools @@ -30,6 +28,7 @@ from tensorflow.contrib.seq2seq.python.ops import helper as helper_py from tensorflow.contrib.seq2seq.python.ops import basic_decoder from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -66,6 +65,7 @@ def get_result_summary(x): return x +@test_util.run_v1_only class AttentionWrapperTest(test.TestCase): def assertAllCloseOrEqual(self, x, y, **kwargs): diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py index 5ee01f66f16..4943a1574ce 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py @@ -30,7 +30,6 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.keras import initializers -from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.util import nest @@ -305,7 +304,10 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): attention_layer_size = attention_layer_size[0] if attention_layer is not None: attention_layer = attention_layer[0] - cell = rnn_cell.LSTMCell(cell_depth, initializer="ones") + cell = keras.layers.LSTMCell(cell_depth, + recurrent_activation="sigmoid", + kernel_initializer="ones", + recurrent_initializer="ones") cell = wrapper.AttentionWrapper( cell, attention_mechanisms if is_multi else attention_mechanisms[0], @@ -321,7 +323,7 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): sampler = sampler_py.TrainingSampler() my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) - initial_state = cell.zero_state( + initial_state = cell.get_initial_state( dtype=dtypes.float32, batch_size=batch_size) final_outputs, final_state, _ = my_decoder( decoder_inputs, @@ -330,7 +332,6 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) self.assertIsInstance(final_state, wrapper.AttentionWrapperState) - self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple) expected_time = ( expected_final_state.time if context.executing_eagerly() else None) @@ -342,9 +343,9 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): self.assertEqual((batch_size, attention_depth), tuple(final_state.attention.get_shape().as_list())) self.assertEqual((batch_size, cell_depth), - tuple(final_state.cell_state.c.get_shape().as_list())) + tuple(final_state.cell_state[0].get_shape().as_list())) self.assertEqual((batch_size, cell_depth), - tuple(final_state.cell_state.h.get_shape().as_list())) + tuple(final_state.cell_state[1].get_shape().as_list())) if alignment_history: if is_multi: @@ -395,8 +396,9 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): expected_final_alignment_history, final_alignment_history_info) - @parameterized.parameters([np.float16, np.float32, np.float64]) - def _testBahdanauNormalizedDType(self, dtype): + # TODO(b/126893309): reenable np.float16 once the bug is fixed. + @parameterized.parameters([np.float32, np.float64]) + def testBahdanauNormalizedDType(self, dtype): encoder_outputs = self.encoder_outputs.astype(dtype) decoder_inputs = self.decoder_inputs.astype(dtype) attention_mechanism = wrapper.BahdanauAttentionV2( @@ -405,7 +407,7 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): memory_sequence_length=self.encoder_sequence_length, normalize=True, dtype=dtype) - cell = rnn_cell.LSTMCell(self.units) + cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid") cell = wrapper.AttentionWrapper(cell, attention_mechanism) sampler = sampler_py.TrainingSampler() @@ -418,9 +420,9 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) self.assertEqual(final_outputs.rnn_output.dtype, dtype) self.assertIsInstance(final_state, wrapper.AttentionWrapperState) - self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple) - @parameterized.parameters([np.float16, np.float32, np.float64]) + # TODO(b/126893309): reenable np.float16 once the bug is fixed. + @parameterized.parameters([np.float32, np.float64]) def testLuongScaledDType(self, dtype): # Test case for GitHub issue 18099 encoder_outputs = self.encoder_outputs.astype(dtype) @@ -432,7 +434,7 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): scale=True, dtype=dtype, ) - cell = rnn_cell.LSTMCell(self.units) + cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid") cell = wrapper.AttentionWrapper(cell, attention_mechanism) sampler = sampler_py.TrainingSampler() @@ -445,7 +447,6 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) self.assertEqual(final_outputs.rnn_output.dtype, dtype) self.assertIsInstance(final_state, wrapper.AttentionWrapperState) - self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple) def testBahdanauNotNormalized(self): create_attention_mechanism = wrapper.BahdanauAttentionV2 @@ -455,11 +456,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=4.8290324), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype(np.int32), mean=0)) expected_final_state = wrapper.AttentionWrapperState( - cell_state=rnn_cell.LSTMStateTuple( - c=ResultSummary( - shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636), - h=ResultSummary( - shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824)), + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824), + ResultSummary( + shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636)], attention=ResultSummary( shape=(5, 6), dtype=np.dtype(np.float32), mean=6.7445569), time=3, @@ -490,11 +491,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) expected_final_state = wrapper.AttentionWrapperState( - cell_state=rnn_cell.LSTMStateTuple( - c=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209), - h=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983)), + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209)], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=6.3075728), time=3, @@ -520,11 +521,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) expected_final_state = wrapper.AttentionWrapperState( - cell_state=rnn_cell.LSTMStateTuple( - c=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547), - h=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)), + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547)], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=4.084631), time=3, @@ -550,11 +551,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) expected_final_state = wrapper.AttentionWrapperState( - cell_state=rnn_cell.LSTMStateTuple( - c=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547), - h=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)), + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547)], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=4.0846314), time=3, @@ -581,11 +582,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=3.86666666)) expected_final_state = wrapper.AttentionWrapperState( - cell_state=rnn_cell.LSTMStateTuple( - c=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002), - h=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742)), + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002)], attention=ResultSummary( shape=(5, 10), dtype=np.dtype("float32"), mean=0.011346335), time=3, @@ -613,11 +614,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) expected_final_state = wrapper.AttentionWrapperState( - cell_state=rnn_cell.LSTMStateTuple( - c=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492), - h=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248)), + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492)], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=8.361186), time=3, @@ -648,11 +649,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) expected_final_state = wrapper.AttentionWrapperState( - cell_state=rnn_cell.LSTMStateTuple( - c=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473), - h=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038)), + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473)], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=7.3326721), time=3, @@ -682,11 +683,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) expected_final_state = wrapper.AttentionWrapperState( - cell_state=rnn_cell.LSTMStateTuple( - c=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384), - h=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)), + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384)], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605), time=3, @@ -716,11 +717,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) expected_final_state = wrapper.AttentionWrapperState( - cell_state=rnn_cell.LSTMStateTuple( - c=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384), - h=ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)), + cell_state=[ + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038), + ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384)], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605), time=3, diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py index abcf71c61b6..599abf5a361 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py @@ -13,31 +13,30 @@ # limitations under the License. # ============================================================================== """Tests for contrib.seq2seq.python.seq2seq.basic_decoder.""" -# pylint: disable=unused-import,g-bad-import-order from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: enable=unused-import import numpy as np -from tensorflow.contrib.seq2seq.python.ops import helper as helper_py from tensorflow.contrib.seq2seq.python.ops import basic_decoder +from tensorflow.contrib.seq2seq.python.ops import helper as helper_py from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn_cell -from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import test -# pylint: enable=g-import-not-at-top +@test_util.run_v1_only class BasicDecoderTest(test.TestCase): def _testStepWithTrainingHelper(self, use_output_layer): diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 56f2a0acc9f..8c84cd13588 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -187,14 +187,23 @@ class TestArrayShapeChecks(test.TestCase): shape=dynamic_shape) batch_size = array_ops.constant(batch_size) - check_op = beam_search_decoder._check_batch_beam(t, batch_size, beam_width) # pylint: disable=protected-access - with self.cached_session() as sess: - if is_valid: - sess.run(check_op) + def _test_body(): + # pylint: disable=protected-access + if context.executing_eagerly(): + beam_search_decoder._check_batch_beam(t, batch_size, beam_width) else: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(check_op) + with self.cached_session(): + check_op = beam_search_decoder._check_batch_beam( + t, batch_size, beam_width) + self.evaluate(check_op) + # pylint: enable=protected-access + + if is_valid: + _test_body() + else: + with self.assertRaises(errors.InvalidArgumentError): + _test_body() def test_array_shape_dynamic_checks(self): self._test_array_shape_dynamic_checks( @@ -463,6 +472,7 @@ class TestLargeBeamStep(test.TestCase): self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]]) +@test_util.run_v1_only class BeamSearchDecoderTest(test.TestCase): def _testDynamicDecodeRNN(self, time_major, has_attention, diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py index b41734d214e..5506aa8b8ee 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py @@ -49,8 +49,8 @@ class GatherTreeTest(test.TestCase): parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=end_token) - with self.session(use_gpu=True): - self.assertAllEqual(expected_result, beams.eval()) + with self.cached_session(use_gpu=True): + self.assertAllEqual(expected_result, self.evaluate(beams)) def testBadParentValuesOnCPU(self): # (batch_size = 1, max_time = 4, beams = 3) @@ -62,15 +62,14 @@ class GatherTreeTest(test.TestCase): [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) max_sequence_lengths = [3] with ops.device("/cpu:0"): - beams = beam_search_ops.gather_tree( - step_ids=step_ids, - parent_ids=parent_ids, - max_sequence_lengths=max_sequence_lengths, - end_token=end_token) - with self.cached_session(): with self.assertRaisesOpError( r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"): - _ = beams.eval() + beams = beam_search_ops.gather_tree( + step_ids=step_ids, + parent_ids=parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=end_token) + self.evaluate(beams) def testBadParentValuesOnGPU(self): # Only want to run this test on CUDA devices, as gather_tree is not @@ -93,8 +92,7 @@ class GatherTreeTest(test.TestCase): parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=end_token) - with self.session(use_gpu=True): - self.assertAllEqual(expected_result, beams.eval()) + self.assertAllEqual(expected_result, self.evaluate(beams)) def testGatherTreeBatch(self): batch_size = 10 @@ -103,7 +101,7 @@ class GatherTreeTest(test.TestCase): max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0] end_token = 5 - with self.session(use_gpu=True): + with self.cached_session(use_gpu=True): step_ids = np.random.randint( 0, high=end_token + 1, size=(max_time, batch_size, beam_width)) parent_ids = np.random.randint( @@ -116,7 +114,7 @@ class GatherTreeTest(test.TestCase): end_token=end_token) self.assertEqual((max_time, batch_size, beam_width), beams.shape) - beams_value = beams.eval() + beams_value = self.evaluate(beams) for b in range(batch_size): # Past max_sequence_lengths[b], we emit all end tokens. b_value = beams_value[max_sequence_lengths[b]:, b, :] diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py index 4c25489fade..4a420221e27 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py @@ -13,26 +13,25 @@ # limitations under the License. # ============================================================================== """Tests for contrib.seq2seq.python.seq2seq.decoder.""" -# pylint: disable=unused-import,g-bad-import-order from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: enable=unused-import import numpy as np +from tensorflow.contrib.seq2seq.python.ops import basic_decoder from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.contrib.seq2seq.python.ops import helper as helper_py -from tensorflow.contrib.seq2seq.python.ops import basic_decoder from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import test -# pylint: enable=g-import-not-at-top +@test_util.run_v1_only class DynamicDecodeRNNTest(test.TestCase): def _testDynamicDecodeRNN(self, time_major, maximum_iterations=None): diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py index 41b2a53ca5b..7eb544a921c 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py @@ -31,7 +31,7 @@ from tensorflow.python.platform import test @test_util.run_all_in_graph_and_eager_modes class LossTest(test.TestCase): - def setUp(self): + def config_default_values(self): self.batch_size = 2 self.sequence_length = 3 self.number_of_classes = 5 @@ -56,7 +56,8 @@ class LossTest(test.TestCase): self.expected_loss = 1.60944 def testSequenceLoss(self): - with self.test_session(use_gpu=True): + self.config_default_values() + with self.cached_session(use_gpu=True): average_loss_per_example = loss.sequence_loss( self.logits, self.targets, self.weights, average_across_timesteps=True, @@ -90,7 +91,8 @@ class LossTest(test.TestCase): self.assertAllClose(compare_total, res) def testSequenceLossClass(self): - with self.test_session(use_gpu=True): + self.config_default_values() + with self.cached_session(use_gpu=True): seq_loss = loss.SequenceLoss(average_across_timesteps=True, average_across_batch=True, sum_over_timesteps=False, @@ -132,7 +134,8 @@ class LossTest(test.TestCase): self.assertAllClose(compare_total, res) def testSumReduction(self): - with self.test_session(use_gpu=True): + self.config_default_values() + with self.cached_session(use_gpu=True): seq_loss = loss.SequenceLoss(average_across_timesteps=False, average_across_batch=False, sum_over_timesteps=True, @@ -174,6 +177,7 @@ class LossTest(test.TestCase): self.assertAllClose(compare_total, res) def testWeightedSumReduction(self): + self.config_default_values() weights = [ constant_op.constant(1.0, shape=[self.batch_size]) for _ in range(self.sequence_length) @@ -181,7 +185,7 @@ class LossTest(test.TestCase): # Make the last element in the sequence to have zero weights. weights[-1] = constant_op.constant(0.0, shape=[self.batch_size]) self.weights = array_ops.stack(weights, axis=1) - with self.test_session(use_gpu=True): + with self.cached_session(use_gpu=True): seq_loss = loss.SequenceLoss(average_across_timesteps=False, average_across_batch=False, sum_over_timesteps=True, @@ -225,12 +229,13 @@ class LossTest(test.TestCase): self.assertAllClose(compare_total, res) def testZeroWeights(self): + self.config_default_values() weights = [ constant_op.constant(0.0, shape=[self.batch_size]) for _ in range(self.sequence_length) ] weights = array_ops.stack(weights, axis=1) - with self.test_session(use_gpu=True): + with self.cached_session(use_gpu=True): average_loss_per_example = loss.sequence_loss( self.logits, self.targets, weights, average_across_timesteps=True, diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 79c2ac2f500..577a3efbd7d 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -2347,7 +2347,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): if self._initial_cell_state is not None: cell_state = self._initial_cell_state else: - cell_state = self._cell.zero_state(batch_size, dtype) + cell_state = self._cell.get_initial_state(batch_size=batch_size, + dtype=dtype) error_message = ( "When calling zero_state of AttentionWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 1d773a44989..e67e5c0d9c5 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -218,7 +218,7 @@ def _check_batch_beam(t, batch_size, beam_width): "incompatible with the dynamic shape of %s elements. " "Consider setting reorder_tensor_arrays to False to disable " "TensorArray reordering during the beam search." - % (t.name)) + % (t if context.executing_eagerly() else t.name)) rank = t.shape.ndims shape = array_ops.shape(t) if rank == 2: