diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py index 2cd2726a6fa..fa3f074c67c 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py @@ -358,14 +358,12 @@ class BasicDecoderTest(test.TestCase): np.squeeze(inputs[batch_where_not_sampling, 1])) def _testStepWithScheduledOutputTrainingHelper( - self, sampling_probability, use_next_input_layer, use_auxiliary_inputs): + self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 cell_depth = input_depth - if use_next_input_layer: - cell_depth = 6 if use_auxiliary_inputs: auxiliary_input_depth = 4 auxiliary_inputs = np.random.randn( @@ -379,16 +377,20 @@ class BasicDecoderTest(test.TestCase): cell = rnn_cell.LSTMCell(cell_depth) sampling_probability = constant_op.constant(sampling_probability) - next_input_layer = None - if use_next_input_layer: - next_input_layer = layers_core.Dense(input_depth, use_bias=False) + if use_next_inputs_fn: + def next_inputs_fn(outputs): + # Use deterministic function for test. + samples = math_ops.argmax(outputs, axis=1) + return array_ops.one_hot(samples, cell_depth, dtype=dtypes.float32) + else: + next_inputs_fn = None helper = helper_py.ScheduledOutputTrainingHelper( inputs=inputs, sequence_length=sequence_length, sampling_probability=sampling_probability, time_major=False, - next_input_layer=next_input_layer, + next_inputs_fn=next_inputs_fn, auxiliary_inputs=auxiliary_inputs) my_decoder = basic_decoder.BasicDecoder( @@ -412,9 +414,8 @@ class BasicDecoderTest(test.TestCase): step_finished) = my_decoder.step( constant_op.constant(0), first_inputs, first_state) - if use_next_input_layer: - output_after_next_input_layer = next_input_layer( - step_outputs.rnn_output) + if use_next_inputs_fn: + output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output) batch_size_t = my_decoder.batch_size @@ -441,8 +442,8 @@ class BasicDecoderTest(test.TestCase): "step_next_inputs": step_next_inputs, "step_finished": step_finished } - if use_next_input_layer: - fetches["output_after_next_input_layer"] = output_after_next_input_layer + if use_next_inputs_fn: + fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn sess_results = sess.run(fetches) @@ -461,8 +462,8 @@ class BasicDecoderTest(test.TestCase): np.array([]).reshape(batch_size, 0).astype(np.float32)) expected_next_sampling_inputs = np.concatenate( - (sess_results["output_after_next_input_layer"][batch_where_sampling] - if use_next_input_layer else + (sess_results["output_after_next_inputs_fn"][batch_where_sampling] + if use_next_inputs_fn else sess_results["step_outputs"].rnn_output[batch_where_sampling], auxiliary_inputs_to_concat[batch_where_sampling]), axis=-1) @@ -477,32 +478,31 @@ class BasicDecoderTest(test.TestCase): auxiliary_inputs_to_concat[batch_where_not_sampling]), axis=-1)) - def testStepWithScheduledOutputTrainingHelperWithoutNextInputLayerOrAuxInputs( + def testStepWithScheduledOutputTrainingHelperWithoutNextInputsFnOrAuxInputs( self): self._testStepWithScheduledOutputTrainingHelper( - sampling_probability=0.5, use_next_input_layer=False, + sampling_probability=0.5, use_next_inputs_fn=False, use_auxiliary_inputs=False) - def testStepWithScheduledOutputTrainingHelperWithNextInputLayer(self): + def testStepWithScheduledOutputTrainingHelperWithNextInputsFn(self): self._testStepWithScheduledOutputTrainingHelper( - sampling_probability=0.5, use_next_input_layer=True, + sampling_probability=0.5, use_next_inputs_fn=True, use_auxiliary_inputs=False) def testStepWithScheduledOutputTrainingHelperWithAuxiliaryInputs(self): self._testStepWithScheduledOutputTrainingHelper( - sampling_probability=0.5, use_next_input_layer=False, + sampling_probability=0.5, use_next_inputs_fn=False, use_auxiliary_inputs=True) - def testStepWithScheduledOutputTrainingHelperWithNextInputLayerAndAuxInputs( + def testStepWithScheduledOutputTrainingHelperWithNextInputsFnAndAuxInputs( self): self._testStepWithScheduledOutputTrainingHelper( - sampling_probability=0.5, use_next_input_layer=True, + sampling_probability=0.5, use_next_inputs_fn=True, use_auxiliary_inputs=True) - def testStepWithScheduledOutputTrainingHelperWithNoSampling( - self): + def testStepWithScheduledOutputTrainingHelperWithNoSampling(self): self._testStepWithScheduledOutputTrainingHelper( - sampling_probability=0.0, use_next_input_layer=True, + sampling_probability=0.0, use_next_inputs_fn=True, use_auxiliary_inputs=True) def testStepWithInferenceHelperCategorical(self): diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py index c1682de0411..64e00c21c70 100644 --- a/tensorflow/contrib/seq2seq/python/ops/helper.py +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -27,7 +27,6 @@ from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.layers import base as layers_base from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops @@ -351,7 +350,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper): """ def __init__(self, inputs, sequence_length, sampling_probability, - time_major=False, seed=None, next_input_layer=None, + time_major=False, seed=None, next_inputs_fn=None, auxiliary_inputs=None, name=None): """Initializer. @@ -363,9 +362,9 @@ class ScheduledOutputTrainingHelper(TrainingHelper): time_major: Python bool. Whether the tensors in `inputs` are time major. If `False` (default), they are assumed to be batch major. seed: The sampling seed. - next_input_layer: (Optional) An instance of `tf.layers.Layer`, i.e., - `tf.layers.Dense`. Optional layer to apply to the RNN output to create - the next input. + next_inputs_fn: (Optional) callable to apply to the RNN outputs to create + the next input when sampling. If `None` (default), the RNN outputs will + be used as the next inputs. auxiliary_inputs: An optional (structure of) auxiliary input tensors with a shape that matches `inputs` in all but (potentially) the final dimension. These tensors will be concatenated to the sampled output or @@ -403,11 +402,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper): self._seed = seed - if (next_input_layer is not None and not isinstance(next_input_layer, - layers_base.Layer)): - raise TypeError("next_input_layer must be a Layer, received: %s" % - type(next_input_layer)) - self._next_input_layer = next_input_layer + self._next_inputs_fn = next_inputs_fn super(ScheduledOutputTrainingHelper, self).__init__( inputs=maybe_concatenated_inputs, @@ -453,7 +448,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper): lambda x, y: array_ops.concat((x, y), -1), outputs_, auxiliary_inputs) - if self._next_input_layer is None: + if self._next_inputs_fn is None: return array_ops.where( sample_ids, maybe_concatenate_auxiliary_inputs(outputs), base_next_inputs) @@ -466,7 +461,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper): inputs_not_sampling = array_ops.gather_nd(base_next_inputs, where_not_sampling) sampled_next_inputs = maybe_concatenate_auxiliary_inputs( - self._next_input_layer(outputs_sampling), where_sampling) + self._next_inputs_fn(outputs_sampling), where_sampling) base_shape = array_ops.shape(base_next_inputs) return (array_ops.scatter_nd(indices=where_sampling,