Add BeamSearchDecoderV2 which can be used as a keras layer.

PiperOrigin-RevId: 233982439
This commit is contained in:
Scott Zhu 2019-02-14 10:28:30 -08:00 committed by TensorFlower Gardener
parent 03aa9d18f5
commit 3c9b46c245
3 changed files with 528 additions and 148 deletions

View File

@ -25,10 +25,13 @@ from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import layers
from tensorflow.python.layers import core as layers_core from tensorflow.python.layers import core as layers_core
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
@ -530,11 +533,10 @@ class BeamSearchDecoderTest(test.TestCase):
return (shape[1], shape[0]) + shape[2:] return (shape[1], shape[0]) + shape[2:]
return shape return shape
self.assertTrue( self.assertIsInstance(
isinstance(final_outputs, final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput)
beam_search_decoder.FinalBeamSearchDecoderOutput)) self.assertIsInstance(
self.assertTrue( final_state, beam_search_decoder.BeamSearchDecoderState)
isinstance(final_state, beam_search_decoder.BeamSearchDecoderState))
beam_search_decoder_output = final_outputs.beam_search_decoder_output beam_search_decoder_output = final_outputs.beam_search_decoder_output
self.assertEqual( self.assertEqual(
@ -574,5 +576,119 @@ class BeamSearchDecoderTest(test.TestCase):
with_alignment_history=True) with_alignment_history=True)
@test_util.run_all_in_graph_and_eager_modes
class BeamSearchDecoderV2Test(test.TestCase):
def _testDynamicDecodeRNN(self, time_major, has_attention,
with_alignment_history=False):
encoder_sequence_length = np.array([3, 2, 3, 1, 1])
decoder_sequence_length = np.array([2, 0, 1, 2, 3])
batch_size = 5
decoder_max_time = 4
input_depth = 7
cell_depth = 9
attention_depth = 6
vocab_size = 20
end_token = vocab_size - 1
start_token = 0
embedding_dim = 50
max_out = max(decoder_sequence_length)
output_layer = layers.Dense(vocab_size, use_bias=True, activation=None)
beam_width = 3
with self.cached_session():
batch_size_tensor = constant_op.constant(batch_size)
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
cell = rnn_cell.LSTMCell(cell_depth)
initial_state = cell.zero_state(batch_size, dtypes.float32)
coverage_penalty_weight = 0.0
if has_attention:
coverage_penalty_weight = 0.2
inputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, decoder_max_time, input_depth).astype(
np.float32),
shape=(None, None, input_depth))
tiled_inputs = beam_search_decoder.tile_batch(
inputs, multiplier=beam_width)
tiled_sequence_length = beam_search_decoder.tile_batch(
encoder_sequence_length, multiplier=beam_width)
attention_mechanism = attention_wrapper.BahdanauAttention(
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
initial_state = beam_search_decoder.tile_batch(
initial_state, multiplier=beam_width)
cell = attention_wrapper.AttentionWrapper(
cell=cell,
attention_mechanism=attention_mechanism,
attention_layer_size=attention_depth,
alignment_history=with_alignment_history)
cell_state = cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
if has_attention:
cell_state = cell_state.clone(cell_state=initial_state)
bsd = beam_search_decoder.BeamSearchDecoderV2(
cell=cell,
beam_width=beam_width,
output_layer=output_layer,
length_penalty_weight=0.0,
coverage_penalty_weight=coverage_penalty_weight,
output_time_major=time_major,
maximum_iterations=max_out)
final_outputs, final_state, final_sequence_lengths = bsd(
embedding,
start_tokens=array_ops.fill([batch_size_tensor], start_token),
end_token=end_token,
initial_state=cell_state)
def _t(shape):
if time_major:
return (shape[1], shape[0]) + shape[2:]
return shape
self.assertIsInstance(
final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput)
self.assertIsInstance(
final_state, beam_search_decoder.BeamSearchDecoderState)
beam_search_decoder_output = final_outputs.beam_search_decoder_output
expected_seq_length = 3 if context.executing_eagerly() else None
self.assertEqual(
_t((batch_size, expected_seq_length, beam_width)),
tuple(beam_search_decoder_output.scores.get_shape().as_list()))
self.assertEqual(
_t((batch_size, expected_seq_length, beam_width)),
tuple(final_outputs.predicted_ids.get_shape().as_list()))
self.evaluate(variables.global_variables_initializer())
eval_results = self.evaluate({
'final_outputs': final_outputs,
'final_sequence_lengths': final_sequence_lengths
})
max_sequence_length = np.max(eval_results['final_sequence_lengths'])
# A smoke test
self.assertEqual(
_t((batch_size, max_sequence_length, beam_width)),
eval_results['final_outputs'].beam_search_decoder_output.scores.shape)
self.assertEqual(
_t((batch_size, max_sequence_length, beam_width)), eval_results[
'final_outputs'].beam_search_decoder_output.predicted_ids.shape)
def testDynamicDecodeRNNBatchMajorNoAttention(self):
self._testDynamicDecodeRNN(time_major=False, has_attention=False)
def testDynamicDecodeRNNBatchMajorYesAttention(self):
self._testDynamicDecodeRNN(time_major=False, has_attention=True)
def testDynamicDecodeRNNBatchMajorYesAttentionWithAlignmentHistory(self):
self._testDynamicDecodeRNN(
time_major=False,
has_attention=True,
with_alignment_history=True)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -25,6 +25,7 @@ import math
import numpy as np import numpy as np
from tensorflow.contrib.framework.python.framework import tensor_util from tensorflow.contrib.framework.python.framework import tensor_util
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
@ -1919,7 +1920,15 @@ class AttentionWrapperState(
def with_same_shape(old, new): def with_same_shape(old, new):
"""Check and set new tensor's shape.""" """Check and set new tensor's shape."""
if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor): if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor):
return tensor_util.with_same_shape(old, new) if not context.executing_eagerly():
return tensor_util.with_same_shape(old, new)
else:
if old.shape.as_list() != new.shape.as_list():
raise ValueError("The shape of the AttentionWrapperState is "
"expected to be same as the one to clone. "
"self.shape: %s, input.shape: %s" %
(old.shape, new.shape))
return new
return new return new
return nest.map_structure( return nest.map_structure(
@ -2048,13 +2057,13 @@ def _compute_attention(attention_mechanism, cell_output, attention_state,
# the batched matmul is over memory_time, so the output shape is # the batched matmul is over memory_time, so the output shape is
# [batch_size, 1, memory_size]. # [batch_size, 1, memory_size].
# we then squeeze out the singleton dim. # we then squeeze out the singleton dim.
context = math_ops.matmul(expanded_alignments, attention_mechanism.values) context_ = math_ops.matmul(expanded_alignments, attention_mechanism.values)
context = array_ops.squeeze(context, [1]) context_ = array_ops.squeeze(context_, [1])
if attention_layer is not None: if attention_layer is not None:
attention = attention_layer(array_ops.concat([cell_output, context], 1)) attention = attention_layer(array_ops.concat([cell_output, context_], 1))
else: else:
attention = context attention = context_
return attention, alignments, next_attention_state return attention, alignments, next_attention_state

View File

@ -24,11 +24,12 @@ import numpy as np
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import base as layers_base from tensorflow.python.keras import layers
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import embedding_ops
@ -182,11 +183,12 @@ def gather_tree_from_array(t, parent_ids, sequence_length):
return ordered return ordered
def _check_maybe(t): def _check_ndims(t):
if t.shape.ndims is None: if t.shape.ndims is None:
raise ValueError( raise ValueError(
"Expected tensor (%s) to have known rank, but ndims == None." % t) "Expected tensor (%s) to have known rank, but ndims == None." % t)
def _check_static_batch_beam_maybe(shape, batch_size, beam_width): def _check_static_batch_beam_maybe(shape, batch_size, beam_width):
"""Raises an exception if dimensions are known statically and can not be """Raises an exception if dimensions are known statically and can not be
reshaped to [batch_size, beam_size, -1]. reshaped to [batch_size, beam_size, -1].
@ -205,6 +207,7 @@ def _check_static_batch_beam_maybe(shape, batch_size, beam_width):
return False return False
return True return True
def _check_batch_beam(t, batch_size, beam_width): def _check_batch_beam(t, batch_size, beam_width):
"""Returns an Assert operation checking that the elements of the stacked """Returns an Assert operation checking that the elements of the stacked
TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point, TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point,
@ -229,70 +232,30 @@ def _check_batch_beam(t, batch_size, beam_width):
return control_flow_ops.Assert(condition, [error_message]) return control_flow_ops.Assert(condition, [error_message])
class BeamSearchDecoderMixin(object):
"""BeamSearchDecoderMixin contains the common methods for BeamSearchDecoder.
class BeamSearchDecoder(decoder.Decoder): It is expected to be used a base class for concrete BeamSearchDecoder. Since
"""BeamSearch sampling decoder. this is a mixin class, it is expected to be used together with other class as
base.
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
`AttentionWrapper`, then you must ensure that:
- The encoder output has been tiled to `beam_width` via
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
- The `batch_size` argument passed to the `zero_state` method of this
wrapper is equal to `true_batch_size * beam_width`.
- The initial state created with `zero_state` above contains a
`cell_state` value containing properly tiled final state from the
encoder.
An example:
```
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
encoder_outputs, multiplier=beam_width)
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
sequence_length, multiplier=beam_width)
attention_mechanism = MyFavoriteAttentionMechanism(
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
decoder_initial_state = attention_cell.zero_state(
dtype, batch_size=true_batch_size * beam_width)
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
```
Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages
the translation to cover all inputs.
""" """
def __init__(self, def __init__(self,
cell, cell,
embedding,
start_tokens,
end_token,
initial_state,
beam_width, beam_width,
output_layer=None, output_layer=None,
length_penalty_weight=0.0, length_penalty_weight=0.0,
coverage_penalty_weight=0.0, coverage_penalty_weight=0.0,
reorder_tensor_arrays=True): reorder_tensor_arrays=True,
"""Initialize the BeamSearchDecoder. **kwargs):
"""Initialize the BeamSearchDecoderMixin.
Args: Args:
cell: An `RNNCell` instance. cell: An `RNNCell` instance.
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
or the `params` argument for `embedding_lookup`.
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
end_token: `int32` scalar, the token that marks end of decoding.
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
beam_width: Python integer, the number of beams. beam_width: Python integer, the number of beams.
output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
`tf.layers.Dense`. Optional layer to apply to the RNN output prior `tf.keras.layers.Dense`. Optional layer to apply to the RNN output
to storing the result or sampling. prior to storing the result or sampling.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0. length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
coverage_penalty_weight: Float weight to penalize the coverage of source coverage_penalty_weight: Float weight to penalize the coverage of source
sentence. Disabled with 0.0. sentence. Disabled with 0.0.
@ -302,59 +265,35 @@ class BeamSearchDecoder(decoder.Decoder):
Otherwise, the `TensorArray` will be returned as is. Set this flag to Otherwise, the `TensorArray` will be returned as is. Set this flag to
`False` if the cell state contains `TensorArray`s that are not amenable `False` if the cell state contains `TensorArray`s that are not amenable
to reordering. to reordering.
**kwargs: Dict, other keyword arguments for parent class.
Raises: Raises:
TypeError: if `cell` is not an instance of `RNNCell`, TypeError: if `cell` is not an instance of `RNNCell`,
or `output_layer` is not an instance of `tf.layers.Layer`. or `output_layer` is not an instance of `tf.keras.layers.Layer`.
ValueError: If `start_tokens` is not a vector or
`end_token` is not a scalar.
""" """
rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access
if (output_layer is not None and if (output_layer is not None and
not isinstance(output_layer, layers_base.Layer)): not isinstance(output_layer, layers.Layer)):
raise TypeError( raise TypeError(
"output_layer must be a Layer, received: %s" % type(output_layer)) "output_layer must be a Layer, received: %s" % type(output_layer))
self._cell = cell self._cell = cell
self._output_layer = output_layer self._output_layer = output_layer
self._reorder_tensor_arrays = reorder_tensor_arrays self._reorder_tensor_arrays = reorder_tensor_arrays
if callable(embedding): self._start_tokens = None
self._embedding_fn = embedding self._end_token = None
else: self._batch_size = None
self._embedding_fn = (
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
self._start_tokens = ops.convert_to_tensor(
start_tokens, dtype=dtypes.int32, name="start_tokens")
if self._start_tokens.get_shape().ndims != 1:
raise ValueError("start_tokens must be a vector")
self._end_token = ops.convert_to_tensor(
end_token, dtype=dtypes.int32, name="end_token")
if self._end_token.get_shape().ndims != 0:
raise ValueError("end_token must be a scalar")
self._batch_size = array_ops.size(start_tokens)
self._beam_width = beam_width self._beam_width = beam_width
self._length_penalty_weight = length_penalty_weight self._length_penalty_weight = length_penalty_weight
self._coverage_penalty_weight = coverage_penalty_weight self._coverage_penalty_weight = coverage_penalty_weight
self._initial_cell_state = nest.map_structure( super(BeamSearchDecoderMixin, self).__init__(**kwargs)
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
self._start_tokens = array_ops.tile(
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
self._start_inputs = self._embedding_fn(self._start_tokens)
self._finished = array_ops.one_hot(
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
depth=self._beam_width,
on_value=False,
off_value=True,
dtype=dtypes.bool)
@property @property
def batch_size(self): def batch_size(self):
return self._batch_size return self._batch_size
def _rnn_output_size(self): def _rnn_output_size(self):
"""Get the output shape from the RNN layer."""
size = self._cell.output_size size = self._cell.output_size
if self._output_layer is None: if self._output_layer is None:
return size return size
@ -393,50 +332,6 @@ class BeamSearchDecoder(decoder.Decoder):
predicted_ids=tensor_shape.TensorShape([self._beam_width]), predicted_ids=tensor_shape.TensorShape([self._beam_width]),
parent_ids=tensor_shape.TensorShape([self._beam_width])) parent_ids=tensor_shape.TensorShape([self._beam_width]))
@property
def output_dtype(self):
# Assume the dtype of the cell is the output_size structure
# containing the input_state's first component's dtype.
# Return that structure and int32 (the id)
dtype = nest.flatten(self._initial_cell_state)[0].dtype
return BeamSearchDecoderOutput(
scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
predicted_ids=dtypes.int32,
parent_ids=dtypes.int32)
def initialize(self, name=None):
"""Initialize the decoder.
Args:
name: Name scope for any created operations.
Returns:
`(finished, start_inputs, initial_state)`.
"""
finished, start_inputs = self._finished, self._start_inputs
dtype = nest.flatten(self._initial_cell_state)[0].dtype
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
depth=self._beam_width,
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
dtype=dtype)
init_attention_probs = get_attention_probs(
self._initial_cell_state, self._coverage_penalty_weight)
if init_attention_probs is None:
init_attention_probs = ()
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
log_probs=log_probs,
finished=finished,
lengths=array_ops.zeros(
[self._batch_size, self._beam_width], dtype=dtypes.int64),
accumulated_attention_probs=init_attention_probs)
return (finished, start_inputs, initial_state)
def finalize(self, outputs, final_state, sequence_lengths): def finalize(self, outputs, final_state, sequence_lengths):
"""Finalize and return the predicted_ids. """Finalize and return the predicted_ids.
@ -562,7 +457,7 @@ class BeamSearchDecoder(decoder.Decoder):
""" """
if isinstance(t, tensor_array_ops.TensorArray): if isinstance(t, tensor_array_ops.TensorArray):
return t return t
_check_maybe(t) _check_ndims(t)
if t.shape.ndims >= 1: if t.shape.ndims >= 1:
return self._split_batch_beams(t, s) return self._split_batch_beams(t, s)
else: else:
@ -586,7 +481,7 @@ class BeamSearchDecoder(decoder.Decoder):
""" """
if isinstance(t, tensor_array_ops.TensorArray): if isinstance(t, tensor_array_ops.TensorArray):
return t return t
_check_maybe(t) _check_ndims(t)
if t.shape.ndims >= 2: if t.shape.ndims >= 2:
return self._merge_batch_beams(t, s) return self._merge_batch_beams(t, s)
else: else:
@ -609,11 +504,18 @@ class BeamSearchDecoder(decoder.Decoder):
if not isinstance(t, tensor_array_ops.TensorArray): if not isinstance(t, tensor_array_ops.TensorArray):
return t return t
# pylint: disable=protected-access # pylint: disable=protected-access
if (not t._infer_shape or not t._element_shape # This is a bad hack due to the implementation detail of eager/graph TA.
or t._element_shape[0].ndims is None # TODO(b/124374427): Update this to use public property of TensorArray.
or t._element_shape[0].ndims < 1): if context.executing_eagerly():
element_shape = t._element_shape
else:
element_shape = t._element_shape[0]
if (not t._infer_shape
or not t._element_shape
or element_shape.ndims is None
or element_shape.ndims < 1):
shape = ( shape = (
t._element_shape[0] if t._infer_shape and t._element_shape element_shape if t._infer_shape and t._element_shape
else tensor_shape.TensorShape(None)) else tensor_shape.TensorShape(None))
tf_logging.warn("The TensorArray %s in the cell state is not amenable to " tf_logging.warn("The TensorArray %s in the cell state is not amenable to "
"sorting based on the beam search result. For a " "sorting based on the beam search result. For a "
@ -621,10 +523,10 @@ class BeamSearchDecoder(decoder.Decoder):
"defined and have at least a rank of 1, but saw shape: %s" "defined and have at least a rank of 1, but saw shape: %s"
% (t.handle.name, shape)) % (t.handle.name, shape))
return t return t
shape = t._element_shape[0]
# pylint: enable=protected-access # pylint: enable=protected-access
if not _check_static_batch_beam_maybe( if not _check_static_batch_beam_maybe(
shape, tensor_util.constant_value(self._batch_size), self._beam_width): element_shape, tensor_util.constant_value(self._batch_size),
self._beam_width):
return t return t
t = t.stack() t = t.stack()
with ops.control_dependencies( with ops.control_dependencies(
@ -684,6 +586,359 @@ class BeamSearchDecoder(decoder.Decoder):
return (beam_search_output, beam_search_state, next_inputs, finished) return (beam_search_output, beam_search_state, next_inputs, finished)
class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.Decoder):
# Note that the inheritance hierarchy is important here. The Mixin has to be
# the first parent class since we will use super().__init__(), and Mixin which
# is a object will properly invoke the __init__ method of other parent class.
"""BeamSearch sampling decoder.
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
`AttentionWrapper`, then you must ensure that:
- The encoder output has been tiled to `beam_width` via
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
- The `batch_size` argument passed to the `zero_state` method of this
wrapper is equal to `true_batch_size * beam_width`.
- The initial state created with `zero_state` above contains a
`cell_state` value containing properly tiled final state from the
encoder.
An example:
```
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
encoder_outputs, multiplier=beam_width)
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
sequence_length, multiplier=beam_width)
attention_mechanism = MyFavoriteAttentionMechanism(
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
decoder_initial_state = attention_cell.zero_state(
dtype, batch_size=true_batch_size * beam_width)
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
```
Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages
the decoder to cover all inputs.
"""
def __init__(self,
cell,
embedding,
start_tokens,
end_token,
initial_state,
beam_width,
output_layer=None,
length_penalty_weight=0.0,
coverage_penalty_weight=0.0,
reorder_tensor_arrays=True):
"""Initialize the BeamSearchDecoder.
Args:
cell: An `RNNCell` instance.
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
or the `params` argument for `embedding_lookup`.
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
end_token: `int32` scalar, the token that marks end of decoding.
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
beam_width: Python integer, the number of beams.
output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
`tf.keras.layers.Dense`. Optional layer to apply to the RNN output
prior to storing the result or sampling.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
coverage_penalty_weight: Float weight to penalize the coverage of source
sentence. Disabled with 0.0.
reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
state will be reordered according to the beam search path. If the
`TensorArray` can be reordered, the stacked form will be returned.
Otherwise, the `TensorArray` will be returned as is. Set this flag to
`False` if the cell state contains `TensorArray`s that are not amenable
to reordering.
Raises:
TypeError: if `cell` is not an instance of `RNNCell`,
or `output_layer` is not an instance of `tf.keras.layers.Layer`.
ValueError: If `start_tokens` is not a vector or
`end_token` is not a scalar.
"""
super(BeamSearchDecoder, self).__init__(
cell,
beam_width,
output_layer=output_layer,
length_penalty_weight=length_penalty_weight,
coverage_penalty_weight=coverage_penalty_weight,
reorder_tensor_arrays=reorder_tensor_arrays)
if callable(embedding):
self._embedding_fn = embedding
else:
self._embedding_fn = (
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
self._start_tokens = ops.convert_to_tensor(
start_tokens, dtype=dtypes.int32, name="start_tokens")
if self._start_tokens.get_shape().ndims != 1:
raise ValueError("start_tokens must be a vector")
self._end_token = ops.convert_to_tensor(
end_token, dtype=dtypes.int32, name="end_token")
if self._end_token.get_shape().ndims != 0:
raise ValueError("end_token must be a scalar")
self._batch_size = array_ops.size(start_tokens)
self._initial_cell_state = nest.map_structure(
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
self._start_tokens = array_ops.tile(
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
self._start_inputs = self._embedding_fn(self._start_tokens)
self._finished = array_ops.one_hot(
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
depth=self._beam_width,
on_value=False,
off_value=True,
dtype=dtypes.bool)
def initialize(self, name=None):
"""Initialize the decoder.
Args:
name: Name scope for any created operations.
Returns:
`(finished, start_inputs, initial_state)`.
"""
finished, start_inputs = self._finished, self._start_inputs
dtype = nest.flatten(self._initial_cell_state)[0].dtype
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
depth=self._beam_width,
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
dtype=dtype)
init_attention_probs = get_attention_probs(
self._initial_cell_state, self._coverage_penalty_weight)
if init_attention_probs is None:
init_attention_probs = ()
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
log_probs=log_probs,
finished=finished,
lengths=array_ops.zeros(
[self._batch_size, self._beam_width], dtype=dtypes.int64),
accumulated_attention_probs=init_attention_probs)
return (finished, start_inputs, initial_state)
@property
def output_dtype(self):
# Assume the dtype of the cell is the output_size structure
# containing the input_state's first component's dtype.
# Return that structure and int32 (the id)
dtype = nest.flatten(self._initial_cell_state)[0].dtype
return BeamSearchDecoderOutput(
scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
predicted_ids=dtypes.int32,
parent_ids=dtypes.int32)
class BeamSearchDecoderV2(BeamSearchDecoderMixin, decoder.BaseDecoder):
# Note that the inheritance hierarchy is important here. The Mixin has to be
# the first parent class since we will use super().__init__(), and Mixin which
# is a object will properly invoke the __init__ method of other parent class.
"""BeamSearch sampling decoder.
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
`AttentionWrapper`, then you must ensure that:
- The encoder output has been tiled to `beam_width` via
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
- The `batch_size` argument passed to the `zero_state` method of this
wrapper is equal to `true_batch_size * beam_width`.
- The initial state created with `zero_state` above contains a
`cell_state` value containing properly tiled final state from the
encoder.
An example:
```
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
encoder_outputs, multiplier=beam_width)
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
sequence_length, multiplier=beam_width)
attention_mechanism = MyFavoriteAttentionMechanism(
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
decoder_initial_state = attention_cell.zero_state(
dtype, batch_size=true_batch_size * beam_width)
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
```
Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages
the decoding to cover all inputs.
"""
def __init__(self,
cell,
beam_width,
embedding_fn=None,
output_layer=None,
length_penalty_weight=0.0,
coverage_penalty_weight=0.0,
reorder_tensor_arrays=True,
**kwargs):
"""Initialize the BeamSearchDecoderV2.
Args:
cell: An `RNNCell` instance.
beam_width: Python integer, the number of beams.
embedding_fn: A callable that takes a vector tensor of `ids` (argmax ids).
output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
`tf.keras.layers.Dense`. Optional layer to apply to the RNN output
prior to storing the result or sampling.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
coverage_penalty_weight: Float weight to penalize the coverage of source
sentence. Disabled with 0.0.
reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
state will be reordered according to the beam search path. If the
`TensorArray` can be reordered, the stacked form will be returned.
Otherwise, the `TensorArray` will be returned as is. Set this flag to
`False` if the cell state contains `TensorArray`s that are not amenable
to reordering.
**kwargs: Dict, other keyword arguments for initialization.
Raises:
TypeError: if `cell` is not an instance of `RNNCell`,
or `output_layer` is not an instance of `tf.keras.layers.Layer`.
"""
super(BeamSearchDecoderV2, self).__init__(
cell,
beam_width,
output_layer=output_layer,
length_penalty_weight=length_penalty_weight,
coverage_penalty_weight=coverage_penalty_weight,
reorder_tensor_arrays=reorder_tensor_arrays,
**kwargs)
if embedding_fn is None or callable(embedding_fn):
self._embedding_fn = embedding_fn
else:
raise ValueError("embedding_fn is expected to be a callable, got %s" %
type(embedding_fn))
def initialize(self,
embedding,
start_tokens,
end_token,
initial_state):
"""Initialize the decoder.
Args:
embedding: A tensor from the embedding layer output, which is the
`params` argument for `embedding_lookup`.
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
end_token: `int32` scalar, the token that marks end of decoding.
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
Returns:
`(finished, start_inputs, initial_state)`.
Raises:
ValueError: If `start_tokens` is not a vector or `end_token` is not a
scalar.
"""
if embedding is not None and self._embedding_fn is not None:
raise ValueError(
"embedding and embedding_fn cannot be provided at same time")
elif embedding is not None:
self._embedding_fn = (
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
self._start_tokens = ops.convert_to_tensor(
start_tokens, dtype=dtypes.int32, name="start_tokens")
if self._start_tokens.get_shape().ndims != 1:
raise ValueError("start_tokens must be a vector")
self._end_token = ops.convert_to_tensor(
end_token, dtype=dtypes.int32, name="end_token")
if self._end_token.get_shape().ndims != 0:
raise ValueError("end_token must be a scalar")
self._batch_size = array_ops.size(start_tokens)
self._initial_cell_state = nest.map_structure(
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
self._start_tokens = array_ops.tile(
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
self._start_inputs = self._embedding_fn(self._start_tokens)
self._finished = array_ops.one_hot(
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
depth=self._beam_width,
on_value=False,
off_value=True,
dtype=dtypes.bool)
finished, start_inputs = self._finished, self._start_inputs
dtype = nest.flatten(self._initial_cell_state)[0].dtype
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
depth=self._beam_width,
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
dtype=dtype)
init_attention_probs = get_attention_probs(
self._initial_cell_state, self._coverage_penalty_weight)
if init_attention_probs is None:
init_attention_probs = ()
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
log_probs=log_probs,
finished=finished,
lengths=array_ops.zeros(
[self._batch_size, self._beam_width], dtype=dtypes.int64),
accumulated_attention_probs=init_attention_probs)
return (finished, start_inputs, initial_state)
@property
def output_dtype(self):
# Assume the dtype of the cell is the output_size structure
# containing the input_state's first component's dtype.
# Return that structure and int32 (the id)
dtype = nest.flatten(self._initial_cell_state)[0].dtype
return BeamSearchDecoderOutput(
scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
predicted_ids=dtypes.int32,
parent_ids=dtypes.int32)
def call(self, embeddning, start_tokens, end_token, initial_state, **kwargs):
init_kwargs = kwargs
init_kwargs["start_tokens"] = start_tokens
init_kwargs["end_token"] = end_token
init_kwargs["initial_state"] = initial_state
return decoder.dynamic_decode(self,
output_time_major=self.output_time_major,
impute_finished=self.impute_finished,
maximum_iterations=self.maximum_iterations,
parallel_iterations=self.parallel_iterations,
swap_memory=self.swap_memory,
decoder_init_input=embeddning,
decoder_init_kwargs=init_kwargs)
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
beam_width, end_token, length_penalty_weight, beam_width, end_token, length_penalty_weight,
coverage_penalty_weight): coverage_penalty_weight):
@ -1068,7 +1323,7 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
""" """
if isinstance(gather_from, tensor_array_ops.TensorArray): if isinstance(gather_from, tensor_array_ops.TensorArray):
return gather_from return gather_from
_check_maybe(gather_from) _check_ndims(gather_from)
if gather_from.shape.ndims >= len(gather_shape): if gather_from.shape.ndims >= len(gather_shape):
return _tensor_gather_helper( return _tensor_gather_helper(
gather_indices=gather_indices, gather_indices=gather_indices,