Add BeamSearchDecoderV2 which can be used as a keras layer.
PiperOrigin-RevId: 233982439
This commit is contained in:
parent
03aa9d18f5
commit
3c9b46c245
@ -25,10 +25,13 @@ from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
|
||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
|
||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import layers
|
||||
from tensorflow.python.layers import core as layers_core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
@ -530,11 +533,10 @@ class BeamSearchDecoderTest(test.TestCase):
|
||||
return (shape[1], shape[0]) + shape[2:]
|
||||
return shape
|
||||
|
||||
self.assertTrue(
|
||||
isinstance(final_outputs,
|
||||
beam_search_decoder.FinalBeamSearchDecoderOutput))
|
||||
self.assertTrue(
|
||||
isinstance(final_state, beam_search_decoder.BeamSearchDecoderState))
|
||||
self.assertIsInstance(
|
||||
final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput)
|
||||
self.assertIsInstance(
|
||||
final_state, beam_search_decoder.BeamSearchDecoderState)
|
||||
|
||||
beam_search_decoder_output = final_outputs.beam_search_decoder_output
|
||||
self.assertEqual(
|
||||
@ -574,5 +576,119 @@ class BeamSearchDecoderTest(test.TestCase):
|
||||
with_alignment_history=True)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class BeamSearchDecoderV2Test(test.TestCase):
|
||||
|
||||
def _testDynamicDecodeRNN(self, time_major, has_attention,
|
||||
with_alignment_history=False):
|
||||
encoder_sequence_length = np.array([3, 2, 3, 1, 1])
|
||||
decoder_sequence_length = np.array([2, 0, 1, 2, 3])
|
||||
batch_size = 5
|
||||
decoder_max_time = 4
|
||||
input_depth = 7
|
||||
cell_depth = 9
|
||||
attention_depth = 6
|
||||
vocab_size = 20
|
||||
end_token = vocab_size - 1
|
||||
start_token = 0
|
||||
embedding_dim = 50
|
||||
max_out = max(decoder_sequence_length)
|
||||
output_layer = layers.Dense(vocab_size, use_bias=True, activation=None)
|
||||
beam_width = 3
|
||||
|
||||
with self.cached_session():
|
||||
batch_size_tensor = constant_op.constant(batch_size)
|
||||
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
|
||||
cell = rnn_cell.LSTMCell(cell_depth)
|
||||
initial_state = cell.zero_state(batch_size, dtypes.float32)
|
||||
coverage_penalty_weight = 0.0
|
||||
if has_attention:
|
||||
coverage_penalty_weight = 0.2
|
||||
inputs = array_ops.placeholder_with_default(
|
||||
np.random.randn(batch_size, decoder_max_time, input_depth).astype(
|
||||
np.float32),
|
||||
shape=(None, None, input_depth))
|
||||
tiled_inputs = beam_search_decoder.tile_batch(
|
||||
inputs, multiplier=beam_width)
|
||||
tiled_sequence_length = beam_search_decoder.tile_batch(
|
||||
encoder_sequence_length, multiplier=beam_width)
|
||||
attention_mechanism = attention_wrapper.BahdanauAttention(
|
||||
num_units=attention_depth,
|
||||
memory=tiled_inputs,
|
||||
memory_sequence_length=tiled_sequence_length)
|
||||
initial_state = beam_search_decoder.tile_batch(
|
||||
initial_state, multiplier=beam_width)
|
||||
cell = attention_wrapper.AttentionWrapper(
|
||||
cell=cell,
|
||||
attention_mechanism=attention_mechanism,
|
||||
attention_layer_size=attention_depth,
|
||||
alignment_history=with_alignment_history)
|
||||
cell_state = cell.zero_state(
|
||||
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
|
||||
if has_attention:
|
||||
cell_state = cell_state.clone(cell_state=initial_state)
|
||||
bsd = beam_search_decoder.BeamSearchDecoderV2(
|
||||
cell=cell,
|
||||
beam_width=beam_width,
|
||||
output_layer=output_layer,
|
||||
length_penalty_weight=0.0,
|
||||
coverage_penalty_weight=coverage_penalty_weight,
|
||||
output_time_major=time_major,
|
||||
maximum_iterations=max_out)
|
||||
|
||||
final_outputs, final_state, final_sequence_lengths = bsd(
|
||||
embedding,
|
||||
start_tokens=array_ops.fill([batch_size_tensor], start_token),
|
||||
end_token=end_token,
|
||||
initial_state=cell_state)
|
||||
|
||||
def _t(shape):
|
||||
if time_major:
|
||||
return (shape[1], shape[0]) + shape[2:]
|
||||
return shape
|
||||
|
||||
self.assertIsInstance(
|
||||
final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput)
|
||||
self.assertIsInstance(
|
||||
final_state, beam_search_decoder.BeamSearchDecoderState)
|
||||
|
||||
beam_search_decoder_output = final_outputs.beam_search_decoder_output
|
||||
expected_seq_length = 3 if context.executing_eagerly() else None
|
||||
self.assertEqual(
|
||||
_t((batch_size, expected_seq_length, beam_width)),
|
||||
tuple(beam_search_decoder_output.scores.get_shape().as_list()))
|
||||
self.assertEqual(
|
||||
_t((batch_size, expected_seq_length, beam_width)),
|
||||
tuple(final_outputs.predicted_ids.get_shape().as_list()))
|
||||
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
eval_results = self.evaluate({
|
||||
'final_outputs': final_outputs,
|
||||
'final_sequence_lengths': final_sequence_lengths
|
||||
})
|
||||
|
||||
max_sequence_length = np.max(eval_results['final_sequence_lengths'])
|
||||
|
||||
# A smoke test
|
||||
self.assertEqual(
|
||||
_t((batch_size, max_sequence_length, beam_width)),
|
||||
eval_results['final_outputs'].beam_search_decoder_output.scores.shape)
|
||||
self.assertEqual(
|
||||
_t((batch_size, max_sequence_length, beam_width)), eval_results[
|
||||
'final_outputs'].beam_search_decoder_output.predicted_ids.shape)
|
||||
|
||||
def testDynamicDecodeRNNBatchMajorNoAttention(self):
|
||||
self._testDynamicDecodeRNN(time_major=False, has_attention=False)
|
||||
|
||||
def testDynamicDecodeRNNBatchMajorYesAttention(self):
|
||||
self._testDynamicDecodeRNN(time_major=False, has_attention=True)
|
||||
|
||||
def testDynamicDecodeRNNBatchMajorYesAttentionWithAlignmentHistory(self):
|
||||
self._testDynamicDecodeRNN(
|
||||
time_major=False,
|
||||
has_attention=True,
|
||||
with_alignment_history=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -25,6 +25,7 @@ import math
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -1919,7 +1920,15 @@ class AttentionWrapperState(
|
||||
def with_same_shape(old, new):
|
||||
"""Check and set new tensor's shape."""
|
||||
if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor):
|
||||
if not context.executing_eagerly():
|
||||
return tensor_util.with_same_shape(old, new)
|
||||
else:
|
||||
if old.shape.as_list() != new.shape.as_list():
|
||||
raise ValueError("The shape of the AttentionWrapperState is "
|
||||
"expected to be same as the one to clone. "
|
||||
"self.shape: %s, input.shape: %s" %
|
||||
(old.shape, new.shape))
|
||||
return new
|
||||
return new
|
||||
|
||||
return nest.map_structure(
|
||||
@ -2048,13 +2057,13 @@ def _compute_attention(attention_mechanism, cell_output, attention_state,
|
||||
# the batched matmul is over memory_time, so the output shape is
|
||||
# [batch_size, 1, memory_size].
|
||||
# we then squeeze out the singleton dim.
|
||||
context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
|
||||
context = array_ops.squeeze(context, [1])
|
||||
context_ = math_ops.matmul(expanded_alignments, attention_mechanism.values)
|
||||
context_ = array_ops.squeeze(context_, [1])
|
||||
|
||||
if attention_layer is not None:
|
||||
attention = attention_layer(array_ops.concat([cell_output, context], 1))
|
||||
attention = attention_layer(array_ops.concat([cell_output, context_], 1))
|
||||
else:
|
||||
attention = context
|
||||
attention = context_
|
||||
|
||||
return attention, alignments, next_attention_state
|
||||
|
||||
|
@ -24,11 +24,12 @@ import numpy as np
|
||||
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
|
||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.layers import base as layers_base
|
||||
from tensorflow.python.keras import layers
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
@ -182,11 +183,12 @@ def gather_tree_from_array(t, parent_ids, sequence_length):
|
||||
return ordered
|
||||
|
||||
|
||||
def _check_maybe(t):
|
||||
def _check_ndims(t):
|
||||
if t.shape.ndims is None:
|
||||
raise ValueError(
|
||||
"Expected tensor (%s) to have known rank, but ndims == None." % t)
|
||||
|
||||
|
||||
def _check_static_batch_beam_maybe(shape, batch_size, beam_width):
|
||||
"""Raises an exception if dimensions are known statically and can not be
|
||||
reshaped to [batch_size, beam_size, -1].
|
||||
@ -205,6 +207,7 @@ def _check_static_batch_beam_maybe(shape, batch_size, beam_width):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _check_batch_beam(t, batch_size, beam_width):
|
||||
"""Returns an Assert operation checking that the elements of the stacked
|
||||
TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point,
|
||||
@ -229,70 +232,30 @@ def _check_batch_beam(t, batch_size, beam_width):
|
||||
return control_flow_ops.Assert(condition, [error_message])
|
||||
|
||||
|
||||
class BeamSearchDecoderMixin(object):
|
||||
"""BeamSearchDecoderMixin contains the common methods for BeamSearchDecoder.
|
||||
|
||||
class BeamSearchDecoder(decoder.Decoder):
|
||||
"""BeamSearch sampling decoder.
|
||||
|
||||
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
|
||||
`AttentionWrapper`, then you must ensure that:
|
||||
|
||||
- The encoder output has been tiled to `beam_width` via
|
||||
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
|
||||
- The `batch_size` argument passed to the `zero_state` method of this
|
||||
wrapper is equal to `true_batch_size * beam_width`.
|
||||
- The initial state created with `zero_state` above contains a
|
||||
`cell_state` value containing properly tiled final state from the
|
||||
encoder.
|
||||
|
||||
An example:
|
||||
|
||||
```
|
||||
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
|
||||
encoder_outputs, multiplier=beam_width)
|
||||
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
|
||||
encoder_final_state, multiplier=beam_width)
|
||||
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
|
||||
sequence_length, multiplier=beam_width)
|
||||
attention_mechanism = MyFavoriteAttentionMechanism(
|
||||
num_units=attention_depth,
|
||||
memory=tiled_inputs,
|
||||
memory_sequence_length=tiled_sequence_length)
|
||||
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
|
||||
decoder_initial_state = attention_cell.zero_state(
|
||||
dtype, batch_size=true_batch_size * beam_width)
|
||||
decoder_initial_state = decoder_initial_state.clone(
|
||||
cell_state=tiled_encoder_final_state)
|
||||
```
|
||||
|
||||
Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
|
||||
when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages
|
||||
the translation to cover all inputs.
|
||||
It is expected to be used a base class for concrete BeamSearchDecoder. Since
|
||||
this is a mixin class, it is expected to be used together with other class as
|
||||
base.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cell,
|
||||
embedding,
|
||||
start_tokens,
|
||||
end_token,
|
||||
initial_state,
|
||||
beam_width,
|
||||
output_layer=None,
|
||||
length_penalty_weight=0.0,
|
||||
coverage_penalty_weight=0.0,
|
||||
reorder_tensor_arrays=True):
|
||||
"""Initialize the BeamSearchDecoder.
|
||||
reorder_tensor_arrays=True,
|
||||
**kwargs):
|
||||
"""Initialize the BeamSearchDecoderMixin.
|
||||
|
||||
Args:
|
||||
cell: An `RNNCell` instance.
|
||||
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
|
||||
or the `params` argument for `embedding_lookup`.
|
||||
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
|
||||
end_token: `int32` scalar, the token that marks end of decoding.
|
||||
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
|
||||
beam_width: Python integer, the number of beams.
|
||||
output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
|
||||
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
|
||||
to storing the result or sampling.
|
||||
output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
|
||||
`tf.keras.layers.Dense`. Optional layer to apply to the RNN output
|
||||
prior to storing the result or sampling.
|
||||
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
|
||||
coverage_penalty_weight: Float weight to penalize the coverage of source
|
||||
sentence. Disabled with 0.0.
|
||||
@ -302,59 +265,35 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
Otherwise, the `TensorArray` will be returned as is. Set this flag to
|
||||
`False` if the cell state contains `TensorArray`s that are not amenable
|
||||
to reordering.
|
||||
**kwargs: Dict, other keyword arguments for parent class.
|
||||
|
||||
Raises:
|
||||
TypeError: if `cell` is not an instance of `RNNCell`,
|
||||
or `output_layer` is not an instance of `tf.layers.Layer`.
|
||||
ValueError: If `start_tokens` is not a vector or
|
||||
`end_token` is not a scalar.
|
||||
or `output_layer` is not an instance of `tf.keras.layers.Layer`.
|
||||
"""
|
||||
rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access
|
||||
if (output_layer is not None and
|
||||
not isinstance(output_layer, layers_base.Layer)):
|
||||
not isinstance(output_layer, layers.Layer)):
|
||||
raise TypeError(
|
||||
"output_layer must be a Layer, received: %s" % type(output_layer))
|
||||
self._cell = cell
|
||||
self._output_layer = output_layer
|
||||
self._reorder_tensor_arrays = reorder_tensor_arrays
|
||||
|
||||
if callable(embedding):
|
||||
self._embedding_fn = embedding
|
||||
else:
|
||||
self._embedding_fn = (
|
||||
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
|
||||
|
||||
self._start_tokens = ops.convert_to_tensor(
|
||||
start_tokens, dtype=dtypes.int32, name="start_tokens")
|
||||
if self._start_tokens.get_shape().ndims != 1:
|
||||
raise ValueError("start_tokens must be a vector")
|
||||
self._end_token = ops.convert_to_tensor(
|
||||
end_token, dtype=dtypes.int32, name="end_token")
|
||||
if self._end_token.get_shape().ndims != 0:
|
||||
raise ValueError("end_token must be a scalar")
|
||||
|
||||
self._batch_size = array_ops.size(start_tokens)
|
||||
self._start_tokens = None
|
||||
self._end_token = None
|
||||
self._batch_size = None
|
||||
self._beam_width = beam_width
|
||||
self._length_penalty_weight = length_penalty_weight
|
||||
self._coverage_penalty_weight = coverage_penalty_weight
|
||||
self._initial_cell_state = nest.map_structure(
|
||||
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
|
||||
self._start_tokens = array_ops.tile(
|
||||
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
|
||||
self._start_inputs = self._embedding_fn(self._start_tokens)
|
||||
|
||||
self._finished = array_ops.one_hot(
|
||||
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
|
||||
depth=self._beam_width,
|
||||
on_value=False,
|
||||
off_value=True,
|
||||
dtype=dtypes.bool)
|
||||
super(BeamSearchDecoderMixin, self).__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def _rnn_output_size(self):
|
||||
"""Get the output shape from the RNN layer."""
|
||||
size = self._cell.output_size
|
||||
if self._output_layer is None:
|
||||
return size
|
||||
@ -393,50 +332,6 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
predicted_ids=tensor_shape.TensorShape([self._beam_width]),
|
||||
parent_ids=tensor_shape.TensorShape([self._beam_width]))
|
||||
|
||||
@property
|
||||
def output_dtype(self):
|
||||
# Assume the dtype of the cell is the output_size structure
|
||||
# containing the input_state's first component's dtype.
|
||||
# Return that structure and int32 (the id)
|
||||
dtype = nest.flatten(self._initial_cell_state)[0].dtype
|
||||
return BeamSearchDecoderOutput(
|
||||
scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
|
||||
predicted_ids=dtypes.int32,
|
||||
parent_ids=dtypes.int32)
|
||||
|
||||
def initialize(self, name=None):
|
||||
"""Initialize the decoder.
|
||||
|
||||
Args:
|
||||
name: Name scope for any created operations.
|
||||
|
||||
Returns:
|
||||
`(finished, start_inputs, initial_state)`.
|
||||
"""
|
||||
finished, start_inputs = self._finished, self._start_inputs
|
||||
|
||||
dtype = nest.flatten(self._initial_cell_state)[0].dtype
|
||||
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
|
||||
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
|
||||
depth=self._beam_width,
|
||||
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
|
||||
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
|
||||
dtype=dtype)
|
||||
init_attention_probs = get_attention_probs(
|
||||
self._initial_cell_state, self._coverage_penalty_weight)
|
||||
if init_attention_probs is None:
|
||||
init_attention_probs = ()
|
||||
|
||||
initial_state = BeamSearchDecoderState(
|
||||
cell_state=self._initial_cell_state,
|
||||
log_probs=log_probs,
|
||||
finished=finished,
|
||||
lengths=array_ops.zeros(
|
||||
[self._batch_size, self._beam_width], dtype=dtypes.int64),
|
||||
accumulated_attention_probs=init_attention_probs)
|
||||
|
||||
return (finished, start_inputs, initial_state)
|
||||
|
||||
def finalize(self, outputs, final_state, sequence_lengths):
|
||||
"""Finalize and return the predicted_ids.
|
||||
|
||||
@ -562,7 +457,7 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
"""
|
||||
if isinstance(t, tensor_array_ops.TensorArray):
|
||||
return t
|
||||
_check_maybe(t)
|
||||
_check_ndims(t)
|
||||
if t.shape.ndims >= 1:
|
||||
return self._split_batch_beams(t, s)
|
||||
else:
|
||||
@ -586,7 +481,7 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
"""
|
||||
if isinstance(t, tensor_array_ops.TensorArray):
|
||||
return t
|
||||
_check_maybe(t)
|
||||
_check_ndims(t)
|
||||
if t.shape.ndims >= 2:
|
||||
return self._merge_batch_beams(t, s)
|
||||
else:
|
||||
@ -609,11 +504,18 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
if not isinstance(t, tensor_array_ops.TensorArray):
|
||||
return t
|
||||
# pylint: disable=protected-access
|
||||
if (not t._infer_shape or not t._element_shape
|
||||
or t._element_shape[0].ndims is None
|
||||
or t._element_shape[0].ndims < 1):
|
||||
# This is a bad hack due to the implementation detail of eager/graph TA.
|
||||
# TODO(b/124374427): Update this to use public property of TensorArray.
|
||||
if context.executing_eagerly():
|
||||
element_shape = t._element_shape
|
||||
else:
|
||||
element_shape = t._element_shape[0]
|
||||
if (not t._infer_shape
|
||||
or not t._element_shape
|
||||
or element_shape.ndims is None
|
||||
or element_shape.ndims < 1):
|
||||
shape = (
|
||||
t._element_shape[0] if t._infer_shape and t._element_shape
|
||||
element_shape if t._infer_shape and t._element_shape
|
||||
else tensor_shape.TensorShape(None))
|
||||
tf_logging.warn("The TensorArray %s in the cell state is not amenable to "
|
||||
"sorting based on the beam search result. For a "
|
||||
@ -621,10 +523,10 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
"defined and have at least a rank of 1, but saw shape: %s"
|
||||
% (t.handle.name, shape))
|
||||
return t
|
||||
shape = t._element_shape[0]
|
||||
# pylint: enable=protected-access
|
||||
if not _check_static_batch_beam_maybe(
|
||||
shape, tensor_util.constant_value(self._batch_size), self._beam_width):
|
||||
element_shape, tensor_util.constant_value(self._batch_size),
|
||||
self._beam_width):
|
||||
return t
|
||||
t = t.stack()
|
||||
with ops.control_dependencies(
|
||||
@ -684,6 +586,359 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
return (beam_search_output, beam_search_state, next_inputs, finished)
|
||||
|
||||
|
||||
class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.Decoder):
|
||||
# Note that the inheritance hierarchy is important here. The Mixin has to be
|
||||
# the first parent class since we will use super().__init__(), and Mixin which
|
||||
# is a object will properly invoke the __init__ method of other parent class.
|
||||
"""BeamSearch sampling decoder.
|
||||
|
||||
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
|
||||
`AttentionWrapper`, then you must ensure that:
|
||||
|
||||
- The encoder output has been tiled to `beam_width` via
|
||||
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
|
||||
- The `batch_size` argument passed to the `zero_state` method of this
|
||||
wrapper is equal to `true_batch_size * beam_width`.
|
||||
- The initial state created with `zero_state` above contains a
|
||||
`cell_state` value containing properly tiled final state from the
|
||||
encoder.
|
||||
|
||||
An example:
|
||||
|
||||
```
|
||||
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
|
||||
encoder_outputs, multiplier=beam_width)
|
||||
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
|
||||
encoder_final_state, multiplier=beam_width)
|
||||
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
|
||||
sequence_length, multiplier=beam_width)
|
||||
attention_mechanism = MyFavoriteAttentionMechanism(
|
||||
num_units=attention_depth,
|
||||
memory=tiled_inputs,
|
||||
memory_sequence_length=tiled_sequence_length)
|
||||
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
|
||||
decoder_initial_state = attention_cell.zero_state(
|
||||
dtype, batch_size=true_batch_size * beam_width)
|
||||
decoder_initial_state = decoder_initial_state.clone(
|
||||
cell_state=tiled_encoder_final_state)
|
||||
```
|
||||
|
||||
Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
|
||||
when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages
|
||||
the decoder to cover all inputs.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cell,
|
||||
embedding,
|
||||
start_tokens,
|
||||
end_token,
|
||||
initial_state,
|
||||
beam_width,
|
||||
output_layer=None,
|
||||
length_penalty_weight=0.0,
|
||||
coverage_penalty_weight=0.0,
|
||||
reorder_tensor_arrays=True):
|
||||
"""Initialize the BeamSearchDecoder.
|
||||
|
||||
Args:
|
||||
cell: An `RNNCell` instance.
|
||||
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
|
||||
or the `params` argument for `embedding_lookup`.
|
||||
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
|
||||
end_token: `int32` scalar, the token that marks end of decoding.
|
||||
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
|
||||
beam_width: Python integer, the number of beams.
|
||||
output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
|
||||
`tf.keras.layers.Dense`. Optional layer to apply to the RNN output
|
||||
prior to storing the result or sampling.
|
||||
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
|
||||
coverage_penalty_weight: Float weight to penalize the coverage of source
|
||||
sentence. Disabled with 0.0.
|
||||
reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
|
||||
state will be reordered according to the beam search path. If the
|
||||
`TensorArray` can be reordered, the stacked form will be returned.
|
||||
Otherwise, the `TensorArray` will be returned as is. Set this flag to
|
||||
`False` if the cell state contains `TensorArray`s that are not amenable
|
||||
to reordering.
|
||||
|
||||
Raises:
|
||||
TypeError: if `cell` is not an instance of `RNNCell`,
|
||||
or `output_layer` is not an instance of `tf.keras.layers.Layer`.
|
||||
ValueError: If `start_tokens` is not a vector or
|
||||
`end_token` is not a scalar.
|
||||
"""
|
||||
super(BeamSearchDecoder, self).__init__(
|
||||
cell,
|
||||
beam_width,
|
||||
output_layer=output_layer,
|
||||
length_penalty_weight=length_penalty_weight,
|
||||
coverage_penalty_weight=coverage_penalty_weight,
|
||||
reorder_tensor_arrays=reorder_tensor_arrays)
|
||||
|
||||
if callable(embedding):
|
||||
self._embedding_fn = embedding
|
||||
else:
|
||||
self._embedding_fn = (
|
||||
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
|
||||
|
||||
self._start_tokens = ops.convert_to_tensor(
|
||||
start_tokens, dtype=dtypes.int32, name="start_tokens")
|
||||
if self._start_tokens.get_shape().ndims != 1:
|
||||
raise ValueError("start_tokens must be a vector")
|
||||
self._end_token = ops.convert_to_tensor(
|
||||
end_token, dtype=dtypes.int32, name="end_token")
|
||||
if self._end_token.get_shape().ndims != 0:
|
||||
raise ValueError("end_token must be a scalar")
|
||||
|
||||
self._batch_size = array_ops.size(start_tokens)
|
||||
self._initial_cell_state = nest.map_structure(
|
||||
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
|
||||
self._start_tokens = array_ops.tile(
|
||||
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
|
||||
self._start_inputs = self._embedding_fn(self._start_tokens)
|
||||
|
||||
self._finished = array_ops.one_hot(
|
||||
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
|
||||
depth=self._beam_width,
|
||||
on_value=False,
|
||||
off_value=True,
|
||||
dtype=dtypes.bool)
|
||||
|
||||
def initialize(self, name=None):
|
||||
"""Initialize the decoder.
|
||||
|
||||
Args:
|
||||
name: Name scope for any created operations.
|
||||
|
||||
Returns:
|
||||
`(finished, start_inputs, initial_state)`.
|
||||
"""
|
||||
finished, start_inputs = self._finished, self._start_inputs
|
||||
|
||||
dtype = nest.flatten(self._initial_cell_state)[0].dtype
|
||||
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
|
||||
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
|
||||
depth=self._beam_width,
|
||||
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
|
||||
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
|
||||
dtype=dtype)
|
||||
init_attention_probs = get_attention_probs(
|
||||
self._initial_cell_state, self._coverage_penalty_weight)
|
||||
if init_attention_probs is None:
|
||||
init_attention_probs = ()
|
||||
|
||||
initial_state = BeamSearchDecoderState(
|
||||
cell_state=self._initial_cell_state,
|
||||
log_probs=log_probs,
|
||||
finished=finished,
|
||||
lengths=array_ops.zeros(
|
||||
[self._batch_size, self._beam_width], dtype=dtypes.int64),
|
||||
accumulated_attention_probs=init_attention_probs)
|
||||
|
||||
return (finished, start_inputs, initial_state)
|
||||
|
||||
@property
|
||||
def output_dtype(self):
|
||||
# Assume the dtype of the cell is the output_size structure
|
||||
# containing the input_state's first component's dtype.
|
||||
# Return that structure and int32 (the id)
|
||||
dtype = nest.flatten(self._initial_cell_state)[0].dtype
|
||||
return BeamSearchDecoderOutput(
|
||||
scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
|
||||
predicted_ids=dtypes.int32,
|
||||
parent_ids=dtypes.int32)
|
||||
|
||||
|
||||
class BeamSearchDecoderV2(BeamSearchDecoderMixin, decoder.BaseDecoder):
|
||||
# Note that the inheritance hierarchy is important here. The Mixin has to be
|
||||
# the first parent class since we will use super().__init__(), and Mixin which
|
||||
# is a object will properly invoke the __init__ method of other parent class.
|
||||
"""BeamSearch sampling decoder.
|
||||
|
||||
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
|
||||
`AttentionWrapper`, then you must ensure that:
|
||||
|
||||
- The encoder output has been tiled to `beam_width` via
|
||||
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
|
||||
- The `batch_size` argument passed to the `zero_state` method of this
|
||||
wrapper is equal to `true_batch_size * beam_width`.
|
||||
- The initial state created with `zero_state` above contains a
|
||||
`cell_state` value containing properly tiled final state from the
|
||||
encoder.
|
||||
|
||||
An example:
|
||||
|
||||
```
|
||||
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
|
||||
encoder_outputs, multiplier=beam_width)
|
||||
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
|
||||
encoder_final_state, multiplier=beam_width)
|
||||
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
|
||||
sequence_length, multiplier=beam_width)
|
||||
attention_mechanism = MyFavoriteAttentionMechanism(
|
||||
num_units=attention_depth,
|
||||
memory=tiled_inputs,
|
||||
memory_sequence_length=tiled_sequence_length)
|
||||
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
|
||||
decoder_initial_state = attention_cell.zero_state(
|
||||
dtype, batch_size=true_batch_size * beam_width)
|
||||
decoder_initial_state = decoder_initial_state.clone(
|
||||
cell_state=tiled_encoder_final_state)
|
||||
```
|
||||
|
||||
Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
|
||||
when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages
|
||||
the decoding to cover all inputs.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cell,
|
||||
beam_width,
|
||||
embedding_fn=None,
|
||||
output_layer=None,
|
||||
length_penalty_weight=0.0,
|
||||
coverage_penalty_weight=0.0,
|
||||
reorder_tensor_arrays=True,
|
||||
**kwargs):
|
||||
"""Initialize the BeamSearchDecoderV2.
|
||||
|
||||
Args:
|
||||
cell: An `RNNCell` instance.
|
||||
beam_width: Python integer, the number of beams.
|
||||
embedding_fn: A callable that takes a vector tensor of `ids` (argmax ids).
|
||||
output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
|
||||
`tf.keras.layers.Dense`. Optional layer to apply to the RNN output
|
||||
prior to storing the result or sampling.
|
||||
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
|
||||
coverage_penalty_weight: Float weight to penalize the coverage of source
|
||||
sentence. Disabled with 0.0.
|
||||
reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
|
||||
state will be reordered according to the beam search path. If the
|
||||
`TensorArray` can be reordered, the stacked form will be returned.
|
||||
Otherwise, the `TensorArray` will be returned as is. Set this flag to
|
||||
`False` if the cell state contains `TensorArray`s that are not amenable
|
||||
to reordering.
|
||||
**kwargs: Dict, other keyword arguments for initialization.
|
||||
|
||||
Raises:
|
||||
TypeError: if `cell` is not an instance of `RNNCell`,
|
||||
or `output_layer` is not an instance of `tf.keras.layers.Layer`.
|
||||
"""
|
||||
super(BeamSearchDecoderV2, self).__init__(
|
||||
cell,
|
||||
beam_width,
|
||||
output_layer=output_layer,
|
||||
length_penalty_weight=length_penalty_weight,
|
||||
coverage_penalty_weight=coverage_penalty_weight,
|
||||
reorder_tensor_arrays=reorder_tensor_arrays,
|
||||
**kwargs)
|
||||
|
||||
if embedding_fn is None or callable(embedding_fn):
|
||||
self._embedding_fn = embedding_fn
|
||||
else:
|
||||
raise ValueError("embedding_fn is expected to be a callable, got %s" %
|
||||
type(embedding_fn))
|
||||
|
||||
def initialize(self,
|
||||
embedding,
|
||||
start_tokens,
|
||||
end_token,
|
||||
initial_state):
|
||||
"""Initialize the decoder.
|
||||
|
||||
Args:
|
||||
embedding: A tensor from the embedding layer output, which is the
|
||||
`params` argument for `embedding_lookup`.
|
||||
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
|
||||
end_token: `int32` scalar, the token that marks end of decoding.
|
||||
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
|
||||
Returns:
|
||||
`(finished, start_inputs, initial_state)`.
|
||||
Raises:
|
||||
ValueError: If `start_tokens` is not a vector or `end_token` is not a
|
||||
scalar.
|
||||
"""
|
||||
if embedding is not None and self._embedding_fn is not None:
|
||||
raise ValueError(
|
||||
"embedding and embedding_fn cannot be provided at same time")
|
||||
elif embedding is not None:
|
||||
self._embedding_fn = (
|
||||
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
|
||||
|
||||
self._start_tokens = ops.convert_to_tensor(
|
||||
start_tokens, dtype=dtypes.int32, name="start_tokens")
|
||||
if self._start_tokens.get_shape().ndims != 1:
|
||||
raise ValueError("start_tokens must be a vector")
|
||||
self._end_token = ops.convert_to_tensor(
|
||||
end_token, dtype=dtypes.int32, name="end_token")
|
||||
if self._end_token.get_shape().ndims != 0:
|
||||
raise ValueError("end_token must be a scalar")
|
||||
|
||||
self._batch_size = array_ops.size(start_tokens)
|
||||
self._initial_cell_state = nest.map_structure(
|
||||
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
|
||||
self._start_tokens = array_ops.tile(
|
||||
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
|
||||
self._start_inputs = self._embedding_fn(self._start_tokens)
|
||||
|
||||
self._finished = array_ops.one_hot(
|
||||
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
|
||||
depth=self._beam_width,
|
||||
on_value=False,
|
||||
off_value=True,
|
||||
dtype=dtypes.bool)
|
||||
|
||||
finished, start_inputs = self._finished, self._start_inputs
|
||||
|
||||
dtype = nest.flatten(self._initial_cell_state)[0].dtype
|
||||
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
|
||||
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
|
||||
depth=self._beam_width,
|
||||
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
|
||||
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
|
||||
dtype=dtype)
|
||||
init_attention_probs = get_attention_probs(
|
||||
self._initial_cell_state, self._coverage_penalty_weight)
|
||||
if init_attention_probs is None:
|
||||
init_attention_probs = ()
|
||||
|
||||
initial_state = BeamSearchDecoderState(
|
||||
cell_state=self._initial_cell_state,
|
||||
log_probs=log_probs,
|
||||
finished=finished,
|
||||
lengths=array_ops.zeros(
|
||||
[self._batch_size, self._beam_width], dtype=dtypes.int64),
|
||||
accumulated_attention_probs=init_attention_probs)
|
||||
|
||||
return (finished, start_inputs, initial_state)
|
||||
|
||||
@property
|
||||
def output_dtype(self):
|
||||
# Assume the dtype of the cell is the output_size structure
|
||||
# containing the input_state's first component's dtype.
|
||||
# Return that structure and int32 (the id)
|
||||
dtype = nest.flatten(self._initial_cell_state)[0].dtype
|
||||
return BeamSearchDecoderOutput(
|
||||
scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
|
||||
predicted_ids=dtypes.int32,
|
||||
parent_ids=dtypes.int32)
|
||||
|
||||
def call(self, embeddning, start_tokens, end_token, initial_state, **kwargs):
|
||||
init_kwargs = kwargs
|
||||
init_kwargs["start_tokens"] = start_tokens
|
||||
init_kwargs["end_token"] = end_token
|
||||
init_kwargs["initial_state"] = initial_state
|
||||
return decoder.dynamic_decode(self,
|
||||
output_time_major=self.output_time_major,
|
||||
impute_finished=self.impute_finished,
|
||||
maximum_iterations=self.maximum_iterations,
|
||||
parallel_iterations=self.parallel_iterations,
|
||||
swap_memory=self.swap_memory,
|
||||
decoder_init_input=embeddning,
|
||||
decoder_init_kwargs=init_kwargs)
|
||||
|
||||
|
||||
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
|
||||
beam_width, end_token, length_penalty_weight,
|
||||
coverage_penalty_weight):
|
||||
@ -1068,7 +1323,7 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
|
||||
"""
|
||||
if isinstance(gather_from, tensor_array_ops.TensorArray):
|
||||
return gather_from
|
||||
_check_maybe(gather_from)
|
||||
_check_ndims(gather_from)
|
||||
if gather_from.shape.ndims >= len(gather_shape):
|
||||
return _tensor_gather_helper(
|
||||
gather_indices=gather_indices,
|
||||
|
Loading…
Reference in New Issue
Block a user