Allow the next input in ScheduledOutputTrainingHelper to be computed by any callable by removing requirement that it be a Layer.

PiperOrigin-RevId: 165042108
This commit is contained in:
Adam Roberts 2017-08-11 16:47:52 -07:00 committed by TensorFlower Gardener
parent 189e594b59
commit f01a3191a8
2 changed files with 31 additions and 36 deletions

View File

@ -358,14 +358,12 @@ class BasicDecoderTest(test.TestCase):
np.squeeze(inputs[batch_where_not_sampling, 1])) np.squeeze(inputs[batch_where_not_sampling, 1]))
def _testStepWithScheduledOutputTrainingHelper( def _testStepWithScheduledOutputTrainingHelper(
self, sampling_probability, use_next_input_layer, use_auxiliary_inputs): self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs):
sequence_length = [3, 4, 3, 1, 0] sequence_length = [3, 4, 3, 1, 0]
batch_size = 5 batch_size = 5
max_time = 8 max_time = 8
input_depth = 7 input_depth = 7
cell_depth = input_depth cell_depth = input_depth
if use_next_input_layer:
cell_depth = 6
if use_auxiliary_inputs: if use_auxiliary_inputs:
auxiliary_input_depth = 4 auxiliary_input_depth = 4
auxiliary_inputs = np.random.randn( auxiliary_inputs = np.random.randn(
@ -379,16 +377,20 @@ class BasicDecoderTest(test.TestCase):
cell = rnn_cell.LSTMCell(cell_depth) cell = rnn_cell.LSTMCell(cell_depth)
sampling_probability = constant_op.constant(sampling_probability) sampling_probability = constant_op.constant(sampling_probability)
next_input_layer = None if use_next_inputs_fn:
if use_next_input_layer: def next_inputs_fn(outputs):
next_input_layer = layers_core.Dense(input_depth, use_bias=False) # Use deterministic function for test.
samples = math_ops.argmax(outputs, axis=1)
return array_ops.one_hot(samples, cell_depth, dtype=dtypes.float32)
else:
next_inputs_fn = None
helper = helper_py.ScheduledOutputTrainingHelper( helper = helper_py.ScheduledOutputTrainingHelper(
inputs=inputs, inputs=inputs,
sequence_length=sequence_length, sequence_length=sequence_length,
sampling_probability=sampling_probability, sampling_probability=sampling_probability,
time_major=False, time_major=False,
next_input_layer=next_input_layer, next_inputs_fn=next_inputs_fn,
auxiliary_inputs=auxiliary_inputs) auxiliary_inputs=auxiliary_inputs)
my_decoder = basic_decoder.BasicDecoder( my_decoder = basic_decoder.BasicDecoder(
@ -412,9 +414,8 @@ class BasicDecoderTest(test.TestCase):
step_finished) = my_decoder.step( step_finished) = my_decoder.step(
constant_op.constant(0), first_inputs, first_state) constant_op.constant(0), first_inputs, first_state)
if use_next_input_layer: if use_next_inputs_fn:
output_after_next_input_layer = next_input_layer( output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output)
step_outputs.rnn_output)
batch_size_t = my_decoder.batch_size batch_size_t = my_decoder.batch_size
@ -441,8 +442,8 @@ class BasicDecoderTest(test.TestCase):
"step_next_inputs": step_next_inputs, "step_next_inputs": step_next_inputs,
"step_finished": step_finished "step_finished": step_finished
} }
if use_next_input_layer: if use_next_inputs_fn:
fetches["output_after_next_input_layer"] = output_after_next_input_layer fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn
sess_results = sess.run(fetches) sess_results = sess.run(fetches)
@ -461,8 +462,8 @@ class BasicDecoderTest(test.TestCase):
np.array([]).reshape(batch_size, 0).astype(np.float32)) np.array([]).reshape(batch_size, 0).astype(np.float32))
expected_next_sampling_inputs = np.concatenate( expected_next_sampling_inputs = np.concatenate(
(sess_results["output_after_next_input_layer"][batch_where_sampling] (sess_results["output_after_next_inputs_fn"][batch_where_sampling]
if use_next_input_layer else if use_next_inputs_fn else
sess_results["step_outputs"].rnn_output[batch_where_sampling], sess_results["step_outputs"].rnn_output[batch_where_sampling],
auxiliary_inputs_to_concat[batch_where_sampling]), auxiliary_inputs_to_concat[batch_where_sampling]),
axis=-1) axis=-1)
@ -477,32 +478,31 @@ class BasicDecoderTest(test.TestCase):
auxiliary_inputs_to_concat[batch_where_not_sampling]), auxiliary_inputs_to_concat[batch_where_not_sampling]),
axis=-1)) axis=-1))
def testStepWithScheduledOutputTrainingHelperWithoutNextInputLayerOrAuxInputs( def testStepWithScheduledOutputTrainingHelperWithoutNextInputsFnOrAuxInputs(
self): self):
self._testStepWithScheduledOutputTrainingHelper( self._testStepWithScheduledOutputTrainingHelper(
sampling_probability=0.5, use_next_input_layer=False, sampling_probability=0.5, use_next_inputs_fn=False,
use_auxiliary_inputs=False) use_auxiliary_inputs=False)
def testStepWithScheduledOutputTrainingHelperWithNextInputLayer(self): def testStepWithScheduledOutputTrainingHelperWithNextInputsFn(self):
self._testStepWithScheduledOutputTrainingHelper( self._testStepWithScheduledOutputTrainingHelper(
sampling_probability=0.5, use_next_input_layer=True, sampling_probability=0.5, use_next_inputs_fn=True,
use_auxiliary_inputs=False) use_auxiliary_inputs=False)
def testStepWithScheduledOutputTrainingHelperWithAuxiliaryInputs(self): def testStepWithScheduledOutputTrainingHelperWithAuxiliaryInputs(self):
self._testStepWithScheduledOutputTrainingHelper( self._testStepWithScheduledOutputTrainingHelper(
sampling_probability=0.5, use_next_input_layer=False, sampling_probability=0.5, use_next_inputs_fn=False,
use_auxiliary_inputs=True) use_auxiliary_inputs=True)
def testStepWithScheduledOutputTrainingHelperWithNextInputLayerAndAuxInputs( def testStepWithScheduledOutputTrainingHelperWithNextInputsFnAndAuxInputs(
self): self):
self._testStepWithScheduledOutputTrainingHelper( self._testStepWithScheduledOutputTrainingHelper(
sampling_probability=0.5, use_next_input_layer=True, sampling_probability=0.5, use_next_inputs_fn=True,
use_auxiliary_inputs=True) use_auxiliary_inputs=True)
def testStepWithScheduledOutputTrainingHelperWithNoSampling( def testStepWithScheduledOutputTrainingHelperWithNoSampling(self):
self):
self._testStepWithScheduledOutputTrainingHelper( self._testStepWithScheduledOutputTrainingHelper(
sampling_probability=0.0, use_next_input_layer=True, sampling_probability=0.0, use_next_inputs_fn=True,
use_auxiliary_inputs=True) use_auxiliary_inputs=True)
def testStepWithInferenceHelperCategorical(self): def testStepWithInferenceHelperCategorical(self):

