Allow the next input in ScheduledOutputTrainingHelper to be computed by any callable by removing requirement that it be a Layer.
PiperOrigin-RevId: 165042108
This commit is contained in:
parent
189e594b59
commit
f01a3191a8
@ -358,14 +358,12 @@ class BasicDecoderTest(test.TestCase):
|
||||
np.squeeze(inputs[batch_where_not_sampling, 1]))
|
||||
|
||||
def _testStepWithScheduledOutputTrainingHelper(
|
||||
self, sampling_probability, use_next_input_layer, use_auxiliary_inputs):
|
||||
self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs):
|
||||
sequence_length = [3, 4, 3, 1, 0]
|
||||
batch_size = 5
|
||||
max_time = 8
|
||||
input_depth = 7
|
||||
cell_depth = input_depth
|
||||
if use_next_input_layer:
|
||||
cell_depth = 6
|
||||
if use_auxiliary_inputs:
|
||||
auxiliary_input_depth = 4
|
||||
auxiliary_inputs = np.random.randn(
|
||||
@ -379,16 +377,20 @@ class BasicDecoderTest(test.TestCase):
|
||||
cell = rnn_cell.LSTMCell(cell_depth)
|
||||
sampling_probability = constant_op.constant(sampling_probability)
|
||||
|
||||
next_input_layer = None
|
||||
if use_next_input_layer:
|
||||
next_input_layer = layers_core.Dense(input_depth, use_bias=False)
|
||||
if use_next_inputs_fn:
|
||||
def next_inputs_fn(outputs):
|
||||
# Use deterministic function for test.
|
||||
samples = math_ops.argmax(outputs, axis=1)
|
||||
return array_ops.one_hot(samples, cell_depth, dtype=dtypes.float32)
|
||||
else:
|
||||
next_inputs_fn = None
|
||||
|
||||
helper = helper_py.ScheduledOutputTrainingHelper(
|
||||
inputs=inputs,
|
||||
sequence_length=sequence_length,
|
||||
sampling_probability=sampling_probability,
|
||||
time_major=False,
|
||||
next_input_layer=next_input_layer,
|
||||
next_inputs_fn=next_inputs_fn,
|
||||
auxiliary_inputs=auxiliary_inputs)
|
||||
|
||||
my_decoder = basic_decoder.BasicDecoder(
|
||||
@ -412,9 +414,8 @@ class BasicDecoderTest(test.TestCase):
|
||||
step_finished) = my_decoder.step(
|
||||
constant_op.constant(0), first_inputs, first_state)
|
||||
|
||||
if use_next_input_layer:
|
||||
output_after_next_input_layer = next_input_layer(
|
||||
step_outputs.rnn_output)
|
||||
if use_next_inputs_fn:
|
||||
output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output)
|
||||
|
||||
batch_size_t = my_decoder.batch_size
|
||||
|
||||
@ -441,8 +442,8 @@ class BasicDecoderTest(test.TestCase):
|
||||
"step_next_inputs": step_next_inputs,
|
||||
"step_finished": step_finished
|
||||
}
|
||||
if use_next_input_layer:
|
||||
fetches["output_after_next_input_layer"] = output_after_next_input_layer
|
||||
if use_next_inputs_fn:
|
||||
fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn
|
||||
|
||||
sess_results = sess.run(fetches)
|
||||
|
||||
@ -461,8 +462,8 @@ class BasicDecoderTest(test.TestCase):
|
||||
np.array([]).reshape(batch_size, 0).astype(np.float32))
|
||||
|
||||
expected_next_sampling_inputs = np.concatenate(
|
||||
(sess_results["output_after_next_input_layer"][batch_where_sampling]
|
||||
if use_next_input_layer else
|
||||
(sess_results["output_after_next_inputs_fn"][batch_where_sampling]
|
||||
if use_next_inputs_fn else
|
||||
sess_results["step_outputs"].rnn_output[batch_where_sampling],
|
||||
auxiliary_inputs_to_concat[batch_where_sampling]),
|
||||
axis=-1)
|
||||
@ -477,32 +478,31 @@ class BasicDecoderTest(test.TestCase):
|
||||
auxiliary_inputs_to_concat[batch_where_not_sampling]),
|
||||
axis=-1))
|
||||
|
||||
def testStepWithScheduledOutputTrainingHelperWithoutNextInputLayerOrAuxInputs(
|
||||
def testStepWithScheduledOutputTrainingHelperWithoutNextInputsFnOrAuxInputs(
|
||||
self):
|
||||
self._testStepWithScheduledOutputTrainingHelper(
|
||||
sampling_probability=0.5, use_next_input_layer=False,
|
||||
sampling_probability=0.5, use_next_inputs_fn=False,
|
||||
use_auxiliary_inputs=False)
|
||||
|
||||
def testStepWithScheduledOutputTrainingHelperWithNextInputLayer(self):
|
||||
def testStepWithScheduledOutputTrainingHelperWithNextInputsFn(self):
|
||||
self._testStepWithScheduledOutputTrainingHelper(
|
||||
sampling_probability=0.5, use_next_input_layer=True,
|
||||
sampling_probability=0.5, use_next_inputs_fn=True,
|
||||
use_auxiliary_inputs=False)
|
||||
|
||||
def testStepWithScheduledOutputTrainingHelperWithAuxiliaryInputs(self):
|
||||
self._testStepWithScheduledOutputTrainingHelper(
|
||||
sampling_probability=0.5, use_next_input_layer=False,
|
||||
sampling_probability=0.5, use_next_inputs_fn=False,
|
||||
use_auxiliary_inputs=True)
|
||||
|
||||
def testStepWithScheduledOutputTrainingHelperWithNextInputLayerAndAuxInputs(
|
||||
def testStepWithScheduledOutputTrainingHelperWithNextInputsFnAndAuxInputs(
|
||||
self):
|
||||
self._testStepWithScheduledOutputTrainingHelper(
|
||||
sampling_probability=0.5, use_next_input_layer=True,
|
||||
sampling_probability=0.5, use_next_inputs_fn=True,
|
||||
use_auxiliary_inputs=True)
|
||||
|
||||
def testStepWithScheduledOutputTrainingHelperWithNoSampling(
|
||||
self):
|
||||
def testStepWithScheduledOutputTrainingHelperWithNoSampling(self):
|
||||
self._testStepWithScheduledOutputTrainingHelper(
|
||||
sampling_probability=0.0, use_next_input_layer=True,
|
||||
sampling_probability=0.0, use_next_inputs_fn=True,
|
||||
use_auxiliary_inputs=True)
|
||||
|
||||
def testStepWithInferenceHelperCategorical(self):
|
||||
|
@ -27,7 +27,6 @@ from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.layers import base as layers_base
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
@ -351,7 +350,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
|
||||
"""
|
||||
|
||||
def __init__(self, inputs, sequence_length, sampling_probability,
|
||||
time_major=False, seed=None, next_input_layer=None,
|
||||
time_major=False, seed=None, next_inputs_fn=None,
|
||||
auxiliary_inputs=None, name=None):
|
||||
"""Initializer.
|
||||
|
||||
@ -363,9 +362,9 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
|
||||
time_major: Python bool. Whether the tensors in `inputs` are time major.
|
||||
If `False` (default), they are assumed to be batch major.
|
||||
seed: The sampling seed.
|
||||
next_input_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
|
||||
`tf.layers.Dense`. Optional layer to apply to the RNN output to create
|
||||
the next input.
|
||||
next_inputs_fn: (Optional) callable to apply to the RNN outputs to create
|
||||
the next input when sampling. If `None` (default), the RNN outputs will
|
||||
be used as the next inputs.
|
||||
auxiliary_inputs: An optional (structure of) auxiliary input tensors with
|
||||
a shape that matches `inputs` in all but (potentially) the final
|
||||
dimension. These tensors will be concatenated to the sampled output or
|
||||
@ -403,11 +402,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
|
||||
|
||||
self._seed = seed
|
||||
|
||||
if (next_input_layer is not None and not isinstance(next_input_layer,
|
||||
layers_base.Layer)):
|
||||
raise TypeError("next_input_layer must be a Layer, received: %s" %
|
||||
type(next_input_layer))
|
||||
self._next_input_layer = next_input_layer
|
||||
self._next_inputs_fn = next_inputs_fn
|
||||
|
||||
super(ScheduledOutputTrainingHelper, self).__init__(
|
||||
inputs=maybe_concatenated_inputs,
|
||||
@ -453,7 +448,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
|
||||
lambda x, y: array_ops.concat((x, y), -1),
|
||||
outputs_, auxiliary_inputs)
|
||||
|
||||
if self._next_input_layer is None:
|
||||
if self._next_inputs_fn is None:
|
||||
return array_ops.where(
|
||||
sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
|
||||
base_next_inputs)
|
||||
@ -466,7 +461,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
|
||||
inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
|
||||
where_not_sampling)
|
||||
sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
|
||||
self._next_input_layer(outputs_sampling), where_sampling)
|
||||
self._next_inputs_fn(outputs_sampling), where_sampling)
|
||||
|
||||
base_shape = array_ops.shape(base_next_inputs)
|
||||
return (array_ops.scatter_nd(indices=where_sampling,
|
||||
|
Loading…
Reference in New Issue
Block a user