Add BeamSearchDecoderV2 which can be used as a keras layer.

PiperOrigin-RevId: 233982439
This commit is contained in:
Scott Zhu 2019-02-14 10:28:30 -08:00 committed by TensorFlower Gardener
parent 03aa9d18f5
commit 3c9b46c245
3 changed files with 528 additions and 148 deletions

View File

@ -25,10 +25,13 @@ from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import layers
from tensorflow.python.layers import core as layers_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
@ -530,11 +533,10 @@ class BeamSearchDecoderTest(test.TestCase):
return (shape[1], shape[0]) + shape[2:]
return shape
self.assertTrue(
isinstance(final_outputs,
beam_search_decoder.FinalBeamSearchDecoderOutput))
self.assertTrue(
isinstance(final_state, beam_search_decoder.BeamSearchDecoderState))
self.assertIsInstance(
final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput)
self.assertIsInstance(
final_state, beam_search_decoder.BeamSearchDecoderState)
beam_search_decoder_output = final_outputs.beam_search_decoder_output
self.assertEqual(
@ -574,5 +576,119 @@ class BeamSearchDecoderTest(test.TestCase):
with_alignment_history=True)
@test_util.run_all_in_graph_and_eager_modes
class BeamSearchDecoderV2Test(test.TestCase):
def _testDynamicDecodeRNN(self, time_major, has_attention,
with_alignment_history=False):
encoder_sequence_length = np.array([3, 2, 3, 1, 1])
decoder_sequence_length = np.array([2, 0, 1, 2, 3])
batch_size = 5
decoder_max_time = 4
input_depth = 7
cell_depth = 9
attention_depth = 6
vocab_size = 20
end_token = vocab_size - 1
start_token = 0
embedding_dim = 50
max_out = max(decoder_sequence_length)
output_layer = layers.Dense(vocab_size, use_bias=True, activation=None)
beam_width = 3
with self.cached_session():
batch_size_tensor = constant_op.constant(batch_size)
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
cell = rnn_cell.LSTMCell(cell_depth)
initial_state = cell.zero_state(batch_size, dtypes.float32)
coverage_penalty_weight = 0.0
if has_attention:
coverage_penalty_weight = 0.2
inputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, decoder_max_time, input_depth).astype(
np.float32),
shape=(None, None, input_depth))
tiled_inputs = beam_search_decoder.tile_batch(
inputs, multiplier=beam_width)
tiled_sequence_length = beam_search_decoder.tile_batch(
encoder_sequence_length, multiplier=beam_width)
attention_mechanism = attention_wrapper.BahdanauAttention(
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
initial_state = beam_search_decoder.tile_batch(
initial_state, multiplier=beam_width)
cell = attention_wrapper.AttentionWrapper(
cell=cell,
attention_mechanism=attention_mechanism,
attention_layer_size=attention_depth,
alignment_history=with_alignment_history)
cell_state = cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
if has_attention:
cell_state = cell_state.clone(cell_state=initial_state)
bsd = beam_search_decoder.BeamSearchDecoderV2(
cell=cell,
beam_width=beam_width,
output_layer=output_layer,
length_penalty_weight=0.0,
coverage_penalty_weight=coverage_penalty_weight,
output_time_major=time_major,
maximum_iterations=max_out)
final_outputs, final_state, final_sequence_lengths = bsd(
embedding,
start_tokens=array_ops.fill([batch_size_tensor], start_token),
end_token=end_token,
initial_state=cell_state)
def _t(shape):
if time_major:
return (shape[1], shape[0]) + shape[2:]
return shape
self.assertIsInstance(
final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput)
self.assertIsInstance(
final_state, beam_search_decoder.BeamSearchDecoderState)
beam_search_decoder_output = final_outputs.beam_search_decoder_output
expected_seq_length = 3 if context.executing_eagerly() else None
self.assertEqual(
_t((batch_size, expected_seq_length, beam_width)),
tuple(beam_search_decoder_output.scores.get_shape().as_list()))
self.assertEqual(
_t((batch_size, expected_seq_length, beam_width)),
tuple(final_outputs.predicted_ids.get_shape().as_list()))
self.evaluate(variables.global_variables_initializer())
eval_results = self.evaluate({
'final_outputs': final_outputs,
'final_sequence_lengths': final_sequence_lengths
})
max_sequence_length = np.max(eval_results['final_sequence_lengths'])
# A smoke test
self.assertEqual(
_t((batch_size, max_sequence_length, beam_width)),
eval_results['final_outputs'].beam_search_decoder_output.scores.shape)
self.assertEqual(
_t((batch_size, max_sequence_length, beam_width)), eval_results[
'final_outputs'].beam_search_decoder_output.predicted_ids.shape)
def testDynamicDecodeRNNBatchMajorNoAttention(self):
self._testDynamicDecodeRNN(time_major=False, has_attention=False)
def testDynamicDecodeRNNBatchMajorYesAttention(self):
self._testDynamicDecodeRNN(time_major=False, has_attention=True)
def testDynamicDecodeRNNBatchMajorYesAttentionWithAlignmentHistory(self):
self._testDynamicDecodeRNN(
time_major=False,
has_attention=True,
with_alignment_history=True)
if __name__ == '__main__':
test.main()

View File

@ -25,6 +25,7 @@ import math
import numpy as np
from tensorflow.contrib.framework.python.framework import tensor_util
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@ -1919,7 +1920,15 @@ class AttentionWrapperState(
def with_same_shape(old, new):
"""Check and set new tensor's shape."""
if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor):
return tensor_util.with_same_shape(old, new)
if not context.executing_eagerly():
return tensor_util.with_same_shape(old, new)
else:
if old.shape.as_list() != new.shape.as_list():
raise ValueError("The shape of the AttentionWrapperState is "
"expected to be same as the one to clone. "
"self.shape: %s, input.shape: %s" %
(old.shape, new.shape))
return new
return new
return nest.map_structure(
@ -2048,13 +2057,13 @@ def _compute_attention(attention_mechanism, cell_output, attention_state,
# the batched matmul is over memory_time, so the output shape is
# [batch_size, 1, memory_size].
# we then squeeze out the singleton dim.
context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
context = array_ops.squeeze(context, [1])
context_ = math_ops.matmul(expanded_alignments, attention_mechanism.values)
context_ = array_ops.squeeze(context_, [1])
if attention_layer is not None:
attention = attention_layer(array_ops.concat([cell_output, context], 1))
attention = attention_layer(array_ops.concat([cell_output, context_], 1))
else:
attention = context
attention = context_
return attention, alignments, next_attention_state

View File

@ -24,11 +24,12 @@ import numpy as np
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import base as layers_base
from tensorflow.python.keras import layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
@ -182,11 +183,12 @@ def gather_tree_from_array(t, parent_ids, sequence_length):
return ordered
def _check_maybe(t):
def _check_ndims(t):
if t.shape.ndims is None:
raise ValueError(
"Expected tensor (%s) to have known rank, but ndims == None." % t)
def _check_static_batch_beam_maybe(shape, batch_size, beam_width):
"""Raises an exception if dimensions are known statically and can not be
reshaped to [batch_size, beam_size, -1].
@ -205,6 +207,7 @@ def _check_static_batch_beam_maybe(shape, batch_size, beam_width):
return False
return True
def _check_batch_beam(t, batch_size, beam_width):
"""Returns an Assert operation checking that the elements of the stacked
TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point,
@ -229,70 +232,30 @@ def _check_batch_beam(t, batch_size, beam_width):
return control_flow_ops.Assert(condition, [error_message])
class BeamSearchDecoderMixin(object):
"""BeamSearchDecoderMixin contains the common methods for BeamSearchDecoder.
class BeamSearchDecoder(decoder.Decoder):
"""BeamSearch sampling decoder.
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
`AttentionWrapper`, then you must ensure that:
- The encoder output has been tiled to `beam_width` via
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
- The `batch_size` argument passed to the `zero_state` method of this
wrapper is equal to `true_batch_size * beam_width`.
- The initial state created with `zero_state` above contains a
`cell_state` value containing properly tiled final state from the
encoder.
An example:
```
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
encoder_outputs, multiplier=beam_width)
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
sequence_length, multiplier=beam_width)
attention_mechanism = MyFavoriteAttentionMechanism(
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
decoder_initial_state = attention_cell.zero_state(
dtype, batch_size=true_batch_size * beam_width)
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
```
Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages
the translation to cover all inputs.
It is expected to be used a base class for concrete BeamSearchDecoder. Since
this is a mixin class, it is expected to be used together with other class as
base.
"""
def __init__(self,
cell,
embedding,
start_tokens,
end_token,
initial_state,
beam_width,
output_layer=None,
length_penalty_weight=0.0,
coverage_penalty_weight=0.0,
reorder_tensor_arrays=True):
"""Initialize the BeamSearchDecoder.
reorder_tensor_arrays=True,
**kwargs):
"""Initialize the BeamSearchDecoderMixin.
Args:
cell: An `RNNCell` instance.
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
or the `params` argument for `embedding_lookup`.
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
end_token: `int32` scalar, the token that marks end of decoding.
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
beam_width: Python integer, the number of beams.
output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
to storing the result or sampling.
output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
`tf.keras.layers.Dense`. Optional layer to apply to the RNN output
prior to storing the result or sampling.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
coverage_penalty_weight: Float weight to penalize the coverage of source
sentence. Disabled with 0.0.
@ -302,59 +265,35 @@ class BeamSearchDecoder(decoder.Decoder):
Otherwise, the `TensorArray` will be returned as is. Set this flag to
`False` if the cell state contains `TensorArray`s that are not amenable
to reordering.
**kwargs: Dict, other keyword arguments for parent class.
Raises:
TypeError: if `cell` is not an instance of `RNNCell`,
or `output_layer` is not an instance of `tf.layers.Layer`.
ValueError: If `start_tokens` is not a vector or
`end_token` is not a scalar.
or `output_layer` is not an instance of `tf.keras.layers.Layer`.
"""
rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access
if (output_layer is not None and
not isinstance(output_layer, layers_base.Layer)):
not isinstance(output_layer, layers.Layer)):
raise TypeError(
"output_layer must be a Layer, received: %s" % type(output_layer))
self._cell = cell
self._output_layer = output_layer
self._reorder_tensor_arrays = reorder_tensor_arrays
if callable(embedding):
self._embedding_fn = embedding
else:
self._embedding_fn = (
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
self._start_tokens = ops.convert_to_tensor(
start_tokens, dtype=dtypes.int32, name="start_tokens")
if self._start_tokens.get_shape().ndims != 1:
raise ValueError("start_tokens must be a vector")
self._end_token = ops.convert_to_tensor(
end_token, dtype=dtypes.int32, name="end_token")
if self._end_token.get_shape().ndims != 0:
raise ValueError("end_token must be a scalar")
self._batch_size = array_ops.size(start_tokens)
self._start_tokens = None
self._end_token = None
self._batch_size = None
self._beam_width = beam_width
self._length_penalty_weight = length_penalty_weight
self._coverage_penalty_weight = coverage_penalty_weight
self._initial_cell_state = nest.map_structure(
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
self._start_tokens = array_ops.tile(
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
self._start_inputs = self._embedding_fn(self._start_tokens)
self._finished = array_ops.one_hot(
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
depth=self._beam_width,
on_value=False,
off_value=True,
dtype=dtypes.bool)
super(BeamSearchDecoderMixin, self).__init__(**kwargs)
@property
def batch_size(self):
return self._batch_size
def _rnn_output_size(self):
"""Get the output shape from the RNN layer."""
size = self._cell.output_size
if self._output_layer is None:
return size
@ -393,50 +332,6 @@ class BeamSearchDecoder(decoder.Decoder):
predicted_ids=tensor_shape.TensorShape([self._beam_width]),
parent_ids=tensor_shape.TensorShape([self._beam_width]))
@property
def output_dtype(self):
# Assume the dtype of the cell is the output_size structure
# containing the input_state's first component's dtype.
# Return that structure and int32 (the id)
dtype = nest.flatten(self._initial_cell_state)[0].dtype
return BeamSearchDecoderOutput(
scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
predicted_ids=dtypes.int32,
parent_ids=dtypes.int32)
def initialize(self, name=None):
"""Initialize the decoder.
Args:
name: Name scope for any created operations.
Returns:
`(finished, start_inputs, initial_state)`.
"""
finished, start_inputs = self._finished, self._start_inputs
dtype = nest.flatten(self._initial_cell_state)[0].dtype
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
depth=self._beam_width,
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
dtype=dtype)
init_attention_probs = get_attention_probs(
self._initial_cell_state, self._coverage_penalty_weight)
if init_attention_probs is None:
init_attention_probs = ()
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
log_probs=log_probs,
finished=finished,
lengths=array_ops.zeros(
[self._batch_size, self._beam_width], dtype=dtypes.int64),
accumulated_attention_probs=init_attention_probs)
return (finished, start_inputs, initial_state)
def finalize(self, outputs, final_state, sequence_lengths):
"""Finalize and return the predicted_ids.
@ -562,7 +457,7 @@ class BeamSearchDecoder(decoder.Decoder):
"""
if isinstance(t, tensor_array_ops.TensorArray):
return t
_check_maybe(t)
_check_ndims(t)
if t.shape.ndims >= 1:
return self._split_batch_beams(t, s)
else:
@ -586,7 +481,7 @@ class BeamSearchDecoder(decoder.Decoder):
"""
if isinstance(t, tensor_array_ops.TensorArray):
return t
_check_maybe(t)
_check_ndims(t)
if t.shape.ndims >= 2:
return self._merge_batch_beams(t, s)
else:
@ -609,11 +504,18 @@ class BeamSearchDecoder(decoder.Decoder):
if not isinstance(t, tensor_array_ops.TensorArray):
return t
# pylint: disable=protected-access
if (not t._infer_shape or not t._element_shape
or t._element_shape[0].ndims is None
or t._element_shape[0].ndims < 1):
# This is a bad hack due to the implementation detail of eager/graph TA.
# TODO(b/124374427): Update this to use public property of TensorArray.
if context.executing_eagerly():
element_shape = t._element_shape
else:
element_shape = t._element_shape[0]
if (not t._infer_shape
or not t._element_shape
or element_shape.ndims is None
or element_shape.ndims < 1):
shape = (
t._element_shape[0] if t._infer_shape and t._element_shape
element_shape if t._infer_shape and t._element_shape
else tensor_shape.TensorShape(None))
tf_logging.warn("The TensorArray %s in the cell state is not amenable to "
"sorting based on the beam search result. For a "
@ -621,10 +523,10 @@ class BeamSearchDecoder(decoder.Decoder):
"defined and have at least a rank of 1, but saw shape: %s"
% (t.handle.name, shape))
return t
shape = t._element_shape[0]
# pylint: enable=protected-access
if not _check_static_batch_beam_maybe(
shape, tensor_util.constant_value(self._batch_size), self._beam_width):
element_shape, tensor_util.constant_value(self._batch_size),
self._beam_width):
return t
t = t.stack()
with ops.control_dependencies(
@ -684,6 +586,359 @@ class BeamSearchDecoder(decoder.Decoder):
return (beam_search_output, beam_search_state, next_inputs, finished)
class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.Decoder):
# Note that the inheritance hierarchy is important here. The Mixin has to be
# the first parent class since we will use super().__init__(), and Mixin which
# is a object will properly invoke the __init__ method of other parent class.
"""BeamSearch sampling decoder.
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
`AttentionWrapper`, then you must ensure that:
- The encoder output has been tiled to `beam_width` via
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
- The `batch_size` argument passed to the `zero_state` method of this
wrapper is equal to `true_batch_size * beam_width`.
- The initial state created with `zero_state` above contains a
`cell_state` value containing properly tiled final state from the
encoder.
An example:
```
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
encoder_outputs, multiplier=beam_width)
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
sequence_length, multiplier=beam_width)
attention_mechanism = MyFavoriteAttentionMechanism(
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
decoder_initial_state = attention_cell.zero_state(
dtype, batch_size=true_batch_size * beam_width)
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
```
Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages
the decoder to cover all inputs.
"""
def __init__(self,
cell,
embedding,
start_tokens,
end_token,
initial_state,
beam_width,
output_layer=None,
length_penalty_weight=0.0,
coverage_penalty_weight=0.0,
reorder_tensor_arrays=True):
"""Initialize the BeamSearchDecoder.
Args:
cell: An `RNNCell` instance.
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
or the `params` argument for `embedding_lookup`.
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
end_token: `int32` scalar, the token that marks end of decoding.
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
beam_width: Python integer, the number of beams.
output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
`tf.keras.layers.Dense`. Optional layer to apply to the RNN output
prior to storing the result or sampling.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
coverage_penalty_weight: Float weight to penalize the coverage of source
sentence. Disabled with 0.0.
reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
state will be reordered according to the beam search path. If the
`TensorArray` can be reordered, the stacked form will be returned.
Otherwise, the `TensorArray` will be returned as is. Set this flag to
`False` if the cell state contains `TensorArray`s that are not amenable
to reordering.
Raises:
TypeError: if `cell` is not an instance of `RNNCell`,
or `output_layer` is not an instance of `tf.keras.layers.Layer`.
ValueError: If `start_tokens` is not a vector or
`end_token` is not a scalar.
"""
super(BeamSearchDecoder, self).__init__(
cell,
beam_width,
output_layer=output_layer,
length_penalty_weight=length_penalty_weight,
coverage_penalty_weight=coverage_penalty_weight,
reorder_tensor_arrays=reorder_tensor_arrays)
if callable(embedding):
self._embedding_fn = embedding
else:
self._embedding_fn = (
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
self._start_tokens = ops.convert_to_tensor(
start_tokens, dtype=dtypes.int32, name="start_tokens")
if self._start_tokens.get_shape().ndims != 1:
raise ValueError("start_tokens must be a vector")
self._end_token = ops.convert_to_tensor(
end_token, dtype=dtypes.int32, name="end_token")
if self._end_token.get_shape().ndims != 0:
raise ValueError("end_token must be a scalar")
self._batch_size = array_ops.size(start_tokens)
self._initial_cell_state = nest.map_structure(
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
self._start_tokens = array_ops.tile(
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
self._start_inputs = self._embedding_fn(self._start_tokens)
self._finished = array_ops.one_hot(
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
depth=self._beam_width,
on_value=False,
off_value=True,
dtype=dtypes.bool)
def initialize(self, name=None):
"""Initialize the decoder.
Args:
name: Name scope for any created operations.
Returns:
`(finished, start_inputs, initial_state)`.
"""
finished, start_inputs = self._finished, self._start_inputs
dtype = nest.flatten(self._initial_cell_state)[0].dtype
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
depth=self._beam_width,
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
dtype=dtype)
init_attention_probs = get_attention_probs(
self._initial_cell_state, self._coverage_penalty_weight)
if init_attention_probs is None:
init_attention_probs = ()
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
log_probs=log_probs,
finished=finished,
lengths=array_ops.zeros(
[self._batch_size, self._beam_width], dtype=dtypes.int64),
accumulated_attention_probs=init_attention_probs)
return (finished, start_inputs, initial_state)
@property
def output_dtype(self):
# Assume the dtype of the cell is the output_size structure
# containing the input_state's first component's dtype.
# Return that structure and int32 (the id)
dtype = nest.flatten(self._initial_cell_state)[0].dtype
return BeamSearchDecoderOutput(
scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
predicted_ids=dtypes.int32,
parent_ids=dtypes.int32)
class BeamSearchDecoderV2(BeamSearchDecoderMixin, decoder.BaseDecoder):
# Note that the inheritance hierarchy is important here. The Mixin has to be
# the first parent class since we will use super().__init__(), and Mixin which
# is a object will properly invoke the __init__ method of other parent class.
"""BeamSearch sampling decoder.
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
`AttentionWrapper`, then you must ensure that:
- The encoder output has been tiled to `beam_width` via
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
- The `batch_size` argument passed to the `zero_state` method of this
wrapper is equal to `true_batch_size * beam_width`.
- The initial state created with `zero_state` above contains a
`cell_state` value containing properly tiled final state from the
encoder.
An example:
```
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
encoder_outputs, multiplier=beam_width)
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
sequence_length, multiplier=beam_width)
attention_mechanism = MyFavoriteAttentionMechanism(
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
decoder_initial_state = attention_cell.zero_state(
dtype, batch_size=true_batch_size * beam_width)
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
```
Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages
the decoding to cover all inputs.
"""
def __init__(self,
cell,
beam_width,
embedding_fn=None,
output_layer=None,
length_penalty_weight=0.0,
coverage_penalty_weight=0.0,
reorder_tensor_arrays=True,
**kwargs):
"""Initialize the BeamSearchDecoderV2.
Args:
cell: An `RNNCell` instance.
beam_width: Python integer, the number of beams.
embedding_fn: A callable that takes a vector tensor of `ids` (argmax ids).
output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
`tf.keras.layers.Dense`. Optional layer to apply to the RNN output
prior to storing the result or sampling.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
coverage_penalty_weight: Float weight to penalize the coverage of source
sentence. Disabled with 0.0.
reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
state will be reordered according to the beam search path. If the
`TensorArray` can be reordered, the stacked form will be returned.
Otherwise, the `TensorArray` will be returned as is. Set this flag to
`False` if the cell state contains `TensorArray`s that are not amenable
to reordering.
**kwargs: Dict, other keyword arguments for initialization.
Raises:
TypeError: if `cell` is not an instance of `RNNCell`,
or `output_layer` is not an instance of `tf.keras.layers.Layer`.
"""
super(BeamSearchDecoderV2, self).__init__(
cell,
beam_width,
output_layer=output_layer,
length_penalty_weight=length_penalty_weight,
coverage_penalty_weight=coverage_penalty_weight,
reorder_tensor_arrays=reorder_tensor_arrays,
**kwargs)
if embedding_fn is None or callable(embedding_fn):
self._embedding_fn = embedding_fn
else:
raise ValueError("embedding_fn is expected to be a callable, got %s" %
type(embedding_fn))
def initialize(self,
embedding,
start_tokens,
end_token,
initial_state):
"""Initialize the decoder.
Args:
embedding: A tensor from the embedding layer output, which is the
`params` argument for `embedding_lookup`.
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
end_token: `int32` scalar, the token that marks end of decoding.
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
Returns:
`(finished, start_inputs, initial_state)`.
Raises:
ValueError: If `start_tokens` is not a vector or `end_token` is not a
scalar.
"""
if embedding is not None and self._embedding_fn is not None:
raise ValueError(
"embedding and embedding_fn cannot be provided at same time")
elif embedding is not None:
self._embedding_fn = (
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
self._start_tokens = ops.convert_to_tensor(
start_tokens, dtype=dtypes.int32, name="start_tokens")
if self._start_tokens.get_shape().ndims != 1:
raise ValueError("start_tokens must be a vector")
self._end_token = ops.convert_to_tensor(
end_token, dtype=dtypes.int32, name="end_token")
if self._end_token.get_shape().ndims != 0:
raise ValueError("end_token must be a scalar")
self._batch_size = array_ops.size(start_tokens)
self._initial_cell_state = nest.map_structure(
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
self._start_tokens = array_ops.tile(
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
self._start_inputs = self._embedding_fn(self._start_tokens)
self._finished = array_ops.one_hot(
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
depth=self._beam_width,
on_value=False,
off_value=True,
dtype=dtypes.bool)
finished, start_inputs = self._finished, self._start_inputs
dtype = nest.flatten(self._initial_cell_state)[0].dtype
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
depth=self._beam_width,
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
dtype=dtype)
init_attention_probs = get_attention_probs(
self._initial_cell_state, self._coverage_penalty_weight)
if init_attention_probs is None:
init_attention_probs = ()
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
log_probs=log_probs,
finished=finished,
lengths=array_ops.zeros(
[self._batch_size, self._beam_width], dtype=dtypes.int64),
accumulated_attention_probs=init_attention_probs)
return (finished, start_inputs, initial_state)
@property
def output_dtype(self):
# Assume the dtype of the cell is the output_size structure
# containing the input_state's first component's dtype.
# Return that structure and int32 (the id)
dtype = nest.flatten(self._initial_cell_state)[0].dtype
return BeamSearchDecoderOutput(
scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
predicted_ids=dtypes.int32,
parent_ids=dtypes.int32)
def call(self, embeddning, start_tokens, end_token, initial_state, **kwargs):
init_kwargs = kwargs
init_kwargs["start_tokens"] = start_tokens
init_kwargs["end_token"] = end_token
init_kwargs["initial_state"] = initial_state
return decoder.dynamic_decode(self,
output_time_major=self.output_time_major,
impute_finished=self.impute_finished,
maximum_iterations=self.maximum_iterations,
parallel_iterations=self.parallel_iterations,
swap_memory=self.swap_memory,
decoder_init_input=embeddning,
decoder_init_kwargs=init_kwargs)
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
beam_width, end_token, length_penalty_weight,
coverage_penalty_weight):
@ -1068,7 +1323,7 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
"""
if isinstance(gather_from, tensor_array_ops.TensorArray):
return gather_from
_check_maybe(gather_from)
_check_ndims(gather_from)
if gather_from.shape.ndims >= len(gather_shape):
return _tensor_gather_helper(
gather_indices=gather_indices,