Allow specification of sample_id_shape and sample_id_dtype in seq2seq.BasicDecoder and add a new InferenceHelper.
PiperOrigin-RevId: 165019969
This commit is contained in:
parent
37c54be0ce
commit
e9a8d75bc4
@ -47,6 +47,7 @@ _allowed_symbols = [
|
||||
"FinalBeamSearchDecoderOutput",
|
||||
"gather_tree",
|
||||
"GreedyEmbeddingHelper",
|
||||
"InferenceHelper",
|
||||
"SampleEmbeddingHelper",
|
||||
"ScheduledEmbeddingTrainingHelper",
|
||||
"ScheduledOutputTrainingHelper",
|
||||
|
@ -23,14 +23,19 @@ import numpy as np
|
||||
|
||||
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.layers import core as layers_core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops.distributions import bernoulli
|
||||
from tensorflow.python.ops.distributions import categorical
|
||||
from tensorflow.python.platform import test
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
@ -500,5 +505,166 @@ class BasicDecoderTest(test.TestCase):
|
||||
sampling_probability=0.0, use_next_input_layer=True,
|
||||
use_auxiliary_inputs=True)
|
||||
|
||||
def testStepWithInferenceHelperCategorical(self):
|
||||
batch_size = 5
|
||||
vocabulary_size = 7
|
||||
cell_depth = vocabulary_size
|
||||
start_token = 0
|
||||
end_token = 6
|
||||
|
||||
start_inputs = array_ops.one_hot(
|
||||
np.ones(batch_size) * start_token,
|
||||
vocabulary_size)
|
||||
|
||||
# The sample function samples categorically from the logits.
|
||||
sample_fn = lambda x: categorical.Categorical(logits=x).sample()
|
||||
# The next inputs are a one-hot encoding of the sampled labels.
|
||||
next_inputs_fn = (
|
||||
lambda x: array_ops.one_hot(x, vocabulary_size, dtype=dtypes.float32))
|
||||
end_fn = lambda sample_ids: math_ops.equal(sample_ids, end_token)
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with variable_scope.variable_scope(
|
||||
"testStepWithInferenceHelper",
|
||||
initializer=init_ops.constant_initializer(0.01)):
|
||||
cell = rnn_cell.LSTMCell(vocabulary_size)
|
||||
helper = helper_py.InferenceHelper(
|
||||
sample_fn, sample_shape=(), sample_dtype=dtypes.int32,
|
||||
start_inputs=start_inputs, end_fn=end_fn,
|
||||
next_inputs_fn=next_inputs_fn)
|
||||
my_decoder = basic_decoder.BasicDecoder(
|
||||
cell=cell,
|
||||
helper=helper,
|
||||
initial_state=cell.zero_state(
|
||||
dtype=dtypes.float32, batch_size=batch_size))
|
||||
output_size = my_decoder.output_size
|
||||
output_dtype = my_decoder.output_dtype
|
||||
self.assertEqual(
|
||||
basic_decoder.BasicDecoderOutput(cell_depth,
|
||||
tensor_shape.TensorShape([])),
|
||||
output_size)
|
||||
self.assertEqual(
|
||||
basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
|
||||
output_dtype)
|
||||
|
||||
(first_finished, first_inputs, first_state) = my_decoder.initialize()
|
||||
(step_outputs, step_state, step_next_inputs,
|
||||
step_finished) = my_decoder.step(
|
||||
constant_op.constant(0), first_inputs, first_state)
|
||||
batch_size_t = my_decoder.batch_size
|
||||
|
||||
self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(
|
||||
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
|
||||
self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
|
||||
self.assertEqual((batch_size,), step_outputs[1].get_shape())
|
||||
self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
|
||||
self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
|
||||
self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
|
||||
self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess_results = sess.run({
|
||||
"batch_size": batch_size_t,
|
||||
"first_finished": first_finished,
|
||||
"first_inputs": first_inputs,
|
||||
"first_state": first_state,
|
||||
"step_outputs": step_outputs,
|
||||
"step_state": step_state,
|
||||
"step_next_inputs": step_next_inputs,
|
||||
"step_finished": step_finished
|
||||
})
|
||||
|
||||
sample_ids = sess_results["step_outputs"].sample_id
|
||||
self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
|
||||
expected_step_finished = (sample_ids == end_token)
|
||||
expected_step_next_inputs = np.zeros((batch_size, vocabulary_size))
|
||||
expected_step_next_inputs[np.arange(batch_size), sample_ids] = 1.0
|
||||
self.assertAllEqual(expected_step_finished,
|
||||
sess_results["step_finished"])
|
||||
self.assertAllEqual(expected_step_next_inputs,
|
||||
sess_results["step_next_inputs"])
|
||||
|
||||
def testStepWithInferenceHelperMultilabel(self):
|
||||
batch_size = 5
|
||||
vocabulary_size = 7
|
||||
cell_depth = vocabulary_size
|
||||
start_token = 0
|
||||
end_token = 6
|
||||
|
||||
start_inputs = array_ops.one_hot(
|
||||
np.ones(batch_size) * start_token,
|
||||
vocabulary_size)
|
||||
|
||||
# The sample function samples independent bernoullis from the logits.
|
||||
sample_fn = (
|
||||
lambda x: bernoulli.Bernoulli(logits=x, dtype=dtypes.bool).sample())
|
||||
# The next inputs are a one-hot encoding of the sampled labels.
|
||||
next_inputs_fn = math_ops.to_float
|
||||
end_fn = lambda sample_ids: sample_ids[:, end_token]
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with variable_scope.variable_scope(
|
||||
"testStepWithInferenceHelper",
|
||||
initializer=init_ops.constant_initializer(0.01)):
|
||||
cell = rnn_cell.LSTMCell(vocabulary_size)
|
||||
helper = helper_py.InferenceHelper(
|
||||
sample_fn, sample_shape=[cell_depth], sample_dtype=dtypes.bool,
|
||||
start_inputs=start_inputs, end_fn=end_fn,
|
||||
next_inputs_fn=next_inputs_fn)
|
||||
my_decoder = basic_decoder.BasicDecoder(
|
||||
cell=cell,
|
||||
helper=helper,
|
||||
initial_state=cell.zero_state(
|
||||
dtype=dtypes.float32, batch_size=batch_size))
|
||||
output_size = my_decoder.output_size
|
||||
output_dtype = my_decoder.output_dtype
|
||||
self.assertEqual(
|
||||
basic_decoder.BasicDecoderOutput(cell_depth, cell_depth),
|
||||
output_size)
|
||||
self.assertEqual(
|
||||
basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.bool),
|
||||
output_dtype)
|
||||
|
||||
(first_finished, first_inputs, first_state) = my_decoder.initialize()
|
||||
(step_outputs, step_state, step_next_inputs,
|
||||
step_finished) = my_decoder.step(
|
||||
constant_op.constant(0), first_inputs, first_state)
|
||||
batch_size_t = my_decoder.batch_size
|
||||
|
||||
self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
|
||||
self.assertTrue(
|
||||
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
|
||||
self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
|
||||
self.assertEqual((batch_size, cell_depth), step_outputs[1].get_shape())
|
||||
self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
|
||||
self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
|
||||
self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
|
||||
self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess_results = sess.run({
|
||||
"batch_size": batch_size_t,
|
||||
"first_finished": first_finished,
|
||||
"first_inputs": first_inputs,
|
||||
"first_state": first_state,
|
||||
"step_outputs": step_outputs,
|
||||
"step_state": step_state,
|
||||
"step_next_inputs": step_next_inputs,
|
||||
"step_finished": step_finished
|
||||
})
|
||||
|
||||
sample_ids = sess_results["step_outputs"].sample_id
|
||||
self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
|
||||
expected_step_finished = sample_ids[:, end_token]
|
||||
expected_step_next_inputs = sample_ids.astype(np.float32)
|
||||
self.assertAllEqual(expected_step_finished,
|
||||
sess_results["step_finished"])
|
||||
self.assertAllEqual(expected_step_next_inputs,
|
||||
sess_results["step_next_inputs"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -23,7 +23,6 @@ import collections
|
||||
|
||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.layers import base as layers_base
|
||||
@ -54,7 +53,7 @@ class BasicDecoder(decoder.Decoder):
|
||||
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
|
||||
The initial state of the RNNCell.
|
||||
output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
|
||||
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
|
||||
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
|
||||
to storing the result or sampling.
|
||||
|
||||
Raises:
|
||||
@ -100,17 +99,17 @@ class BasicDecoder(decoder.Decoder):
|
||||
# Return the cell output and the id
|
||||
return BasicDecoderOutput(
|
||||
rnn_output=self._rnn_output_size(),
|
||||
sample_id=tensor_shape.TensorShape([]))
|
||||
sample_id=self._helper.sample_ids_shape)
|
||||
|
||||
@property
|
||||
def output_dtype(self):
|
||||
# Assume the dtype of the cell is the output_size structure
|
||||
# containing the input_state's first component's dtype.
|
||||
# Return that structure and int32 (the id)
|
||||
# Return that structure and the sample_ids_dtype from the helper.
|
||||
dtype = nest.flatten(self._initial_state)[0].dtype
|
||||
return BasicDecoderOutput(
|
||||
nest.map_structure(lambda _: dtype, self._rnn_output_size()),
|
||||
dtypes.int32)
|
||||
self._helper.sample_ids_dtype)
|
||||
|
||||
def initialize(self, name=None):
|
||||
"""Initialize the decoder.
|
||||
|
@ -26,6 +26,7 @@ import six
|
||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.layers import base as layers_base
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -45,6 +46,7 @@ __all__ = [
|
||||
"CustomHelper",
|
||||
"ScheduledEmbeddingTrainingHelper",
|
||||
"ScheduledOutputTrainingHelper",
|
||||
"InferenceHelper",
|
||||
]
|
||||
|
||||
_transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access
|
||||
@ -71,6 +73,22 @@ class Helper(object):
|
||||
"""
|
||||
raise NotImplementedError("batch_size has not been implemented")
|
||||
|
||||
@abc.abstractproperty
|
||||
def sample_ids_shape(self):
|
||||
"""Shape of tensor returned by `sample`, excluding the batch dimension.
|
||||
|
||||
Returns a `TensorShape`.
|
||||
"""
|
||||
raise NotImplementedError("sample_ids_shape has not been implemented")
|
||||
|
||||
@abc.abstractproperty
|
||||
def sample_ids_dtype(self):
|
||||
"""DType of tensor returned by `sample`.
|
||||
|
||||
Returns a DType.
|
||||
"""
|
||||
raise NotImplementedError("sample_ids_dtype has not been implemented")
|
||||
|
||||
@abc.abstractmethod
|
||||
def initialize(self, name=None):
|
||||
"""Returns `(initial_finished, initial_inputs)`."""
|
||||
@ -90,7 +108,8 @@ class Helper(object):
|
||||
class CustomHelper(Helper):
|
||||
"""Base abstract class that allows the user to customize sampling."""
|
||||
|
||||
def __init__(self, initialize_fn, sample_fn, next_inputs_fn):
|
||||
def __init__(self, initialize_fn, sample_fn, next_inputs_fn,
|
||||
sample_ids_shape=None, sample_ids_dtype=None):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
@ -100,11 +119,17 @@ class CustomHelper(Helper):
|
||||
and emits tensor `sample_ids`.
|
||||
next_inputs_fn: callable that takes `(time, outputs, state, sample_ids)`
|
||||
and emits `(finished, next_inputs, next_state)`.
|
||||
sample_ids_shape: Either a list of integers, or a 1-D Tensor of type
|
||||
`int32`, the shape of each value in the `sample_ids` batch. Defaults to
|
||||
a scalar.
|
||||
sample_ids_dtype: The dtype of the `sample_ids` tensor. Defaults to int32.
|
||||
"""
|
||||
self._initialize_fn = initialize_fn
|
||||
self._sample_fn = sample_fn
|
||||
self._next_inputs_fn = next_inputs_fn
|
||||
self._batch_size = None
|
||||
self._sample_ids_shape = tensor_shape.TensorShape(sample_ids_shape or [])
|
||||
self._sample_ids_dtype = sample_ids_dtype or dtypes.int32
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
@ -112,6 +137,14 @@ class CustomHelper(Helper):
|
||||
raise ValueError("batch_size accessed before initialize was called")
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def sample_ids_shape(self):
|
||||
return self._sample_ids_shape
|
||||
|
||||
@property
|
||||
def sample_ids_dtype(self):
|
||||
return self._sample_ids_dtype
|
||||
|
||||
def initialize(self, name=None):
|
||||
with ops.name_scope(name, "%sInitialize" % type(self).__name__):
|
||||
(finished, next_inputs) = self._initialize_fn()
|
||||
@ -172,6 +205,14 @@ class TrainingHelper(Helper):
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def sample_ids_shape(self):
|
||||
return tensor_shape.TensorShape([])
|
||||
|
||||
@property
|
||||
def sample_ids_dtype(self):
|
||||
return dtypes.int32
|
||||
|
||||
def initialize(self, name=None):
|
||||
with ops.name_scope(name, "TrainingHelperInitialize"):
|
||||
finished = math_ops.equal(0, self._sequence_length)
|
||||
@ -485,6 +526,14 @@ class GreedyEmbeddingHelper(Helper):
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def sample_ids_shape(self):
|
||||
return tensor_shape.TensorShape([])
|
||||
|
||||
@property
|
||||
def sample_ids_dtype(self):
|
||||
return dtypes.int32
|
||||
|
||||
def initialize(self, name=None):
|
||||
finished = array_ops.tile([False], [self._batch_size])
|
||||
return (finished, self._start_inputs)
|
||||
@ -562,3 +611,61 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper):
|
||||
sample_ids = sample_id_sampler.sample(seed=self._seed)
|
||||
|
||||
return sample_ids
|
||||
|
||||
|
||||
class InferenceHelper(Helper):
|
||||
"""A helper to use during inference with a custom sampling function."""
|
||||
|
||||
def __init__(self, sample_fn, sample_shape, sample_dtype,
|
||||
start_inputs, end_fn, next_inputs_fn=None):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
sample_fn: A callable that takes `outputs` and emits tensor `sample_ids`.
|
||||
sample_shape: Either a list of integers, or a 1-D Tensor of type `int32`,
|
||||
the shape of the each sample in the batch returned by `sample_fn`.
|
||||
sample_dtype: the dtype of the sample returned by `sample_fn`.
|
||||
start_inputs: The initial batch of inputs.
|
||||
end_fn: A callable that takes `sample_ids` and emits a `bool` vector
|
||||
shaped `[batch_size]` indicating whether each sample is an end token.
|
||||
next_inputs_fn: (Optional) A callable that takes `sample_ids` and returns
|
||||
the next batch of inputs. If not provided, `sample_ids` is used as the
|
||||
next batch of inputs.
|
||||
"""
|
||||
self._sample_fn = sample_fn
|
||||
self._end_fn = end_fn
|
||||
self._sample_shape = tensor_shape.TensorShape(sample_shape)
|
||||
self._sample_dtype = sample_dtype
|
||||
self._next_inputs_fn = next_inputs_fn
|
||||
self._batch_size = array_ops.shape(start_inputs)[0]
|
||||
self._start_inputs = ops.convert_to_tensor(
|
||||
start_inputs, name="start_inputs")
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def sample_ids_shape(self):
|
||||
return self._sample_shape
|
||||
|
||||
@property
|
||||
def sample_ids_dtype(self):
|
||||
return self._sample_dtype
|
||||
|
||||
def initialize(self, name=None):
|
||||
finished = array_ops.tile([False], [self._batch_size])
|
||||
return (finished, self._start_inputs)
|
||||
|
||||
def sample(self, time, outputs, state, name=None):
|
||||
del time, state # unused by sample
|
||||
return self._sample_fn(outputs)
|
||||
|
||||
def next_inputs(self, time, outputs, state, sample_ids, name=None):
|
||||
del time, outputs # unused by next_inputs
|
||||
if self._next_inputs_fn is None:
|
||||
next_inputs = sample_ids
|
||||
else:
|
||||
next_inputs = self._next_inputs_fn(sample_ids)
|
||||
finished = self._end_fn(sample_ids)
|
||||
return (finished, next_inputs, state)
|
||||
|
Loading…
Reference in New Issue
Block a user