From 3c9b46c245a57df746946403042cca71a94622d2 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Thu, 14 Feb 2019 10:28:30 -0800 Subject: [PATCH] Add BeamSearchDecoderV2 which can be used as a keras layer. PiperOrigin-RevId: 233982439 --- .../kernel_tests/beam_search_decoder_test.py | 126 ++++- .../seq2seq/python/ops/attention_wrapper.py | 19 +- .../seq2seq/python/ops/beam_search_decoder.py | 531 +++++++++++++----- 3 files changed, 528 insertions(+), 148 deletions(-) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 5e28e651c66..56f2a0acc9f 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -25,10 +25,13 @@ from tensorflow.contrib.seq2seq.python.ops import attention_wrapper from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.keras import layers from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn_ops @@ -530,11 +533,10 @@ class BeamSearchDecoderTest(test.TestCase): return (shape[1], shape[0]) + shape[2:] return shape - self.assertTrue( - isinstance(final_outputs, - beam_search_decoder.FinalBeamSearchDecoderOutput)) - self.assertTrue( - isinstance(final_state, beam_search_decoder.BeamSearchDecoderState)) + self.assertIsInstance( + final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) + self.assertIsInstance( + final_state, beam_search_decoder.BeamSearchDecoderState) beam_search_decoder_output = final_outputs.beam_search_decoder_output self.assertEqual( @@ -574,5 +576,119 @@ class BeamSearchDecoderTest(test.TestCase): with_alignment_history=True) +@test_util.run_all_in_graph_and_eager_modes +class BeamSearchDecoderV2Test(test.TestCase): + + def _testDynamicDecodeRNN(self, time_major, has_attention, + with_alignment_history=False): + encoder_sequence_length = np.array([3, 2, 3, 1, 1]) + decoder_sequence_length = np.array([2, 0, 1, 2, 3]) + batch_size = 5 + decoder_max_time = 4 + input_depth = 7 + cell_depth = 9 + attention_depth = 6 + vocab_size = 20 + end_token = vocab_size - 1 + start_token = 0 + embedding_dim = 50 + max_out = max(decoder_sequence_length) + output_layer = layers.Dense(vocab_size, use_bias=True, activation=None) + beam_width = 3 + + with self.cached_session(): + batch_size_tensor = constant_op.constant(batch_size) + embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) + cell = rnn_cell.LSTMCell(cell_depth) + initial_state = cell.zero_state(batch_size, dtypes.float32) + coverage_penalty_weight = 0.0 + if has_attention: + coverage_penalty_weight = 0.2 + inputs = array_ops.placeholder_with_default( + np.random.randn(batch_size, decoder_max_time, input_depth).astype( + np.float32), + shape=(None, None, input_depth)) + tiled_inputs = beam_search_decoder.tile_batch( + inputs, multiplier=beam_width) + tiled_sequence_length = beam_search_decoder.tile_batch( + encoder_sequence_length, multiplier=beam_width) + attention_mechanism = attention_wrapper.BahdanauAttention( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + initial_state = beam_search_decoder.tile_batch( + initial_state, multiplier=beam_width) + cell = attention_wrapper.AttentionWrapper( + cell=cell, + attention_mechanism=attention_mechanism, + attention_layer_size=attention_depth, + alignment_history=with_alignment_history) + cell_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) + if has_attention: + cell_state = cell_state.clone(cell_state=initial_state) + bsd = beam_search_decoder.BeamSearchDecoderV2( + cell=cell, + beam_width=beam_width, + output_layer=output_layer, + length_penalty_weight=0.0, + coverage_penalty_weight=coverage_penalty_weight, + output_time_major=time_major, + maximum_iterations=max_out) + + final_outputs, final_state, final_sequence_lengths = bsd( + embedding, + start_tokens=array_ops.fill([batch_size_tensor], start_token), + end_token=end_token, + initial_state=cell_state) + + def _t(shape): + if time_major: + return (shape[1], shape[0]) + shape[2:] + return shape + + self.assertIsInstance( + final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) + self.assertIsInstance( + final_state, beam_search_decoder.BeamSearchDecoderState) + + beam_search_decoder_output = final_outputs.beam_search_decoder_output + expected_seq_length = 3 if context.executing_eagerly() else None + self.assertEqual( + _t((batch_size, expected_seq_length, beam_width)), + tuple(beam_search_decoder_output.scores.get_shape().as_list())) + self.assertEqual( + _t((batch_size, expected_seq_length, beam_width)), + tuple(final_outputs.predicted_ids.get_shape().as_list())) + + self.evaluate(variables.global_variables_initializer()) + eval_results = self.evaluate({ + 'final_outputs': final_outputs, + 'final_sequence_lengths': final_sequence_lengths + }) + + max_sequence_length = np.max(eval_results['final_sequence_lengths']) + + # A smoke test + self.assertEqual( + _t((batch_size, max_sequence_length, beam_width)), + eval_results['final_outputs'].beam_search_decoder_output.scores.shape) + self.assertEqual( + _t((batch_size, max_sequence_length, beam_width)), eval_results[ + 'final_outputs'].beam_search_decoder_output.predicted_ids.shape) + + def testDynamicDecodeRNNBatchMajorNoAttention(self): + self._testDynamicDecodeRNN(time_major=False, has_attention=False) + + def testDynamicDecodeRNNBatchMajorYesAttention(self): + self._testDynamicDecodeRNN(time_major=False, has_attention=True) + + def testDynamicDecodeRNNBatchMajorYesAttentionWithAlignmentHistory(self): + self._testDynamicDecodeRNN( + time_major=False, + has_attention=True, + with_alignment_history=True) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 5bcf0af8897..79c2ac2f500 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -25,6 +25,7 @@ import math import numpy as np from tensorflow.contrib.framework.python.framework import tensor_util +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -1919,7 +1920,15 @@ class AttentionWrapperState( def with_same_shape(old, new): """Check and set new tensor's shape.""" if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor): - return tensor_util.with_same_shape(old, new) + if not context.executing_eagerly(): + return tensor_util.with_same_shape(old, new) + else: + if old.shape.as_list() != new.shape.as_list(): + raise ValueError("The shape of the AttentionWrapperState is " + "expected to be same as the one to clone. " + "self.shape: %s, input.shape: %s" % + (old.shape, new.shape)) + return new return new return nest.map_structure( @@ -2048,13 +2057,13 @@ def _compute_attention(attention_mechanism, cell_output, attention_state, # the batched matmul is over memory_time, so the output shape is # [batch_size, 1, memory_size]. # we then squeeze out the singleton dim. - context = math_ops.matmul(expanded_alignments, attention_mechanism.values) - context = array_ops.squeeze(context, [1]) + context_ = math_ops.matmul(expanded_alignments, attention_mechanism.values) + context_ = array_ops.squeeze(context_, [1]) if attention_layer is not None: - attention = attention_layer(array_ops.concat([cell_output, context], 1)) + attention = attention_layer(array_ops.concat([cell_output, context_], 1)) else: - attention = context + attention = context_ return attention, alignments, next_attention_state diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 8f8f0577029..1d773a44989 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -24,11 +24,12 @@ import numpy as np from tensorflow.contrib.seq2seq.python.ops import attention_wrapper from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.layers import base as layers_base +from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops @@ -182,11 +183,12 @@ def gather_tree_from_array(t, parent_ids, sequence_length): return ordered -def _check_maybe(t): +def _check_ndims(t): if t.shape.ndims is None: raise ValueError( "Expected tensor (%s) to have known rank, but ndims == None." % t) + def _check_static_batch_beam_maybe(shape, batch_size, beam_width): """Raises an exception if dimensions are known statically and can not be reshaped to [batch_size, beam_size, -1]. @@ -205,6 +207,7 @@ def _check_static_batch_beam_maybe(shape, batch_size, beam_width): return False return True + def _check_batch_beam(t, batch_size, beam_width): """Returns an Assert operation checking that the elements of the stacked TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point, @@ -229,70 +232,30 @@ def _check_batch_beam(t, batch_size, beam_width): return control_flow_ops.Assert(condition, [error_message]) +class BeamSearchDecoderMixin(object): + """BeamSearchDecoderMixin contains the common methods for BeamSearchDecoder. -class BeamSearchDecoder(decoder.Decoder): - """BeamSearch sampling decoder. - - **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in - `AttentionWrapper`, then you must ensure that: - - - The encoder output has been tiled to `beam_width` via - `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). - - The `batch_size` argument passed to the `zero_state` method of this - wrapper is equal to `true_batch_size * beam_width`. - - The initial state created with `zero_state` above contains a - `cell_state` value containing properly tiled final state from the - encoder. - - An example: - - ``` - tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( - encoder_outputs, multiplier=beam_width) - tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( - encoder_final_state, multiplier=beam_width) - tiled_sequence_length = tf.contrib.seq2seq.tile_batch( - sequence_length, multiplier=beam_width) - attention_mechanism = MyFavoriteAttentionMechanism( - num_units=attention_depth, - memory=tiled_inputs, - memory_sequence_length=tiled_sequence_length) - attention_cell = AttentionWrapper(cell, attention_mechanism, ...) - decoder_initial_state = attention_cell.zero_state( - dtype, batch_size=true_batch_size * beam_width) - decoder_initial_state = decoder_initial_state.clone( - cell_state=tiled_encoder_final_state) - ``` - - Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use - when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages - the translation to cover all inputs. + It is expected to be used a base class for concrete BeamSearchDecoder. Since + this is a mixin class, it is expected to be used together with other class as + base. """ def __init__(self, cell, - embedding, - start_tokens, - end_token, - initial_state, beam_width, output_layer=None, length_penalty_weight=0.0, coverage_penalty_weight=0.0, - reorder_tensor_arrays=True): - """Initialize the BeamSearchDecoder. + reorder_tensor_arrays=True, + **kwargs): + """Initialize the BeamSearchDecoderMixin. Args: cell: An `RNNCell` instance. - embedding: A callable that takes a vector tensor of `ids` (argmax ids), - or the `params` argument for `embedding_lookup`. - start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. - end_token: `int32` scalar, the token that marks end of decoding. - initial_state: A (possibly nested tuple of...) tensors and TensorArrays. beam_width: Python integer, the number of beams. - output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., - `tf.layers.Dense`. Optional layer to apply to the RNN output prior - to storing the result or sampling. + output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., + `tf.keras.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. coverage_penalty_weight: Float weight to penalize the coverage of source sentence. Disabled with 0.0. @@ -302,59 +265,35 @@ class BeamSearchDecoder(decoder.Decoder): Otherwise, the `TensorArray` will be returned as is. Set this flag to `False` if the cell state contains `TensorArray`s that are not amenable to reordering. + **kwargs: Dict, other keyword arguments for parent class. Raises: TypeError: if `cell` is not an instance of `RNNCell`, - or `output_layer` is not an instance of `tf.layers.Layer`. - ValueError: If `start_tokens` is not a vector or - `end_token` is not a scalar. + or `output_layer` is not an instance of `tf.keras.layers.Layer`. """ rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access if (output_layer is not None and - not isinstance(output_layer, layers_base.Layer)): + not isinstance(output_layer, layers.Layer)): raise TypeError( "output_layer must be a Layer, received: %s" % type(output_layer)) self._cell = cell self._output_layer = output_layer self._reorder_tensor_arrays = reorder_tensor_arrays - if callable(embedding): - self._embedding_fn = embedding - else: - self._embedding_fn = ( - lambda ids: embedding_ops.embedding_lookup(embedding, ids)) - - self._start_tokens = ops.convert_to_tensor( - start_tokens, dtype=dtypes.int32, name="start_tokens") - if self._start_tokens.get_shape().ndims != 1: - raise ValueError("start_tokens must be a vector") - self._end_token = ops.convert_to_tensor( - end_token, dtype=dtypes.int32, name="end_token") - if self._end_token.get_shape().ndims != 0: - raise ValueError("end_token must be a scalar") - - self._batch_size = array_ops.size(start_tokens) + self._start_tokens = None + self._end_token = None + self._batch_size = None self._beam_width = beam_width self._length_penalty_weight = length_penalty_weight self._coverage_penalty_weight = coverage_penalty_weight - self._initial_cell_state = nest.map_structure( - self._maybe_split_batch_beams, initial_state, self._cell.state_size) - self._start_tokens = array_ops.tile( - array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) - self._start_inputs = self._embedding_fn(self._start_tokens) - - self._finished = array_ops.one_hot( - array_ops.zeros([self._batch_size], dtype=dtypes.int32), - depth=self._beam_width, - on_value=False, - off_value=True, - dtype=dtypes.bool) + super(BeamSearchDecoderMixin, self).__init__(**kwargs) @property def batch_size(self): return self._batch_size def _rnn_output_size(self): + """Get the output shape from the RNN layer.""" size = self._cell.output_size if self._output_layer is None: return size @@ -393,50 +332,6 @@ class BeamSearchDecoder(decoder.Decoder): predicted_ids=tensor_shape.TensorShape([self._beam_width]), parent_ids=tensor_shape.TensorShape([self._beam_width])) - @property - def output_dtype(self): - # Assume the dtype of the cell is the output_size structure - # containing the input_state's first component's dtype. - # Return that structure and int32 (the id) - dtype = nest.flatten(self._initial_cell_state)[0].dtype - return BeamSearchDecoderOutput( - scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), - predicted_ids=dtypes.int32, - parent_ids=dtypes.int32) - - def initialize(self, name=None): - """Initialize the decoder. - - Args: - name: Name scope for any created operations. - - Returns: - `(finished, start_inputs, initial_state)`. - """ - finished, start_inputs = self._finished, self._start_inputs - - dtype = nest.flatten(self._initial_cell_state)[0].dtype - log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) - array_ops.zeros([self._batch_size], dtype=dtypes.int32), - depth=self._beam_width, - on_value=ops.convert_to_tensor(0.0, dtype=dtype), - off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), - dtype=dtype) - init_attention_probs = get_attention_probs( - self._initial_cell_state, self._coverage_penalty_weight) - if init_attention_probs is None: - init_attention_probs = () - - initial_state = BeamSearchDecoderState( - cell_state=self._initial_cell_state, - log_probs=log_probs, - finished=finished, - lengths=array_ops.zeros( - [self._batch_size, self._beam_width], dtype=dtypes.int64), - accumulated_attention_probs=init_attention_probs) - - return (finished, start_inputs, initial_state) - def finalize(self, outputs, final_state, sequence_lengths): """Finalize and return the predicted_ids. @@ -562,7 +457,7 @@ class BeamSearchDecoder(decoder.Decoder): """ if isinstance(t, tensor_array_ops.TensorArray): return t - _check_maybe(t) + _check_ndims(t) if t.shape.ndims >= 1: return self._split_batch_beams(t, s) else: @@ -586,7 +481,7 @@ class BeamSearchDecoder(decoder.Decoder): """ if isinstance(t, tensor_array_ops.TensorArray): return t - _check_maybe(t) + _check_ndims(t) if t.shape.ndims >= 2: return self._merge_batch_beams(t, s) else: @@ -609,11 +504,18 @@ class BeamSearchDecoder(decoder.Decoder): if not isinstance(t, tensor_array_ops.TensorArray): return t # pylint: disable=protected-access - if (not t._infer_shape or not t._element_shape - or t._element_shape[0].ndims is None - or t._element_shape[0].ndims < 1): + # This is a bad hack due to the implementation detail of eager/graph TA. + # TODO(b/124374427): Update this to use public property of TensorArray. + if context.executing_eagerly(): + element_shape = t._element_shape + else: + element_shape = t._element_shape[0] + if (not t._infer_shape + or not t._element_shape + or element_shape.ndims is None + or element_shape.ndims < 1): shape = ( - t._element_shape[0] if t._infer_shape and t._element_shape + element_shape if t._infer_shape and t._element_shape else tensor_shape.TensorShape(None)) tf_logging.warn("The TensorArray %s in the cell state is not amenable to " "sorting based on the beam search result. For a " @@ -621,10 +523,10 @@ class BeamSearchDecoder(decoder.Decoder): "defined and have at least a rank of 1, but saw shape: %s" % (t.handle.name, shape)) return t - shape = t._element_shape[0] # pylint: enable=protected-access if not _check_static_batch_beam_maybe( - shape, tensor_util.constant_value(self._batch_size), self._beam_width): + element_shape, tensor_util.constant_value(self._batch_size), + self._beam_width): return t t = t.stack() with ops.control_dependencies( @@ -684,6 +586,359 @@ class BeamSearchDecoder(decoder.Decoder): return (beam_search_output, beam_search_state, next_inputs, finished) +class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.Decoder): + # Note that the inheritance hierarchy is important here. The Mixin has to be + # the first parent class since we will use super().__init__(), and Mixin which + # is a object will properly invoke the __init__ method of other parent class. + """BeamSearch sampling decoder. + + **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in + `AttentionWrapper`, then you must ensure that: + + - The encoder output has been tiled to `beam_width` via + `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). + - The `batch_size` argument passed to the `zero_state` method of this + wrapper is equal to `true_batch_size * beam_width`. + - The initial state created with `zero_state` above contains a + `cell_state` value containing properly tiled final state from the + encoder. + + An example: + + ``` + tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( + encoder_outputs, multiplier=beam_width) + tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( + encoder_final_state, multiplier=beam_width) + tiled_sequence_length = tf.contrib.seq2seq.tile_batch( + sequence_length, multiplier=beam_width) + attention_mechanism = MyFavoriteAttentionMechanism( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + attention_cell = AttentionWrapper(cell, attention_mechanism, ...) + decoder_initial_state = attention_cell.zero_state( + dtype, batch_size=true_batch_size * beam_width) + decoder_initial_state = decoder_initial_state.clone( + cell_state=tiled_encoder_final_state) + ``` + + Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use + when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages + the decoder to cover all inputs. + """ + + def __init__(self, + cell, + embedding, + start_tokens, + end_token, + initial_state, + beam_width, + output_layer=None, + length_penalty_weight=0.0, + coverage_penalty_weight=0.0, + reorder_tensor_arrays=True): + """Initialize the BeamSearchDecoder. + + Args: + cell: An `RNNCell` instance. + embedding: A callable that takes a vector tensor of `ids` (argmax ids), + or the `params` argument for `embedding_lookup`. + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + initial_state: A (possibly nested tuple of...) tensors and TensorArrays. + beam_width: Python integer, the number of beams. + output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., + `tf.keras.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. + length_penalty_weight: Float weight to penalize length. Disabled with 0.0. + coverage_penalty_weight: Float weight to penalize the coverage of source + sentence. Disabled with 0.0. + reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell + state will be reordered according to the beam search path. If the + `TensorArray` can be reordered, the stacked form will be returned. + Otherwise, the `TensorArray` will be returned as is. Set this flag to + `False` if the cell state contains `TensorArray`s that are not amenable + to reordering. + + Raises: + TypeError: if `cell` is not an instance of `RNNCell`, + or `output_layer` is not an instance of `tf.keras.layers.Layer`. + ValueError: If `start_tokens` is not a vector or + `end_token` is not a scalar. + """ + super(BeamSearchDecoder, self).__init__( + cell, + beam_width, + output_layer=output_layer, + length_penalty_weight=length_penalty_weight, + coverage_penalty_weight=coverage_penalty_weight, + reorder_tensor_arrays=reorder_tensor_arrays) + + if callable(embedding): + self._embedding_fn = embedding + else: + self._embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + + self._start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + if self._start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self._end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + + self._batch_size = array_ops.size(start_tokens) + self._initial_cell_state = nest.map_structure( + self._maybe_split_batch_beams, initial_state, self._cell.state_size) + self._start_tokens = array_ops.tile( + array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) + self._start_inputs = self._embedding_fn(self._start_tokens) + + self._finished = array_ops.one_hot( + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=False, + off_value=True, + dtype=dtypes.bool) + + def initialize(self, name=None): + """Initialize the decoder. + + Args: + name: Name scope for any created operations. + + Returns: + `(finished, start_inputs, initial_state)`. + """ + finished, start_inputs = self._finished, self._start_inputs + + dtype = nest.flatten(self._initial_cell_state)[0].dtype + log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=ops.convert_to_tensor(0.0, dtype=dtype), + off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), + dtype=dtype) + init_attention_probs = get_attention_probs( + self._initial_cell_state, self._coverage_penalty_weight) + if init_attention_probs is None: + init_attention_probs = () + + initial_state = BeamSearchDecoderState( + cell_state=self._initial_cell_state, + log_probs=log_probs, + finished=finished, + lengths=array_ops.zeros( + [self._batch_size, self._beam_width], dtype=dtypes.int64), + accumulated_attention_probs=init_attention_probs) + + return (finished, start_inputs, initial_state) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and int32 (the id) + dtype = nest.flatten(self._initial_cell_state)[0].dtype + return BeamSearchDecoderOutput( + scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), + predicted_ids=dtypes.int32, + parent_ids=dtypes.int32) + + +class BeamSearchDecoderV2(BeamSearchDecoderMixin, decoder.BaseDecoder): + # Note that the inheritance hierarchy is important here. The Mixin has to be + # the first parent class since we will use super().__init__(), and Mixin which + # is a object will properly invoke the __init__ method of other parent class. + """BeamSearch sampling decoder. + + **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in + `AttentionWrapper`, then you must ensure that: + + - The encoder output has been tiled to `beam_width` via + `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). + - The `batch_size` argument passed to the `zero_state` method of this + wrapper is equal to `true_batch_size * beam_width`. + - The initial state created with `zero_state` above contains a + `cell_state` value containing properly tiled final state from the + encoder. + + An example: + + ``` + tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( + encoder_outputs, multiplier=beam_width) + tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( + encoder_final_state, multiplier=beam_width) + tiled_sequence_length = tf.contrib.seq2seq.tile_batch( + sequence_length, multiplier=beam_width) + attention_mechanism = MyFavoriteAttentionMechanism( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + attention_cell = AttentionWrapper(cell, attention_mechanism, ...) + decoder_initial_state = attention_cell.zero_state( + dtype, batch_size=true_batch_size * beam_width) + decoder_initial_state = decoder_initial_state.clone( + cell_state=tiled_encoder_final_state) + ``` + + Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use + when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages + the decoding to cover all inputs. + """ + + def __init__(self, + cell, + beam_width, + embedding_fn=None, + output_layer=None, + length_penalty_weight=0.0, + coverage_penalty_weight=0.0, + reorder_tensor_arrays=True, + **kwargs): + """Initialize the BeamSearchDecoderV2. + + Args: + cell: An `RNNCell` instance. + beam_width: Python integer, the number of beams. + embedding_fn: A callable that takes a vector tensor of `ids` (argmax ids). + output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., + `tf.keras.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. + length_penalty_weight: Float weight to penalize length. Disabled with 0.0. + coverage_penalty_weight: Float weight to penalize the coverage of source + sentence. Disabled with 0.0. + reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell + state will be reordered according to the beam search path. If the + `TensorArray` can be reordered, the stacked form will be returned. + Otherwise, the `TensorArray` will be returned as is. Set this flag to + `False` if the cell state contains `TensorArray`s that are not amenable + to reordering. + **kwargs: Dict, other keyword arguments for initialization. + + Raises: + TypeError: if `cell` is not an instance of `RNNCell`, + or `output_layer` is not an instance of `tf.keras.layers.Layer`. + """ + super(BeamSearchDecoderV2, self).__init__( + cell, + beam_width, + output_layer=output_layer, + length_penalty_weight=length_penalty_weight, + coverage_penalty_weight=coverage_penalty_weight, + reorder_tensor_arrays=reorder_tensor_arrays, + **kwargs) + + if embedding_fn is None or callable(embedding_fn): + self._embedding_fn = embedding_fn + else: + raise ValueError("embedding_fn is expected to be a callable, got %s" % + type(embedding_fn)) + + def initialize(self, + embedding, + start_tokens, + end_token, + initial_state): + """Initialize the decoder. + + Args: + embedding: A tensor from the embedding layer output, which is the + `params` argument for `embedding_lookup`. + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + initial_state: A (possibly nested tuple of...) tensors and TensorArrays. + Returns: + `(finished, start_inputs, initial_state)`. + Raises: + ValueError: If `start_tokens` is not a vector or `end_token` is not a + scalar. + """ + if embedding is not None and self._embedding_fn is not None: + raise ValueError( + "embedding and embedding_fn cannot be provided at same time") + elif embedding is not None: + self._embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + + self._start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + if self._start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self._end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + + self._batch_size = array_ops.size(start_tokens) + self._initial_cell_state = nest.map_structure( + self._maybe_split_batch_beams, initial_state, self._cell.state_size) + self._start_tokens = array_ops.tile( + array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) + self._start_inputs = self._embedding_fn(self._start_tokens) + + self._finished = array_ops.one_hot( + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=False, + off_value=True, + dtype=dtypes.bool) + + finished, start_inputs = self._finished, self._start_inputs + + dtype = nest.flatten(self._initial_cell_state)[0].dtype + log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=ops.convert_to_tensor(0.0, dtype=dtype), + off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), + dtype=dtype) + init_attention_probs = get_attention_probs( + self._initial_cell_state, self._coverage_penalty_weight) + if init_attention_probs is None: + init_attention_probs = () + + initial_state = BeamSearchDecoderState( + cell_state=self._initial_cell_state, + log_probs=log_probs, + finished=finished, + lengths=array_ops.zeros( + [self._batch_size, self._beam_width], dtype=dtypes.int64), + accumulated_attention_probs=init_attention_probs) + + return (finished, start_inputs, initial_state) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and int32 (the id) + dtype = nest.flatten(self._initial_cell_state)[0].dtype + return BeamSearchDecoderOutput( + scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), + predicted_ids=dtypes.int32, + parent_ids=dtypes.int32) + + def call(self, embeddning, start_tokens, end_token, initial_state, **kwargs): + init_kwargs = kwargs + init_kwargs["start_tokens"] = start_tokens + init_kwargs["end_token"] = end_token + init_kwargs["initial_state"] = initial_state + return decoder.dynamic_decode(self, + output_time_major=self.output_time_major, + impute_finished=self.impute_finished, + maximum_iterations=self.maximum_iterations, + parallel_iterations=self.parallel_iterations, + swap_memory=self.swap_memory, + decoder_init_input=embeddning, + decoder_init_kwargs=init_kwargs) + + def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, beam_width, end_token, length_penalty_weight, coverage_penalty_weight): @@ -1068,7 +1323,7 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size, """ if isinstance(gather_from, tensor_array_ops.TensorArray): return gather_from - _check_maybe(gather_from) + _check_ndims(gather_from) if gather_from.shape.ndims >= len(gather_shape): return _tensor_gather_helper( gather_indices=gather_indices,