Allow the next input in ScheduledOutputTrainingHelper to be computed by any callable by removing requirement that it be a Layer.

PiperOrigin-RevId: 165042108
This commit is contained in:
Adam Roberts 2017-08-11 16:47:52 -07:00 committed by TensorFlower Gardener
parent 189e594b59
commit f01a3191a8
2 changed files with 31 additions and 36 deletions

View File

@ -358,14 +358,12 @@ class BasicDecoderTest(test.TestCase):
np.squeeze(inputs[batch_where_not_sampling, 1]))
def _testStepWithScheduledOutputTrainingHelper(
self, sampling_probability, use_next_input_layer, use_auxiliary_inputs):
self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs):
sequence_length = [3, 4, 3, 1, 0]
batch_size = 5
max_time = 8
input_depth = 7
cell_depth = input_depth
if use_next_input_layer:
cell_depth = 6
if use_auxiliary_inputs:
auxiliary_input_depth = 4
auxiliary_inputs = np.random.randn(
@ -379,16 +377,20 @@ class BasicDecoderTest(test.TestCase):
cell = rnn_cell.LSTMCell(cell_depth)
sampling_probability = constant_op.constant(sampling_probability)
next_input_layer = None
if use_next_input_layer:
next_input_layer = layers_core.Dense(input_depth, use_bias=False)
if use_next_inputs_fn:
def next_inputs_fn(outputs):
# Use deterministic function for test.
samples = math_ops.argmax(outputs, axis=1)
return array_ops.one_hot(samples, cell_depth, dtype=dtypes.float32)
else:
next_inputs_fn = None
helper = helper_py.ScheduledOutputTrainingHelper(
inputs=inputs,
sequence_length=sequence_length,
sampling_probability=sampling_probability,
time_major=False,
next_input_layer=next_input_layer,
next_inputs_fn=next_inputs_fn,
auxiliary_inputs=auxiliary_inputs)
my_decoder = basic_decoder.BasicDecoder(
@ -412,9 +414,8 @@ class BasicDecoderTest(test.TestCase):
step_finished) = my_decoder.step(
constant_op.constant(0), first_inputs, first_state)
if use_next_input_layer:
output_after_next_input_layer = next_input_layer(
step_outputs.rnn_output)
if use_next_inputs_fn:
output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output)
batch_size_t = my_decoder.batch_size
@ -441,8 +442,8 @@ class BasicDecoderTest(test.TestCase):
"step_next_inputs": step_next_inputs,
"step_finished": step_finished
}
if use_next_input_layer:
fetches["output_after_next_input_layer"] = output_after_next_input_layer
if use_next_inputs_fn:
fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn
sess_results = sess.run(fetches)
@ -461,8 +462,8 @@ class BasicDecoderTest(test.TestCase):
np.array([]).reshape(batch_size, 0).astype(np.float32))
expected_next_sampling_inputs = np.concatenate(
(sess_results["output_after_next_input_layer"][batch_where_sampling]
if use_next_input_layer else
(sess_results["output_after_next_inputs_fn"][batch_where_sampling]
if use_next_inputs_fn else
sess_results["step_outputs"].rnn_output[batch_where_sampling],
auxiliary_inputs_to_concat[batch_where_sampling]),
axis=-1)
@ -477,32 +478,31 @@ class BasicDecoderTest(test.TestCase):
auxiliary_inputs_to_concat[batch_where_not_sampling]),
axis=-1))
def testStepWithScheduledOutputTrainingHelperWithoutNextInputLayerOrAuxInputs(
def testStepWithScheduledOutputTrainingHelperWithoutNextInputsFnOrAuxInputs(
self):
self._testStepWithScheduledOutputTrainingHelper(
sampling_probability=0.5, use_next_input_layer=False,
sampling_probability=0.5, use_next_inputs_fn=False,
use_auxiliary_inputs=False)
def testStepWithScheduledOutputTrainingHelperWithNextInputLayer(self):
def testStepWithScheduledOutputTrainingHelperWithNextInputsFn(self):
self._testStepWithScheduledOutputTrainingHelper(
sampling_probability=0.5, use_next_input_layer=True,
sampling_probability=0.5, use_next_inputs_fn=True,
use_auxiliary_inputs=False)
def testStepWithScheduledOutputTrainingHelperWithAuxiliaryInputs(self):
self._testStepWithScheduledOutputTrainingHelper(
sampling_probability=0.5, use_next_input_layer=False,
sampling_probability=0.5, use_next_inputs_fn=False,
use_auxiliary_inputs=True)
def testStepWithScheduledOutputTrainingHelperWithNextInputLayerAndAuxInputs(
def testStepWithScheduledOutputTrainingHelperWithNextInputsFnAndAuxInputs(
self):
self._testStepWithScheduledOutputTrainingHelper(
sampling_probability=0.5, use_next_input_layer=True,
sampling_probability=0.5, use_next_inputs_fn=True,
use_auxiliary_inputs=True)
def testStepWithScheduledOutputTrainingHelperWithNoSampling(
self):
def testStepWithScheduledOutputTrainingHelperWithNoSampling(self):
self._testStepWithScheduledOutputTrainingHelper(
sampling_probability=0.0, use_next_input_layer=True,
sampling_probability=0.0, use_next_inputs_fn=True,
use_auxiliary_inputs=True)
def testStepWithInferenceHelperCategorical(self):

View File

@ -27,7 +27,6 @@ from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as layers_base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
@ -351,7 +350,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
"""
def __init__(self, inputs, sequence_length, sampling_probability,
time_major=False, seed=None, next_input_layer=None,
time_major=False, seed=None, next_inputs_fn=None,
auxiliary_inputs=None, name=None):
"""Initializer.
@ -363,9 +362,9 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
time_major: Python bool. Whether the tensors in `inputs` are time major.
If `False` (default), they are assumed to be batch major.
seed: The sampling seed.
next_input_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
`tf.layers.Dense`. Optional layer to apply to the RNN output to create
the next input.
next_inputs_fn: (Optional) callable to apply to the RNN outputs to create
the next input when sampling. If `None` (default), the RNN outputs will
be used as the next inputs.
auxiliary_inputs: An optional (structure of) auxiliary input tensors with
a shape that matches `inputs` in all but (potentially) the final
dimension. These tensors will be concatenated to the sampled output or
@ -403,11 +402,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
self._seed = seed
if (next_input_layer is not None and not isinstance(next_input_layer,
layers_base.Layer)):
raise TypeError("next_input_layer must be a Layer, received: %s" %
type(next_input_layer))
self._next_input_layer = next_input_layer
self._next_inputs_fn = next_inputs_fn
super(ScheduledOutputTrainingHelper, self).__init__(
inputs=maybe_concatenated_inputs,
@ -453,7 +448,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
lambda x, y: array_ops.concat((x, y), -1),
outputs_, auxiliary_inputs)
if self._next_input_layer is None:
if self._next_inputs_fn is None:
return array_ops.where(
sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
base_next_inputs)
@ -466,7 +461,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
where_not_sampling)
sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
self._next_input_layer(outputs_sampling), where_sampling)
self._next_inputs_fn(outputs_sampling), where_sampling)
base_shape = array_ops.shape(base_next_inputs)
return (array_ops.scatter_nd(indices=where_sampling,