Add BeamSearchDecoderV2 which can be used as a keras layer.
PiperOrigin-RevId: 233982439
This commit is contained in:
parent
03aa9d18f5
commit
3c9b46c245
@ -25,10 +25,13 @@ from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
|
|||||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
|
from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
|
||||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
||||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.keras import layers
|
||||||
from tensorflow.python.layers import core as layers_core
|
from tensorflow.python.layers import core as layers_core
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
@ -530,11 +533,10 @@ class BeamSearchDecoderTest(test.TestCase):
|
|||||||
return (shape[1], shape[0]) + shape[2:]
|
return (shape[1], shape[0]) + shape[2:]
|
||||||
return shape
|
return shape
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertIsInstance(
|
||||||
isinstance(final_outputs,
|
final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput)
|
||||||
beam_search_decoder.FinalBeamSearchDecoderOutput))
|
self.assertIsInstance(
|
||||||
self.assertTrue(
|
final_state, beam_search_decoder.BeamSearchDecoderState)
|
||||||
isinstance(final_state, beam_search_decoder.BeamSearchDecoderState))
|
|
||||||
|
|
||||||
beam_search_decoder_output = final_outputs.beam_search_decoder_output
|
beam_search_decoder_output = final_outputs.beam_search_decoder_output
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -574,5 +576,119 @@ class BeamSearchDecoderTest(test.TestCase):
|
|||||||
with_alignment_history=True)
|
with_alignment_history=True)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
class BeamSearchDecoderV2Test(test.TestCase):
|
||||||
|
|
||||||
|
def _testDynamicDecodeRNN(self, time_major, has_attention,
|
||||||
|
with_alignment_history=False):
|
||||||
|
encoder_sequence_length = np.array([3, 2, 3, 1, 1])
|
||||||
|
decoder_sequence_length = np.array([2, 0, 1, 2, 3])
|
||||||
|
batch_size = 5
|
||||||
|
decoder_max_time = 4
|
||||||
|
input_depth = 7
|
||||||
|
cell_depth = 9
|
||||||
|
attention_depth = 6
|
||||||
|
vocab_size = 20
|
||||||
|
end_token = vocab_size - 1
|
||||||
|
start_token = 0
|
||||||
|
embedding_dim = 50
|
||||||
|
max_out = max(decoder_sequence_length)
|
||||||
|
output_layer = layers.Dense(vocab_size, use_bias=True, activation=None)
|
||||||
|
beam_width = 3
|
||||||
|
|
||||||
|
with self.cached_session():
|
||||||
|
batch_size_tensor = constant_op.constant(batch_size)
|
||||||
|
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
|
||||||
|
cell = rnn_cell.LSTMCell(cell_depth)
|
||||||
|
initial_state = cell.zero_state(batch_size, dtypes.float32)
|
||||||
|
coverage_penalty_weight = 0.0
|
||||||
|
if has_attention:
|
||||||
|
coverage_penalty_weight = 0.2
|
||||||
|
inputs = array_ops.placeholder_with_default(
|
||||||
|
np.random.randn(batch_size, decoder_max_time, input_depth).astype(
|
||||||
|
np.float32),
|
||||||
|
shape=(None, None, input_depth))
|
||||||
|
tiled_inputs = beam_search_decoder.tile_batch(
|
||||||
|
inputs, multiplier=beam_width)
|
||||||
|
tiled_sequence_length = beam_search_decoder.tile_batch(
|
||||||
|
encoder_sequence_length, multiplier=beam_width)
|
||||||
|
attention_mechanism = attention_wrapper.BahdanauAttention(
|
||||||
|
num_units=attention_depth,
|
||||||
|
memory=tiled_inputs,
|
||||||
|
memory_sequence_length=tiled_sequence_length)
|
||||||
|
initial_state = beam_search_decoder.tile_batch(
|
||||||
|
initial_state, multiplier=beam_width)
|
||||||
|
cell = attention_wrapper.AttentionWrapper(
|
||||||
|
cell=cell,
|
||||||
|
attention_mechanism=attention_mechanism,
|
||||||
|
attention_layer_size=attention_depth,
|
||||||
|
alignment_history=with_alignment_history)
|
||||||
|
cell_state = cell.zero_state(
|
||||||
|
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
|
||||||
|
if has_attention:
|
||||||
|
cell_state = cell_state.clone(cell_state=initial_state)
|
||||||
|
bsd = beam_search_decoder.BeamSearchDecoderV2(
|
||||||
|
cell=cell,
|
||||||
|
beam_width=beam_width,
|
||||||
|
output_layer=output_layer,
|
||||||
|
length_penalty_weight=0.0,
|
||||||
|
coverage_penalty_weight=coverage_penalty_weight,
|
||||||
|
output_time_major=time_major,
|
||||||
|
maximum_iterations=max_out)
|
||||||
|
|
||||||
|
final_outputs, final_state, final_sequence_lengths = bsd(
|
||||||
|
embedding,
|
||||||
|
start_tokens=array_ops.fill([batch_size_tensor], start_token),
|
||||||
|
end_token=end_token,
|
||||||
|
initial_state=cell_state)
|
||||||
|
|
||||||
|
def _t(shape):
|
||||||
|
if time_major:
|
||||||
|
return (shape[1], shape[0]) + shape[2:]
|
||||||
|
return shape
|
||||||
|
|
||||||
|
self.assertIsInstance(
|
||||||
|
final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput)
|
||||||
|
self.assertIsInstance(
|
||||||
|
final_state, beam_search_decoder.BeamSearchDecoderState)
|
||||||
|
|
||||||
|
beam_search_decoder_output = final_outputs.beam_search_decoder_output
|
||||||
|
expected_seq_length = 3 if context.executing_eagerly() else None
|
||||||
|
self.assertEqual(
|
||||||
|
_t((batch_size, expected_seq_length, beam_width)),
|
||||||
|
tuple(beam_search_decoder_output.scores.get_shape().as_list()))
|
||||||
|
self.assertEqual(
|
||||||
|
_t((batch_size, expected_seq_length, beam_width)),
|
||||||
|
tuple(final_outputs.predicted_ids.get_shape().as_list()))
|
||||||
|
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
eval_results = self.evaluate({
|
||||||
|
'final_outputs': final_outputs,
|
||||||
|
'final_sequence_lengths': final_sequence_lengths
|
||||||
|
})
|
||||||
|
|
||||||
|
max_sequence_length = np.max(eval_results['final_sequence_lengths'])
|
||||||
|
|
||||||
|
# A smoke test
|
||||||
|
self.assertEqual(
|
||||||
|
_t((batch_size, max_sequence_length, beam_width)),
|
||||||
|
eval_results['final_outputs'].beam_search_decoder_output.scores.shape)
|
||||||
|
self.assertEqual(
|
||||||
|
_t((batch_size, max_sequence_length, beam_width)), eval_results[
|
||||||
|
'final_outputs'].beam_search_decoder_output.predicted_ids.shape)
|
||||||
|
|
||||||
|
def testDynamicDecodeRNNBatchMajorNoAttention(self):
|
||||||
|
self._testDynamicDecodeRNN(time_major=False, has_attention=False)
|
||||||
|
|
||||||
|
def testDynamicDecodeRNNBatchMajorYesAttention(self):
|
||||||
|
self._testDynamicDecodeRNN(time_major=False, has_attention=True)
|
||||||
|
|
||||||
|
def testDynamicDecodeRNNBatchMajorYesAttentionWithAlignmentHistory(self):
|
||||||
|
self._testDynamicDecodeRNN(
|
||||||
|
time_major=False,
|
||||||
|
has_attention=True,
|
||||||
|
with_alignment_history=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -25,6 +25,7 @@ import math
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.framework.python.framework import tensor_util
|
from tensorflow.contrib.framework.python.framework import tensor_util
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
@ -1919,7 +1920,15 @@ class AttentionWrapperState(
|
|||||||
def with_same_shape(old, new):
|
def with_same_shape(old, new):
|
||||||
"""Check and set new tensor's shape."""
|
"""Check and set new tensor's shape."""
|
||||||
if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor):
|
if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor):
|
||||||
|
if not context.executing_eagerly():
|
||||||
return tensor_util.with_same_shape(old, new)
|
return tensor_util.with_same_shape(old, new)
|
||||||
|
else:
|
||||||
|
if old.shape.as_list() != new.shape.as_list():
|
||||||
|
raise ValueError("The shape of the AttentionWrapperState is "
|
||||||
|
"expected to be same as the one to clone. "
|
||||||
|
"self.shape: %s, input.shape: %s" %
|
||||||
|
(old.shape, new.shape))
|
||||||
|
return new
|
||||||
return new
|
return new
|
||||||
|
|
||||||
return nest.map_structure(
|
return nest.map_structure(
|
||||||
@ -2048,13 +2057,13 @@ def _compute_attention(attention_mechanism, cell_output, attention_state,
|
|||||||
# the batched matmul is over memory_time, so the output shape is
|
# the batched matmul is over memory_time, so the output shape is
|
||||||
# [batch_size, 1, memory_size].
|
# [batch_size, 1, memory_size].
|
||||||
# we then squeeze out the singleton dim.
|
# we then squeeze out the singleton dim.
|
||||||
context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
|
context_ = math_ops.matmul(expanded_alignments, attention_mechanism.values)
|
||||||
context = array_ops.squeeze(context, [1])
|
context_ = array_ops.squeeze(context_, [1])
|
||||||
|
|
||||||
if attention_layer is not None:
|
if attention_layer is not None:
|
||||||
attention = attention_layer(array_ops.concat([cell_output, context], 1))
|
attention = attention_layer(array_ops.concat([cell_output, context_], 1))
|
||||||
else:
|
else:
|
||||||
attention = context
|
attention = context_
|
||||||
|
|
||||||
return attention, alignments, next_attention_state
|
return attention, alignments, next_attention_state
|
||||||
|
|
||||||
|
@ -24,11 +24,12 @@ import numpy as np
|
|||||||
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
|
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
|
||||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
||||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.layers import base as layers_base
|
from tensorflow.python.keras import layers
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
@ -182,11 +183,12 @@ def gather_tree_from_array(t, parent_ids, sequence_length):
|
|||||||
return ordered
|
return ordered
|
||||||
|
|
||||||
|
|
||||||
def _check_maybe(t):
|
def _check_ndims(t):
|
||||||
if t.shape.ndims is None:
|
if t.shape.ndims is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expected tensor (%s) to have known rank, but ndims == None." % t)
|
"Expected tensor (%s) to have known rank, but ndims == None." % t)
|
||||||
|
|
||||||
|
|
||||||
def _check_static_batch_beam_maybe(shape, batch_size, beam_width):
|
def _check_static_batch_beam_maybe(shape, batch_size, beam_width):
|
||||||
"""Raises an exception if dimensions are known statically and can not be
|
"""Raises an exception if dimensions are known statically and can not be
|
||||||
reshaped to [batch_size, beam_size, -1].
|
reshaped to [batch_size, beam_size, -1].
|
||||||
@ -205,6 +207,7 @@ def _check_static_batch_beam_maybe(shape, batch_size, beam_width):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _check_batch_beam(t, batch_size, beam_width):
|
def _check_batch_beam(t, batch_size, beam_width):
|
||||||
"""Returns an Assert operation checking that the elements of the stacked
|
"""Returns an Assert operation checking that the elements of the stacked
|
||||||
TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point,
|
TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point,
|
||||||
@ -229,70 +232,30 @@ def _check_batch_beam(t, batch_size, beam_width):
|
|||||||
return control_flow_ops.Assert(condition, [error_message])
|
return control_flow_ops.Assert(condition, [error_message])
|
||||||
|
|
||||||
|
|
||||||
|
class BeamSearchDecoderMixin(object):
|
||||||
|
"""BeamSearchDecoderMixin contains the common methods for BeamSearchDecoder.
|
||||||
|
|
||||||
class BeamSearchDecoder(decoder.Decoder):
|
It is expected to be used a base class for concrete BeamSearchDecoder. Since
|
||||||
"""BeamSearch sampling decoder.
|
this is a mixin class, it is expected to be used together with other class as
|
||||||
|
base.
|
||||||
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
|
|
||||||
`AttentionWrapper`, then you must ensure that:
|
|
||||||
|
|
||||||
- The encoder output has been tiled to `beam_width` via
|
|
||||||
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
|
|
||||||
- The `batch_size` argument passed to the `zero_state` method of this
|
|
||||||
wrapper is equal to `true_batch_size * beam_width`.
|
|
||||||
- The initial state created with `zero_state` above contains a
|
|
||||||
`cell_state` value containing properly tiled final state from the
|
|
||||||
encoder.
|
|
||||||
|
|
||||||
An example:
|
|
||||||
|
|
||||||
```
|
|
||||||
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
|
|
||||||
encoder_outputs, multiplier=beam_width)
|
|
||||||
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
|
|
||||||
encoder_final_state, multiplier=beam_width)
|
|
||||||
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
|
|
||||||
sequence_length, multiplier=beam_width)
|
|
||||||
attention_mechanism = MyFavoriteAttentionMechanism(
|
|
||||||
num_units=attention_depth,
|
|
||||||
memory=tiled_inputs,
|
|
||||||
memory_sequence_length=tiled_sequence_length)
|
|
||||||
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
|
|
||||||
decoder_initial_state = attention_cell.zero_state(
|
|
||||||
dtype, batch_size=true_batch_size * beam_width)
|
|
||||||
decoder_initial_state = decoder_initial_state.clone(
|
|
||||||
cell_state=tiled_encoder_final_state)
|
|
||||||
```
|
|
||||||
|
|
||||||
Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
|
|
||||||
when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages
|
|
||||||
the translation to cover all inputs.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
cell,
|
cell,
|
||||||
embedding,
|
|
||||||
start_tokens,
|
|
||||||
end_token,
|
|
||||||
initial_state,
|
|
||||||
beam_width,
|
beam_width,
|
||||||
output_layer=None,
|
output_layer=None,
|
||||||
length_penalty_weight=0.0,
|
length_penalty_weight=0.0,
|
||||||
coverage_penalty_weight=0.0,
|
coverage_penalty_weight=0.0,
|
||||||
reorder_tensor_arrays=True):
|
reorder_tensor_arrays=True,
|
||||||
"""Initialize the BeamSearchDecoder.
|
**kwargs):
|
||||||
|
"""Initialize the BeamSearchDecoderMixin.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cell: An `RNNCell` instance.
|
cell: An `RNNCell` instance.
|
||||||
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
|
|
||||||
or the `params` argument for `embedding_lookup`.
|
|
||||||
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
|
|
||||||
end_token: `int32` scalar, the token that marks end of decoding.
|
|
||||||
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
|
|
||||||
beam_width: Python integer, the number of beams.
|
beam_width: Python integer, the number of beams.
|
||||||
output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
|
output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
|
||||||
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
|
`tf.keras.layers.Dense`. Optional layer to apply to the RNN output
|
||||||
to storing the result or sampling.
|
prior to storing the result or sampling.
|
||||||
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
|
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
|
||||||
coverage_penalty_weight: Float weight to penalize the coverage of source
|
coverage_penalty_weight: Float weight to penalize the coverage of source
|
||||||
sentence. Disabled with 0.0.
|
sentence. Disabled with 0.0.
|
||||||
@ -302,59 +265,35 @@ class BeamSearchDecoder(decoder.Decoder):
|
|||||||
Otherwise, the `TensorArray` will be returned as is. Set this flag to
|
Otherwise, the `TensorArray` will be returned as is. Set this flag to
|
||||||
`False` if the cell state contains `TensorArray`s that are not amenable
|
`False` if the cell state contains `TensorArray`s that are not amenable
|
||||||
to reordering.
|
to reordering.
|
||||||
|
**kwargs: Dict, other keyword arguments for parent class.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: if `cell` is not an instance of `RNNCell`,
|
TypeError: if `cell` is not an instance of `RNNCell`,
|
||||||
or `output_layer` is not an instance of `tf.layers.Layer`.
|
or `output_layer` is not an instance of `tf.keras.layers.Layer`.
|
||||||
ValueError: If `start_tokens` is not a vector or
|
|
||||||
`end_token` is not a scalar.
|
|
||||||
"""
|
"""
|
||||||
rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access
|
rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access
|
||||||
if (output_layer is not None and
|
if (output_layer is not None and
|
||||||
not isinstance(output_layer, layers_base.Layer)):
|
not isinstance(output_layer, layers.Layer)):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"output_layer must be a Layer, received: %s" % type(output_layer))
|
"output_layer must be a Layer, received: %s" % type(output_layer))
|
||||||
self._cell = cell
|
self._cell = cell
|
||||||
self._output_layer = output_layer
|
self._output_layer = output_layer
|
||||||
self._reorder_tensor_arrays = reorder_tensor_arrays
|
self._reorder_tensor_arrays = reorder_tensor_arrays
|
||||||
|
|
||||||
if callable(embedding):
|
self._start_tokens = None
|
||||||
self._embedding_fn = embedding
|
self._end_token = None
|
||||||
else:
|
self._batch_size = None
|
||||||
self._embedding_fn = (
|
|
||||||
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
|
|
||||||
|
|
||||||
self._start_tokens = ops.convert_to_tensor(
|
|
||||||
start_tokens, dtype=dtypes.int32, name="start_tokens")
|
|
||||||
if self._start_tokens.get_shape().ndims != 1:
|
|
||||||
raise ValueError("start_tokens must be a vector")
|
|
||||||
self._end_token = ops.convert_to_tensor(
|
|
||||||
end_token, dtype=dtypes.int32, name="end_token")
|
|
||||||
if self._end_token.get_shape().ndims != 0:
|
|
||||||
raise ValueError("end_token must be a scalar")
|
|
||||||
|
|
||||||
self._batch_size = array_ops.size(start_tokens)
|
|
||||||
self._beam_width = beam_width
|
self._beam_width = beam_width
|
||||||
self._length_penalty_weight = length_penalty_weight
|
self._length_penalty_weight = length_penalty_weight
|
||||||
self._coverage_penalty_weight = coverage_penalty_weight
|
self._coverage_penalty_weight = coverage_penalty_weight
|
||||||
self._initial_cell_state = nest.map_structure(
|
super(BeamSearchDecoderMixin, self).__init__(**kwargs)
|
||||||
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
|
|
||||||
self._start_tokens = array_ops.tile(
|
|
||||||
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
|
|
||||||
self._start_inputs = self._embedding_fn(self._start_tokens)
|
|
||||||
|
|
||||||
self._finished = array_ops.one_hot(
|
|
||||||
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
|
|
||||||
depth=self._beam_width,
|
|
||||||
on_value=False,
|
|
||||||
off_value=True,
|
|
||||||
dtype=dtypes.bool)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
return self._batch_size
|
return self._batch_size
|
||||||
|
|
||||||
def _rnn_output_size(self):
|
def _rnn_output_size(self):
|
||||||
|
"""Get the output shape from the RNN layer."""
|
||||||
size = self._cell.output_size
|
size = self._cell.output_size
|
||||||
if self._output_layer is None:
|
if self._output_layer is None:
|
||||||
return size
|
return size
|
||||||
@ -393,50 +332,6 @@ class BeamSearchDecoder(decoder.Decoder):
|
|||||||
predicted_ids=tensor_shape.TensorShape([self._beam_width]),
|
predicted_ids=tensor_shape.TensorShape([self._beam_width]),
|
||||||
parent_ids=tensor_shape.TensorShape([self._beam_width]))
|
parent_ids=tensor_shape.TensorShape([self._beam_width]))
|
||||||
|
|
||||||
@property
|
|
||||||
def output_dtype(self):
|
|
||||||
# Assume the dtype of the cell is the output_size structure
|
|
||||||
# containing the input_state's first component's dtype.
|
|
||||||
# Return that structure and int32 (the id)
|
|
||||||
dtype = nest.flatten(self._initial_cell_state)[0].dtype
|
|
||||||
return BeamSearchDecoderOutput(
|
|
||||||
scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
|
|
||||||
predicted_ids=dtypes.int32,
|
|
||||||
parent_ids=dtypes.int32)
|
|
||||||
|
|
||||||
def initialize(self, name=None):
|
|
||||||
"""Initialize the decoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name scope for any created operations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`(finished, start_inputs, initial_state)`.
|
|
||||||
"""
|
|
||||||
finished, start_inputs = self._finished, self._start_inputs
|
|
||||||
|
|
||||||
dtype = nest.flatten(self._initial_cell_state)[0].dtype
|
|
||||||
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
|
|
||||||
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
|
|
||||||
depth=self._beam_width,
|
|
||||||
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
|
|
||||||
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
|
|
||||||
dtype=dtype)
|
|
||||||
init_attention_probs = get_attention_probs(
|
|
||||||
self._initial_cell_state, self._coverage_penalty_weight)
|
|
||||||
if init_attention_probs is None:
|
|
||||||
init_attention_probs = ()
|
|
||||||
|
|
||||||
initial_state = BeamSearchDecoderState(
|
|
||||||
cell_state=self._initial_cell_state,
|
|
||||||
log_probs=log_probs,
|
|
||||||
finished=finished,
|
|
||||||
lengths=array_ops.zeros(
|
|
||||||
[self._batch_size, self._beam_width], dtype=dtypes.int64),
|
|
||||||
accumulated_attention_probs=init_attention_probs)
|
|
||||||
|
|
||||||
return (finished, start_inputs, initial_state)
|
|
||||||
|
|
||||||
def finalize(self, outputs, final_state, sequence_lengths):
|
def finalize(self, outputs, final_state, sequence_lengths):
|
||||||
"""Finalize and return the predicted_ids.
|
"""Finalize and return the predicted_ids.
|
||||||
|
|
||||||
@ -562,7 +457,7 @@ class BeamSearchDecoder(decoder.Decoder):
|
|||||||
"""
|
"""
|
||||||
if isinstance(t, tensor_array_ops.TensorArray):
|
if isinstance(t, tensor_array_ops.TensorArray):
|
||||||
return t
|
return t
|
||||||
_check_maybe(t)
|
_check_ndims(t)
|
||||||
if t.shape.ndims >= 1:
|
if t.shape.ndims >= 1:
|
||||||
return self._split_batch_beams(t, s)
|
return self._split_batch_beams(t, s)
|
||||||
else:
|
else:
|
||||||
@ -586,7 +481,7 @@ class BeamSearchDecoder(decoder.Decoder):
|
|||||||
"""
|
"""
|
||||||
if isinstance(t, tensor_array_ops.TensorArray):
|
if isinstance(t, tensor_array_ops.TensorArray):
|
||||||
return t
|
return t
|
||||||
_check_maybe(t)
|
_check_ndims(t)
|
||||||
if t.shape.ndims >= 2:
|
if t.shape.ndims >= 2:
|
||||||
return self._merge_batch_beams(t, s)
|
return self._merge_batch_beams(t, s)
|
||||||
else:
|
else:
|
||||||
@ -609,11 +504,18 @@ class BeamSearchDecoder(decoder.Decoder):
|
|||||||
if not isinstance(t, tensor_array_ops.TensorArray):
|
if not isinstance(t, tensor_array_ops.TensorArray):
|
||||||
return t
|
return t
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if (not t._infer_shape or not t._element_shape
|
# This is a bad hack due to the implementation detail of eager/graph TA.
|
||||||
or t._element_shape[0].ndims is None
|
# TODO(b/124374427): Update this to use public property of TensorArray.
|
||||||
or t._element_shape[0].ndims < 1):
|
if context.executing_eagerly():
|
||||||
|
element_shape = t._element_shape
|
||||||
|
else:
|
||||||
|
element_shape = t._element_shape[0]
|
||||||
|
if (not t._infer_shape
|
||||||
|
or not t._element_shape
|
||||||
|
or element_shape.ndims is None
|
||||||
|
or element_shape.ndims < 1):
|
||||||
shape = (
|
shape = (
|
||||||
t._element_shape[0] if t._infer_shape and t._element_shape
|
element_shape if t._infer_shape and t._element_shape
|
||||||
else tensor_shape.TensorShape(None))
|
else tensor_shape.TensorShape(None))
|
||||||
tf_logging.warn("The TensorArray %s in the cell state is not amenable to "
|
tf_logging.warn("The TensorArray %s in the cell state is not amenable to "
|
||||||
"sorting based on the beam search result. For a "
|
"sorting based on the beam search result. For a "
|
||||||
@ -621,10 +523,10 @@ class BeamSearchDecoder(decoder.Decoder):
|
|||||||
"defined and have at least a rank of 1, but saw shape: %s"
|
"defined and have at least a rank of 1, but saw shape: %s"
|
||||||
% (t.handle.name, shape))
|
% (t.handle.name, shape))
|
||||||
return t
|
return t
|
||||||
shape = t._element_shape[0]
|
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
if not _check_static_batch_beam_maybe(
|
if not _check_static_batch_beam_maybe(
|
||||||
shape, tensor_util.constant_value(self._batch_size), self._beam_width):
|
element_shape, tensor_util.constant_value(self._batch_size),
|
||||||
|
self._beam_width):
|
||||||
return t
|
return t
|
||||||
t = t.stack()
|
t = t.stack()
|
||||||
with ops.control_dependencies(
|
with ops.control_dependencies(
|
||||||
@ -684,6 +586,359 @@ class BeamSearchDecoder(decoder.Decoder):
|
|||||||
return (beam_search_output, beam_search_state, next_inputs, finished)
|
return (beam_search_output, beam_search_state, next_inputs, finished)
|
||||||
|
|
||||||
|
|
||||||
|
class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.Decoder):
|
||||||
|
# Note that the inheritance hierarchy is important here. The Mixin has to be
|
||||||
|
# the first parent class since we will use super().__init__(), and Mixin which
|
||||||
|
# is a object will properly invoke the __init__ method of other parent class.
|
||||||
|
"""BeamSearch sampling decoder.
|
||||||
|
|
||||||
|
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
|
||||||
|
`AttentionWrapper`, then you must ensure that:
|
||||||
|
|
||||||
|
- The encoder output has been tiled to `beam_width` via
|
||||||
|
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
|
||||||
|
- The `batch_size` argument passed to the `zero_state` method of this
|
||||||
|
wrapper is equal to `true_batch_size * beam_width`.
|
||||||
|
- The initial state created with `zero_state` above contains a
|
||||||
|
`cell_state` value containing properly tiled final state from the
|
||||||
|
encoder.
|
||||||
|
|
||||||
|
An example:
|
||||||
|
|
||||||
|
```
|
||||||
|
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
|
||||||
|
encoder_outputs, multiplier=beam_width)
|
||||||
|
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
|
||||||
|
encoder_final_state, multiplier=beam_width)
|
||||||
|
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
|
||||||
|
sequence_length, multiplier=beam_width)
|
||||||
|
attention_mechanism = MyFavoriteAttentionMechanism(
|
||||||
|
num_units=attention_depth,
|
||||||
|
memory=tiled_inputs,
|
||||||
|
memory_sequence_length=tiled_sequence_length)
|
||||||
|
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
|
||||||
|
decoder_initial_state = attention_cell.zero_state(
|
||||||
|
dtype, batch_size=true_batch_size * beam_width)
|
||||||
|
decoder_initial_state = decoder_initial_state.clone(
|
||||||
|
cell_state=tiled_encoder_final_state)
|
||||||
|
```
|
||||||
|
|
||||||
|
Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
|
||||||
|
when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages
|
||||||
|
the decoder to cover all inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cell,
|
||||||
|
embedding,
|
||||||
|
start_tokens,
|
||||||
|
end_token,
|
||||||
|
initial_state,
|
||||||
|
beam_width,
|
||||||
|
output_layer=None,
|
||||||
|
length_penalty_weight=0.0,
|
||||||
|
coverage_penalty_weight=0.0,
|
||||||
|
reorder_tensor_arrays=True):
|
||||||
|
"""Initialize the BeamSearchDecoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: An `RNNCell` instance.
|
||||||
|
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
|
||||||
|
or the `params` argument for `embedding_lookup`.
|
||||||
|
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
|
||||||
|
end_token: `int32` scalar, the token that marks end of decoding.
|
||||||
|
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
|
||||||
|
beam_width: Python integer, the number of beams.
|
||||||
|
output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
|
||||||
|
`tf.keras.layers.Dense`. Optional layer to apply to the RNN output
|
||||||
|
prior to storing the result or sampling.
|
||||||
|
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
|
||||||
|
coverage_penalty_weight: Float weight to penalize the coverage of source
|
||||||
|
sentence. Disabled with 0.0.
|
||||||
|
reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
|
||||||
|
state will be reordered according to the beam search path. If the
|
||||||
|
`TensorArray` can be reordered, the stacked form will be returned.
|
||||||
|
Otherwise, the `TensorArray` will be returned as is. Set this flag to
|
||||||
|
`False` if the cell state contains `TensorArray`s that are not amenable
|
||||||
|
to reordering.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `cell` is not an instance of `RNNCell`,
|
||||||
|
or `output_layer` is not an instance of `tf.keras.layers.Layer`.
|
||||||
|
ValueError: If `start_tokens` is not a vector or
|
||||||
|
`end_token` is not a scalar.
|
||||||
|
"""
|
||||||
|
super(BeamSearchDecoder, self).__init__(
|
||||||
|
cell,
|
||||||
|
beam_width,
|
||||||
|
output_layer=output_layer,
|
||||||
|
length_penalty_weight=length_penalty_weight,
|
||||||
|
coverage_penalty_weight=coverage_penalty_weight,
|
||||||
|
reorder_tensor_arrays=reorder_tensor_arrays)
|
||||||
|
|
||||||
|
if callable(embedding):
|
||||||
|
self._embedding_fn = embedding
|
||||||
|
else:
|
||||||
|
self._embedding_fn = (
|
||||||
|
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
|
||||||
|
|
||||||
|
self._start_tokens = ops.convert_to_tensor(
|
||||||
|
start_tokens, dtype=dtypes.int32, name="start_tokens")
|
||||||
|
if self._start_tokens.get_shape().ndims != 1:
|
||||||
|
raise ValueError("start_tokens must be a vector")
|
||||||
|
self._end_token = ops.convert_to_tensor(
|
||||||
|
end_token, dtype=dtypes.int32, name="end_token")
|
||||||
|
if self._end_token.get_shape().ndims != 0:
|
||||||
|
raise ValueError("end_token must be a scalar")
|
||||||
|
|
||||||
|
self._batch_size = array_ops.size(start_tokens)
|
||||||
|
self._initial_cell_state = nest.map_structure(
|
||||||
|
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
|
||||||
|
self._start_tokens = array_ops.tile(
|
||||||
|
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
|
||||||
|
self._start_inputs = self._embedding_fn(self._start_tokens)
|
||||||
|
|
||||||
|
self._finished = array_ops.one_hot(
|
||||||
|
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
|
||||||
|
depth=self._beam_width,
|
||||||
|
on_value=False,
|
||||||
|
off_value=True,
|
||||||
|
dtype=dtypes.bool)
|
||||||
|
|
||||||
|
def initialize(self, name=None):
|
||||||
|
"""Initialize the decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name scope for any created operations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`(finished, start_inputs, initial_state)`.
|
||||||
|
"""
|
||||||
|
finished, start_inputs = self._finished, self._start_inputs
|
||||||
|
|
||||||
|
dtype = nest.flatten(self._initial_cell_state)[0].dtype
|
||||||
|
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
|
||||||
|
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
|
||||||
|
depth=self._beam_width,
|
||||||
|
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
|
||||||
|
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
|
||||||
|
dtype=dtype)
|
||||||
|
init_attention_probs = get_attention_probs(
|
||||||
|
self._initial_cell_state, self._coverage_penalty_weight)
|
||||||
|
if init_attention_probs is None:
|
||||||
|
init_attention_probs = ()
|
||||||
|
|
||||||
|
initial_state = BeamSearchDecoderState(
|
||||||
|
cell_state=self._initial_cell_state,
|
||||||
|
log_probs=log_probs,
|
||||||
|
finished=finished,
|
||||||
|
lengths=array_ops.zeros(
|
||||||
|
[self._batch_size, self._beam_width], dtype=dtypes.int64),
|
||||||
|
accumulated_attention_probs=init_attention_probs)
|
||||||
|
|
||||||
|
return (finished, start_inputs, initial_state)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dtype(self):
|
||||||
|
# Assume the dtype of the cell is the output_size structure
|
||||||
|
# containing the input_state's first component's dtype.
|
||||||
|
# Return that structure and int32 (the id)
|
||||||
|
dtype = nest.flatten(self._initial_cell_state)[0].dtype
|
||||||
|
return BeamSearchDecoderOutput(
|
||||||
|
scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
|
||||||
|
predicted_ids=dtypes.int32,
|
||||||
|
parent_ids=dtypes.int32)
|
||||||
|
|
||||||
|
|
||||||
|
class BeamSearchDecoderV2(BeamSearchDecoderMixin, decoder.BaseDecoder):
|
||||||
|
# Note that the inheritance hierarchy is important here. The Mixin has to be
|
||||||
|
# the first parent class since we will use super().__init__(), and Mixin which
|
||||||
|
# is a object will properly invoke the __init__ method of other parent class.
|
||||||
|
"""BeamSearch sampling decoder.
|
||||||
|
|
||||||
|
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
|
||||||
|
`AttentionWrapper`, then you must ensure that:
|
||||||
|
|
||||||
|
- The encoder output has been tiled to `beam_width` via
|
||||||
|
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
|
||||||
|
- The `batch_size` argument passed to the `zero_state` method of this
|
||||||
|
wrapper is equal to `true_batch_size * beam_width`.
|
||||||
|
- The initial state created with `zero_state` above contains a
|
||||||
|
`cell_state` value containing properly tiled final state from the
|
||||||
|
encoder.
|
||||||
|
|
||||||
|
An example:
|
||||||
|
|
||||||
|
```
|
||||||
|
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
|
||||||
|
encoder_outputs, multiplier=beam_width)
|
||||||
|
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
|
||||||
|
encoder_final_state, multiplier=beam_width)
|
||||||
|
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
|
||||||
|
sequence_length, multiplier=beam_width)
|
||||||
|
attention_mechanism = MyFavoriteAttentionMechanism(
|
||||||
|
num_units=attention_depth,
|
||||||
|
memory=tiled_inputs,
|
||||||
|
memory_sequence_length=tiled_sequence_length)
|
||||||
|
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
|
||||||
|
decoder_initial_state = attention_cell.zero_state(
|
||||||
|
dtype, batch_size=true_batch_size * beam_width)
|
||||||
|
decoder_initial_state = decoder_initial_state.clone(
|
||||||
|
cell_state=tiled_encoder_final_state)
|
||||||
|
```
|
||||||
|
|
||||||
|
Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
|
||||||
|
when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages
|
||||||
|
the decoding to cover all inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cell,
|
||||||
|
beam_width,
|
||||||
|
embedding_fn=None,
|
||||||
|
output_layer=None,
|
||||||
|
length_penalty_weight=0.0,
|
||||||
|
coverage_penalty_weight=0.0,
|
||||||
|
reorder_tensor_arrays=True,
|
||||||
|
**kwargs):
|
||||||
|
"""Initialize the BeamSearchDecoderV2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: An `RNNCell` instance.
|
||||||
|
beam_width: Python integer, the number of beams.
|
||||||
|
embedding_fn: A callable that takes a vector tensor of `ids` (argmax ids).
|
||||||
|
output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
|
||||||
|
`tf.keras.layers.Dense`. Optional layer to apply to the RNN output
|
||||||
|
prior to storing the result or sampling.
|
||||||
|
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
|
||||||
|
coverage_penalty_weight: Float weight to penalize the coverage of source
|
||||||
|
sentence. Disabled with 0.0.
|
||||||
|
reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
|
||||||
|
state will be reordered according to the beam search path. If the
|
||||||
|
`TensorArray` can be reordered, the stacked form will be returned.
|
||||||
|
Otherwise, the `TensorArray` will be returned as is. Set this flag to
|
||||||
|
`False` if the cell state contains `TensorArray`s that are not amenable
|
||||||
|
to reordering.
|
||||||
|
**kwargs: Dict, other keyword arguments for initialization.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `cell` is not an instance of `RNNCell`,
|
||||||
|
or `output_layer` is not an instance of `tf.keras.layers.Layer`.
|
||||||
|
"""
|
||||||
|
super(BeamSearchDecoderV2, self).__init__(
|
||||||
|
cell,
|
||||||
|
beam_width,
|
||||||
|
output_layer=output_layer,
|
||||||
|
length_penalty_weight=length_penalty_weight,
|
||||||
|
coverage_penalty_weight=coverage_penalty_weight,
|
||||||
|
reorder_tensor_arrays=reorder_tensor_arrays,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
if embedding_fn is None or callable(embedding_fn):
|
||||||
|
self._embedding_fn = embedding_fn
|
||||||
|
else:
|
||||||
|
raise ValueError("embedding_fn is expected to be a callable, got %s" %
|
||||||
|
type(embedding_fn))
|
||||||
|
|
||||||
|
def initialize(self,
|
||||||
|
embedding,
|
||||||
|
start_tokens,
|
||||||
|
end_token,
|
||||||
|
initial_state):
|
||||||
|
"""Initialize the decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: A tensor from the embedding layer output, which is the
|
||||||
|
`params` argument for `embedding_lookup`.
|
||||||
|
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
|
||||||
|
end_token: `int32` scalar, the token that marks end of decoding.
|
||||||
|
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
|
||||||
|
Returns:
|
||||||
|
`(finished, start_inputs, initial_state)`.
|
||||||
|
Raises:
|
||||||
|
ValueError: If `start_tokens` is not a vector or `end_token` is not a
|
||||||
|
scalar.
|
||||||
|
"""
|
||||||
|
if embedding is not None and self._embedding_fn is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"embedding and embedding_fn cannot be provided at same time")
|
||||||
|
elif embedding is not None:
|
||||||
|
self._embedding_fn = (
|
||||||
|
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
|
||||||
|
|
||||||
|
self._start_tokens = ops.convert_to_tensor(
|
||||||
|
start_tokens, dtype=dtypes.int32, name="start_tokens")
|
||||||
|
if self._start_tokens.get_shape().ndims != 1:
|
||||||
|
raise ValueError("start_tokens must be a vector")
|
||||||
|
self._end_token = ops.convert_to_tensor(
|
||||||
|
end_token, dtype=dtypes.int32, name="end_token")
|
||||||
|
if self._end_token.get_shape().ndims != 0:
|
||||||
|
raise ValueError("end_token must be a scalar")
|
||||||
|
|
||||||
|
self._batch_size = array_ops.size(start_tokens)
|
||||||
|
self._initial_cell_state = nest.map_structure(
|
||||||
|
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
|
||||||
|
self._start_tokens = array_ops.tile(
|
||||||
|
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
|
||||||
|
self._start_inputs = self._embedding_fn(self._start_tokens)
|
||||||
|
|
||||||
|
self._finished = array_ops.one_hot(
|
||||||
|
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
|
||||||
|
depth=self._beam_width,
|
||||||
|
on_value=False,
|
||||||
|
off_value=True,
|
||||||
|
dtype=dtypes.bool)
|
||||||
|
|
||||||
|
finished, start_inputs = self._finished, self._start_inputs
|
||||||
|
|
||||||
|
dtype = nest.flatten(self._initial_cell_state)[0].dtype
|
||||||
|
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
|
||||||
|
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
|
||||||
|
depth=self._beam_width,
|
||||||
|
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
|
||||||
|
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
|
||||||
|
dtype=dtype)
|
||||||
|
init_attention_probs = get_attention_probs(
|
||||||
|
self._initial_cell_state, self._coverage_penalty_weight)
|
||||||
|
if init_attention_probs is None:
|
||||||
|
init_attention_probs = ()
|
||||||
|
|
||||||
|
initial_state = BeamSearchDecoderState(
|
||||||
|
cell_state=self._initial_cell_state,
|
||||||
|
log_probs=log_probs,
|
||||||
|
finished=finished,
|
||||||
|
lengths=array_ops.zeros(
|
||||||
|
[self._batch_size, self._beam_width], dtype=dtypes.int64),
|
||||||
|
accumulated_attention_probs=init_attention_probs)
|
||||||
|
|
||||||
|
return (finished, start_inputs, initial_state)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dtype(self):
|
||||||
|
# Assume the dtype of the cell is the output_size structure
|
||||||
|
# containing the input_state's first component's dtype.
|
||||||
|
# Return that structure and int32 (the id)
|
||||||
|
dtype = nest.flatten(self._initial_cell_state)[0].dtype
|
||||||
|
return BeamSearchDecoderOutput(
|
||||||
|
scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
|
||||||
|
predicted_ids=dtypes.int32,
|
||||||
|
parent_ids=dtypes.int32)
|
||||||
|
|
||||||
|
def call(self, embeddning, start_tokens, end_token, initial_state, **kwargs):
|
||||||
|
init_kwargs = kwargs
|
||||||
|
init_kwargs["start_tokens"] = start_tokens
|
||||||
|
init_kwargs["end_token"] = end_token
|
||||||
|
init_kwargs["initial_state"] = initial_state
|
||||||
|
return decoder.dynamic_decode(self,
|
||||||
|
output_time_major=self.output_time_major,
|
||||||
|
impute_finished=self.impute_finished,
|
||||||
|
maximum_iterations=self.maximum_iterations,
|
||||||
|
parallel_iterations=self.parallel_iterations,
|
||||||
|
swap_memory=self.swap_memory,
|
||||||
|
decoder_init_input=embeddning,
|
||||||
|
decoder_init_kwargs=init_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
|
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
|
||||||
beam_width, end_token, length_penalty_weight,
|
beam_width, end_token, length_penalty_weight,
|
||||||
coverage_penalty_weight):
|
coverage_penalty_weight):
|
||||||
@ -1068,7 +1323,7 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
|
|||||||
"""
|
"""
|
||||||
if isinstance(gather_from, tensor_array_ops.TensorArray):
|
if isinstance(gather_from, tensor_array_ops.TensorArray):
|
||||||
return gather_from
|
return gather_from
|
||||||
_check_maybe(gather_from)
|
_check_ndims(gather_from)
|
||||||
if gather_from.shape.ndims >= len(gather_shape):
|
if gather_from.shape.ndims >= len(gather_shape):
|
||||||
return _tensor_gather_helper(
|
return _tensor_gather_helper(
|
||||||
gather_indices=gather_indices,
|
gather_indices=gather_indices,
|
||||||
|
Loading…
Reference in New Issue
Block a user