Allow the next input in ScheduledOutputTrainingHelper to be computed by any callable by removing requirement that it be a Layer.
PiperOrigin-RevId: 165042108
This commit is contained in:
parent
189e594b59
commit
f01a3191a8
@ -358,14 +358,12 @@ class BasicDecoderTest(test.TestCase):
|
|||||||
np.squeeze(inputs[batch_where_not_sampling, 1]))
|
np.squeeze(inputs[batch_where_not_sampling, 1]))
|
||||||
|
|
||||||
def _testStepWithScheduledOutputTrainingHelper(
|
def _testStepWithScheduledOutputTrainingHelper(
|
||||||
self, sampling_probability, use_next_input_layer, use_auxiliary_inputs):
|
self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs):
|
||||||
sequence_length = [3, 4, 3, 1, 0]
|
sequence_length = [3, 4, 3, 1, 0]
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
max_time = 8
|
max_time = 8
|
||||||
input_depth = 7
|
input_depth = 7
|
||||||
cell_depth = input_depth
|
cell_depth = input_depth
|
||||||
if use_next_input_layer:
|
|
||||||
cell_depth = 6
|
|
||||||
if use_auxiliary_inputs:
|
if use_auxiliary_inputs:
|
||||||
auxiliary_input_depth = 4
|
auxiliary_input_depth = 4
|
||||||
auxiliary_inputs = np.random.randn(
|
auxiliary_inputs = np.random.randn(
|
||||||
@ -379,16 +377,20 @@ class BasicDecoderTest(test.TestCase):
|
|||||||
cell = rnn_cell.LSTMCell(cell_depth)
|
cell = rnn_cell.LSTMCell(cell_depth)
|
||||||
sampling_probability = constant_op.constant(sampling_probability)
|
sampling_probability = constant_op.constant(sampling_probability)
|
||||||
|
|
||||||
next_input_layer = None
|
if use_next_inputs_fn:
|
||||||
if use_next_input_layer:
|
def next_inputs_fn(outputs):
|
||||||
next_input_layer = layers_core.Dense(input_depth, use_bias=False)
|
# Use deterministic function for test.
|
||||||
|
samples = math_ops.argmax(outputs, axis=1)
|
||||||
|
return array_ops.one_hot(samples, cell_depth, dtype=dtypes.float32)
|
||||||
|
else:
|
||||||
|
next_inputs_fn = None
|
||||||
|
|
||||||
helper = helper_py.ScheduledOutputTrainingHelper(
|
helper = helper_py.ScheduledOutputTrainingHelper(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
sequence_length=sequence_length,
|
sequence_length=sequence_length,
|
||||||
sampling_probability=sampling_probability,
|
sampling_probability=sampling_probability,
|
||||||
time_major=False,
|
time_major=False,
|
||||||
next_input_layer=next_input_layer,
|
next_inputs_fn=next_inputs_fn,
|
||||||
auxiliary_inputs=auxiliary_inputs)
|
auxiliary_inputs=auxiliary_inputs)
|
||||||
|
|
||||||
my_decoder = basic_decoder.BasicDecoder(
|
my_decoder = basic_decoder.BasicDecoder(
|
||||||
@ -412,9 +414,8 @@ class BasicDecoderTest(test.TestCase):
|
|||||||
step_finished) = my_decoder.step(
|
step_finished) = my_decoder.step(
|
||||||
constant_op.constant(0), first_inputs, first_state)
|
constant_op.constant(0), first_inputs, first_state)
|
||||||
|
|
||||||
if use_next_input_layer:
|
if use_next_inputs_fn:
|
||||||
output_after_next_input_layer = next_input_layer(
|
output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output)
|
||||||
step_outputs.rnn_output)
|
|
||||||
|
|
||||||
batch_size_t = my_decoder.batch_size
|
batch_size_t = my_decoder.batch_size
|
||||||
|
|
||||||
@ -441,8 +442,8 @@ class BasicDecoderTest(test.TestCase):
|
|||||||
"step_next_inputs": step_next_inputs,
|
"step_next_inputs": step_next_inputs,
|
||||||
"step_finished": step_finished
|
"step_finished": step_finished
|
||||||
}
|
}
|
||||||
if use_next_input_layer:
|
if use_next_inputs_fn:
|
||||||
fetches["output_after_next_input_layer"] = output_after_next_input_layer
|
fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn
|
||||||
|
|
||||||
sess_results = sess.run(fetches)
|
sess_results = sess.run(fetches)
|
||||||
|
|
||||||
@ -461,8 +462,8 @@ class BasicDecoderTest(test.TestCase):
|
|||||||
np.array([]).reshape(batch_size, 0).astype(np.float32))
|
np.array([]).reshape(batch_size, 0).astype(np.float32))
|
||||||
|
|
||||||
expected_next_sampling_inputs = np.concatenate(
|
expected_next_sampling_inputs = np.concatenate(
|
||||||
(sess_results["output_after_next_input_layer"][batch_where_sampling]
|
(sess_results["output_after_next_inputs_fn"][batch_where_sampling]
|
||||||
if use_next_input_layer else
|
if use_next_inputs_fn else
|
||||||
sess_results["step_outputs"].rnn_output[batch_where_sampling],
|
sess_results["step_outputs"].rnn_output[batch_where_sampling],
|
||||||
auxiliary_inputs_to_concat[batch_where_sampling]),
|
auxiliary_inputs_to_concat[batch_where_sampling]),
|
||||||
axis=-1)
|
axis=-1)
|
||||||
@ -477,32 +478,31 @@ class BasicDecoderTest(test.TestCase):
|
|||||||
auxiliary_inputs_to_concat[batch_where_not_sampling]),
|
auxiliary_inputs_to_concat[batch_where_not_sampling]),
|
||||||
axis=-1))
|
axis=-1))
|
||||||
|
|
||||||
def testStepWithScheduledOutputTrainingHelperWithoutNextInputLayerOrAuxInputs(
|
def testStepWithScheduledOutputTrainingHelperWithoutNextInputsFnOrAuxInputs(
|
||||||
self):
|
self):
|
||||||
self._testStepWithScheduledOutputTrainingHelper(
|
self._testStepWithScheduledOutputTrainingHelper(
|
||||||
sampling_probability=0.5, use_next_input_layer=False,
|
sampling_probability=0.5, use_next_inputs_fn=False,
|
||||||
use_auxiliary_inputs=False)
|
use_auxiliary_inputs=False)
|
||||||
|
|
||||||
def testStepWithScheduledOutputTrainingHelperWithNextInputLayer(self):
|
def testStepWithScheduledOutputTrainingHelperWithNextInputsFn(self):
|
||||||
self._testStepWithScheduledOutputTrainingHelper(
|
self._testStepWithScheduledOutputTrainingHelper(
|
||||||
sampling_probability=0.5, use_next_input_layer=True,
|
sampling_probability=0.5, use_next_inputs_fn=True,
|
||||||
use_auxiliary_inputs=False)
|
use_auxiliary_inputs=False)
|
||||||
|
|
||||||
def testStepWithScheduledOutputTrainingHelperWithAuxiliaryInputs(self):
|
def testStepWithScheduledOutputTrainingHelperWithAuxiliaryInputs(self):
|
||||||
self._testStepWithScheduledOutputTrainingHelper(
|
self._testStepWithScheduledOutputTrainingHelper(
|
||||||
sampling_probability=0.5, use_next_input_layer=False,
|
sampling_probability=0.5, use_next_inputs_fn=False,
|
||||||
use_auxiliary_inputs=True)
|
use_auxiliary_inputs=True)
|
||||||
|
|
||||||
def testStepWithScheduledOutputTrainingHelperWithNextInputLayerAndAuxInputs(
|
def testStepWithScheduledOutputTrainingHelperWithNextInputsFnAndAuxInputs(
|
||||||
self):
|
self):
|
||||||
self._testStepWithScheduledOutputTrainingHelper(
|
self._testStepWithScheduledOutputTrainingHelper(
|
||||||
sampling_probability=0.5, use_next_input_layer=True,
|
sampling_probability=0.5, use_next_inputs_fn=True,
|
||||||
use_auxiliary_inputs=True)
|
use_auxiliary_inputs=True)
|
||||||
|
|
||||||
def testStepWithScheduledOutputTrainingHelperWithNoSampling(
|
def testStepWithScheduledOutputTrainingHelperWithNoSampling(self):
|
||||||
self):
|
|
||||||
self._testStepWithScheduledOutputTrainingHelper(
|
self._testStepWithScheduledOutputTrainingHelper(
|
||||||
sampling_probability=0.0, use_next_input_layer=True,
|
sampling_probability=0.0, use_next_inputs_fn=True,
|
||||||
use_auxiliary_inputs=True)
|
use_auxiliary_inputs=True)
|
||||||
|
|
||||||
def testStepWithInferenceHelperCategorical(self):
|
def testStepWithInferenceHelperCategorical(self):
|
||||||
|
@ -27,7 +27,6 @@ from tensorflow.contrib.seq2seq.python.ops import decoder
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.layers import base as layers_base
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
@ -351,7 +350,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, inputs, sequence_length, sampling_probability,
|
def __init__(self, inputs, sequence_length, sampling_probability,
|
||||||
time_major=False, seed=None, next_input_layer=None,
|
time_major=False, seed=None, next_inputs_fn=None,
|
||||||
auxiliary_inputs=None, name=None):
|
auxiliary_inputs=None, name=None):
|
||||||
"""Initializer.
|
"""Initializer.
|
||||||
|
|
||||||
@ -363,9 +362,9 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
|
|||||||
time_major: Python bool. Whether the tensors in `inputs` are time major.
|
time_major: Python bool. Whether the tensors in `inputs` are time major.
|
||||||
If `False` (default), they are assumed to be batch major.
|
If `False` (default), they are assumed to be batch major.
|
||||||
seed: The sampling seed.
|
seed: The sampling seed.
|
||||||
next_input_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
|
next_inputs_fn: (Optional) callable to apply to the RNN outputs to create
|
||||||
`tf.layers.Dense`. Optional layer to apply to the RNN output to create
|
the next input when sampling. If `None` (default), the RNN outputs will
|
||||||
the next input.
|
be used as the next inputs.
|
||||||
auxiliary_inputs: An optional (structure of) auxiliary input tensors with
|
auxiliary_inputs: An optional (structure of) auxiliary input tensors with
|
||||||
a shape that matches `inputs` in all but (potentially) the final
|
a shape that matches `inputs` in all but (potentially) the final
|
||||||
dimension. These tensors will be concatenated to the sampled output or
|
dimension. These tensors will be concatenated to the sampled output or
|
||||||
@ -403,11 +402,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
|
|||||||
|
|
||||||
self._seed = seed
|
self._seed = seed
|
||||||
|
|
||||||
if (next_input_layer is not None and not isinstance(next_input_layer,
|
self._next_inputs_fn = next_inputs_fn
|
||||||
layers_base.Layer)):
|
|
||||||
raise TypeError("next_input_layer must be a Layer, received: %s" %
|
|
||||||
type(next_input_layer))
|
|
||||||
self._next_input_layer = next_input_layer
|
|
||||||
|
|
||||||
super(ScheduledOutputTrainingHelper, self).__init__(
|
super(ScheduledOutputTrainingHelper, self).__init__(
|
||||||
inputs=maybe_concatenated_inputs,
|
inputs=maybe_concatenated_inputs,
|
||||||
@ -453,7 +448,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
|
|||||||
lambda x, y: array_ops.concat((x, y), -1),
|
lambda x, y: array_ops.concat((x, y), -1),
|
||||||
outputs_, auxiliary_inputs)
|
outputs_, auxiliary_inputs)
|
||||||
|
|
||||||
if self._next_input_layer is None:
|
if self._next_inputs_fn is None:
|
||||||
return array_ops.where(
|
return array_ops.where(
|
||||||
sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
|
sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
|
||||||
base_next_inputs)
|
base_next_inputs)
|
||||||
@ -466,7 +461,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
|
|||||||
inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
|
inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
|
||||||
where_not_sampling)
|
where_not_sampling)
|
||||||
sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
|
sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
|
||||||
self._next_input_layer(outputs_sampling), where_sampling)
|
self._next_inputs_fn(outputs_sampling), where_sampling)
|
||||||
|
|
||||||
base_shape = array_ops.shape(base_next_inputs)
|
base_shape = array_ops.shape(base_next_inputs)
|
||||||
return (array_ops.scatter_nd(indices=where_sampling,
|
return (array_ops.scatter_nd(indices=where_sampling,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user