View File

@ -27,7 +27,6 @@ from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as layers_base
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import embedding_ops
@ -351,7 +350,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
""" """
def __init__(self, inputs, sequence_length, sampling_probability, def __init__(self, inputs, sequence_length, sampling_probability,
time_major=False, seed=None, next_input_layer=None, time_major=False, seed=None, next_inputs_fn=None,
auxiliary_inputs=None, name=None): auxiliary_inputs=None, name=None):
"""Initializer. """Initializer.
@ -363,9 +362,9 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
time_major: Python bool. Whether the tensors in `inputs` are time major. time_major: Python bool. Whether the tensors in `inputs` are time major.
If `False` (default), they are assumed to be batch major. If `False` (default), they are assumed to be batch major.
seed: The sampling seed. seed: The sampling seed.
next_input_layer: (Optional) An instance of `tf.layers.Layer`, i.e., next_inputs_fn: (Optional) callable to apply to the RNN outputs to create
`tf.layers.Dense`. Optional layer to apply to the RNN output to create the next input when sampling. If `None` (default), the RNN outputs will
the next input. be used as the next inputs.
auxiliary_inputs: An optional (structure of) auxiliary input tensors with auxiliary_inputs: An optional (structure of) auxiliary input tensors with
a shape that matches `inputs` in all but (potentially) the final a shape that matches `inputs` in all but (potentially) the final
dimension. These tensors will be concatenated to the sampled output or dimension. These tensors will be concatenated to the sampled output or
@ -403,11 +402,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
self._seed = seed self._seed = seed
if (next_input_layer is not None and not isinstance(next_input_layer, self._next_inputs_fn = next_inputs_fn
layers_base.Layer)):
raise TypeError("next_input_layer must be a Layer, received: %s" %
type(next_input_layer))
self._next_input_layer = next_input_layer
super(ScheduledOutputTrainingHelper, self).__init__( super(ScheduledOutputTrainingHelper, self).__init__(
inputs=maybe_concatenated_inputs, inputs=maybe_concatenated_inputs,
@ -453,7 +448,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
lambda x, y: array_ops.concat((x, y), -1), lambda x, y: array_ops.concat((x, y), -1),
outputs_, auxiliary_inputs) outputs_, auxiliary_inputs)
if self._next_input_layer is None: if self._next_inputs_fn is None:
return array_ops.where( return array_ops.where(
sample_ids, maybe_concatenate_auxiliary_inputs(outputs), sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
base_next_inputs) base_next_inputs)
@ -466,7 +461,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
inputs_not_sampling = array_ops.gather_nd(base_next_inputs, inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
where_not_sampling) where_not_sampling)
sampled_next_inputs = maybe_concatenate_auxiliary_inputs( sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
self._next_input_layer(outputs_sampling), where_sampling) self._next_inputs_fn(outputs_sampling), where_sampling)
base_shape = array_ops.shape(base_next_inputs) base_shape = array_ops.shape(base_next_inputs)
return (array_ops.scatter_nd(indices=where_sampling, return (array_ops.scatter_nd(indices=where_sampling,