[tf.contrib.seq2seq] Bugfixes to BeamSearchDecoder and GatherTree.

1. Begin the gather tree at the maximum sequence length across all beams (within the batch).
2. Take a second pass starting from t=0 and mask out any beam ids past the *first* beam occurrence of end_token.
3. Update the final sequence lengths to include the first <eos> token in the beam.
4. Update dynamic_decode to allow the BeamSearchDecoder to keep track of its own "finished" states, as the shuffling in the decoder confused the tracking mechanism in dynamic_decode.  This fixes a bug where beam search decoding stops early.
5. Cap sequence length used in GatherTree to min(max_time, max_seq_len(b)) to avoid accessing memory outside the dimensions of input matrices.

Bugs caught by @bdaskalov on github and Pavel Sountsov.  Proper solution and analysis thanks to Rui Zhao.  Thanks all!

Fixes #13536.

PiperOrigin-RevId: 172471462
This commit is contained in:
Eugene Brevdo 2017-10-17 08:48:29 -07:00 committed by TensorFlower Gardener
parent a1ba9f3bf1
commit 18f89c81d2
8 changed files with 217 additions and 147 deletions

View File

@ -49,40 +49,46 @@ class GatherTreeOp : public OpKernel {
const Device& device = ctx->eigen_device<Device>();
const Tensor& step_ids = ctx->input(0);
const Tensor& parent_ids = ctx->input(1);
const Tensor& sequence_length = ctx->input(2);
const Tensor& max_sequence_lengths = ctx->input(2);
const Tensor& end_token = ctx->input(3);
const TensorShape& step_ids_shape = step_ids.shape();
OP_REQUIRES(
ctx, step_ids_shape.dims() == 3,
errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ",
step_ids_shape.DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(max_sequence_lengths.shape()),
errors::InvalidArgument(
"max_sequence_lengths must be a vector, saw shape: ",
max_sequence_lengths.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsMatrix(sequence_length.shape()),
errors::InvalidArgument("sequence_length must be a matrix, saw shape: ",
sequence_length.shape().DebugString()));
OP_REQUIRES(ctx, sequence_length.dim_size(0) == step_ids_shape.dim_size(1),
errors::InvalidArgument(
"Inconsistent batch sizes: sequence_length.shape[0] (",
sequence_length.dim_size(0), ") != ", "step_ids.shape[1] (",
step_ids_shape.dim_size(1), ")"));
OP_REQUIRES(ctx, sequence_length.dim_size(1) == step_ids_shape.dim_size(2),
errors::InvalidArgument(
"Inconsistent batch sizes: sequence_length.shape[1] (",
sequence_length.dim_size(1), ") != ", "step_ids.shape[2] (",
step_ids_shape.dim_size(2), ")"));
ctx, TensorShapeUtils::IsScalar(end_token.shape()),
errors::InvalidArgument("end_token must be a scalar, saw shape: ",
end_token.shape().DebugString()));
OP_REQUIRES(
ctx, step_ids_shape == parent_ids.shape(),
errors::InvalidArgument(
"step_ids.shape must match parent_ids.shape. but shapes are: ",
step_ids_shape.DebugString(), " and ",
parent_ids.shape().DebugString()));
OP_REQUIRES(
ctx,
step_ids_shape.dim_size(1) == max_sequence_lengths.shape().dim_size(0),
errors::InvalidArgument("batch size dimensions step_ids.shape[1] and "
"max_seqeuence_lengths.shape[0] must match. "
"but shapes are: ",
step_ids_shape.DebugString(), " and ",
max_sequence_lengths.shape().DebugString()));
Tensor* beams;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams));
typename TTypes<T, 3>::ConstTensor step_ids_t = step_ids.tensor<T, 3>();
typename TTypes<T, 3>::ConstTensor parent_ids_t = parent_ids.tensor<T, 3>();
typename TTypes<T>::ConstMatrix seq_len_t = sequence_length.matrix<T>();
typename TTypes<int32>::ConstVec max_seq_lens_t =
max_sequence_lengths.vec<int32>();
typename TTypes<T>::ConstScalar end_token_t = end_token.scalar<T>();
typename TTypes<T, 3>::Tensor beams_t = beams->tensor<T, 3>();
const T end_token_value = end_token_t();
functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t,
seq_len_t, beams_t);
max_seq_lens_t, end_token_value, beams_t);
}
};
@ -99,27 +105,29 @@ namespace functor {
template <>
struct GatherTree<CPUDevice, int32> {
void operator()(OpKernelContext* ctx, const CPUDevice& d,
typename TTypes<int32, 3>::ConstTensor step_ids,
typename TTypes<int32, 3>::ConstTensor parent_ids,
typename TTypes<int32>::ConstMatrix sequence_length,
typename TTypes<int32, 3>::Tensor beams) {
const int64 max_time = parent_ids.dimension(0);
const int64 batch_size = parent_ids.dimension(1);
const int64 beam_width = parent_ids.dimension(2);
TTypes<int32, 3>::ConstTensor step_ids,
TTypes<int32, 3>::ConstTensor parent_ids,
TTypes<int32>::ConstVec max_sequence_lengths,
const int32 end_token, TTypes<int32, 3>::Tensor beams) {
const int32 max_time = parent_ids.dimension(0);
const int32 batch_size = parent_ids.dimension(1);
const int32 beam_width = parent_ids.dimension(2);
beams.setConstant(-1);
auto DoWork = [&, ctx](int start_batch_beam, int limit_batch_beam) {
auto DoWork = [&, ctx, end_token](int start_batch_beam,
int limit_batch_beam) {
for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) {
const int32 batch = i / beam_width;
const int32 beam = i % beam_width;
int32 seq_len_b = sequence_length(batch, beam);
if (seq_len_b <= 0) {
const int32 max_seq_len_b =
Eigen::numext::mini(max_time, max_sequence_lengths(batch));
if (max_seq_len_b <= 0) {
continue;
}
beams(seq_len_b - 1, batch, beam) =
step_ids(seq_len_b - 1, batch, beam);
int32 parent = parent_ids(seq_len_b - 1, batch, beam);
for (int32 level = seq_len_b - 2; level >= 0; --level) {
beams(max_seq_len_b - 1, batch, beam) =
step_ids(max_seq_len_b - 1, batch, beam);
int32 parent = parent_ids(max_seq_len_b - 1, batch, beam);
for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
if (parent < 0 || parent > beam_width) {
ctx->SetStatus(
errors::InvalidArgument("Saw invalid parent id ", parent,
@ -130,6 +138,14 @@ struct GatherTree<CPUDevice, int32> {
beams(level, batch, beam) = step_ids(level, batch, parent);
parent = parent_ids(level, batch, parent);
}
bool finished = false;
for (int32 time = 0; time < max_seq_len_b; ++time) {
if (finished) {
beams(time, batch, beam) = -1;
} else if (beams(time, batch, beam) == end_token) {
finished = true;
}
}
}
};
// Guesstimate of cost; ~5 lookup/store/compare per inner beam
@ -137,7 +153,7 @@ struct GatherTree<CPUDevice, int32> {
const int64 batch_beam_cost =
Eigen::TensorOpCost::DivCost<int32>() +
6 * Eigen::TensorOpCost::AddCost<int32>() +
max_time * (5 * Eigen::TensorOpCost::AddCost<int32>());
2 * max_time * (5 * Eigen::TensorOpCost::AddCost<int32>());
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers,
batch_size * beam_width, batch_beam_cost, DoWork);
@ -148,24 +164,26 @@ struct GatherTree<CPUDevice, int32> {
#if GOOGLE_CUDA
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void GatherTree<GPUDevice, T>::operator()( \
OpKernelContext* ctx, const GPUDevice& d, \
typename TTypes<T, 3>::ConstTensor step_ids, \
typename TTypes<T, 3>::ConstTensor parent_ids, \
typename TTypes<T>::ConstMatrix sequence_length, \
typename TTypes<T, 3>::Tensor beams); \
#define DECLARE_GPU_SPEC(T) \
template <> \
void GatherTree<GPUDevice, T>::operator()( \
OpKernelContext* ctx, const GPUDevice& d, \
typename TTypes<T, 3>::ConstTensor step_ids, \
typename TTypes<T, 3>::ConstTensor parent_ids, \
TTypes<int32>::ConstVec max_sequence_lengths, const T end_token, \
typename TTypes<T, 3>::Tensor beams); \
extern template struct GatherTree<GPUDevice, T>;
DECLARE_GPU_SPEC(int32);
#undef DECLARE_GPU_SPEC
} // end namespace functor
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("GatherTree").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
GatherTreeOp<GPUDevice, T>);
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("GatherTree") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.HostMemory("end_token"), \
GatherTreeOp<GPUDevice, T>);
REGISTER_GPU_KERNEL(int32);
#undef REGISTER_GPU_KERNEL

View File

@ -31,8 +31,8 @@ struct GatherTree {
void operator()(OpKernelContext* ctx, const Device& d,
typename TTypes<T, 3>::ConstTensor step_ids,
typename TTypes<T, 3>::ConstTensor parent_ids,
typename TTypes<T>::ConstMatrix sequence_length,
typename TTypes<T, 3>::Tensor beams);
TTypes<int32>::ConstVec max_sequence_lengths,
const T end_token, typename TTypes<T, 3>::Tensor beams);
};
} // namespace functor

View File

@ -29,20 +29,24 @@ template <typename T>
__global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
const int32 beam_width, const T* step_ids,
const T* parent_ids,
const T* sequence_length, T* beams) {
const int32* max_sequence_lengths,
const T end_token, T* beams) {
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) {
const int32 batch = i / beam_width;
const int32 beam = i % beam_width;
const int32 seq_len_b = ldg(sequence_length + batch * beam_width + beam);
if (seq_len_b <= 0) continue;
const int32 max_seq_len_b =
Eigen::numext::mini(max_time, ldg(max_sequence_lengths + batch));
if (max_seq_len_b <= 0) {
continue;
}
#define GET_IX(time_ix, beam_ix) \
(batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix))
const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam);
const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam);
beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix);
int32 parent = ldg(parent_ids + initial_beam_ix);
for (int32 level = seq_len_b - 2; level >= 0; --level) {
for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
const int32 level_beam_ix = GET_IX(level, beam);
const int32 level_parent_ix = GET_IX(level, parent);
if (parent < 0 || parent > beam_width) {
@ -53,6 +57,15 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
parent = ldg(parent_ids + level_parent_ix);
}
}
bool finished = false;
for (int32 time = 0; time < max_seq_len_b; ++time) {
const int32 level_beam_ix = GET_IX(time, beam);
if (finished) {
beams[level_beam_ix] = -1;
} else if (beams[level_beam_ix] == end_token) {
finished = true;
}
}
#undef GET_IX
}
}
@ -62,8 +75,8 @@ struct GatherTree<GPUDevice, T> {
void operator()(OpKernelContext* ctx, const GPUDevice& d,
typename TTypes<T, 3>::ConstTensor step_ids,
typename TTypes<T, 3>::ConstTensor parent_ids,
typename TTypes<T>::ConstMatrix sequence_length,
typename TTypes<T, 3>::Tensor beams) {
TTypes<int32>::ConstVec max_sequence_length,
const T end_token, typename TTypes<T, 3>::Tensor beams) {
const int32 max_time = parent_ids.dimension(0);
const int32 batch_size = parent_ids.dimension(1);
const int32 beam_width = parent_ids.dimension(2);
@ -75,7 +88,10 @@ struct GatherTree<GPUDevice, T> {
GatherTreeOpKernel<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
batch_size, max_time, beam_width,
step_ids.data(), parent_ids.data(), sequence_length.data(),
step_ids.data(),
parent_ids.data(),
max_sequence_length.data(),
end_token,
beams.data());
// clang-format on
}

View File

@ -25,27 +25,27 @@ using shape_inference::ShapeHandle;
REGISTER_OP("GatherTree")
.Input("step_ids: T")
.Input("parent_ids: T")
.Input("sequence_length: T")
.Input("max_sequence_lengths: int32")
.Input("end_token: T")
.Output("beams: T")
.Attr("T: {int32}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle step_ids, parent_ids, sequence_length;
ShapeHandle step_ids, parent_ids, max_sequence_lengths, end_token;
// step_ids, parent_ids, and output are all shaped:
// [max_time, batch_size, beam_width].
// sequence_length is shaped [batch_size, beam_width].
// max_sequence_length is shaped [batch_size] and end_token is a scalar.
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &step_ids));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &parent_ids));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &sequence_length));
DimensionHandle batch_size = c->Dim(step_ids, 1);
DimensionHandle beam_width = c->Dim(step_ids, 2);
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max_sequence_lengths));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &end_token));
TF_RETURN_IF_ERROR(c->Merge(step_ids, parent_ids, &step_ids));
DimensionHandle batch_size = c->Dim(step_ids, 1);
TF_RETURN_IF_ERROR(
c->Merge(batch_size, c->Dim(sequence_length, 0), &batch_size));
TF_RETURN_IF_ERROR(
c->Merge(beam_width, c->Dim(sequence_length, 1), &beam_width));
c->Merge(batch_size, c->Dim(max_sequence_lengths, 0), &batch_size));
ShapeHandle step_ids_prefix = c->Matrix(c->Dim(step_ids, 0), batch_size);
TF_RETURN_IF_ERROR(c->MergePrefix(step_ids, step_ids_prefix, &step_ids,
&step_ids_prefix));
c->set_output(0, step_ids);
return tensorflow::Status::OK();
@ -61,7 +61,8 @@ TODO(ebrevdo): fill in
step_ids: `[max_time, batch_size, beam_width]`.
parent_ids: `[max_time, batch_size, beam_width]`.
sequence_length: `[batch_size, beam_width]`.
max_sequence_lengths: `[batch_size]`.
end_token: `[]`.
beams: `[max_time, batch_size, beam_width]`.
)doc");

View File

@ -54,15 +54,18 @@ class TestGatherTree(test.TestCase):
[[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
dtype=np.int32).transpose([1, 0, 2])
# sequence_lengths is shaped (batch_size = 2, beam_width = 3)
sequence_lengths = [[3, 3, 3], [3, 3, 3]]
# sequence_lengths is shaped (batch_size = 3)
max_sequence_lengths = [3, 3]
expected_result = np.array(
[[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
[[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2])
res = beam_search_ops.gather_tree(
predicted_ids, parent_ids, sequence_lengths)
predicted_ids,
parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=11)
with self.test_session() as sess:
res_ = sess.run(res)

View File

@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
import itertools
import numpy as np
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
@ -38,12 +40,14 @@ class GatherTreeTest(test.TestCase):
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
sequence_length = [[3, 3, 3]]
max_sequence_lengths = [3]
expected_result = _transpose_batch_time(
[[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
beams = beam_search_ops.gather_tree(
step_ids=step_ids, parent_ids=parent_ids,
sequence_length=sequence_length)
step_ids=step_ids,
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=10)
with self.test_session(use_gpu=True):
self.assertAllEqual(expected_result, beams.eval())
@ -54,11 +58,13 @@ class GatherTreeTest(test.TestCase):
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
sequence_length = [[3, 3, 3]]
max_sequence_lengths = [3]
with ops.device("/cpu:0"):
beams = beam_search_ops.gather_tree(
step_ids=step_ids, parent_ids=parent_ids,
sequence_length=sequence_length)
step_ids=step_ids,
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=10)
with self.test_session():
with self.assertRaisesOpError(
r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
@ -75,78 +81,58 @@ class GatherTreeTest(test.TestCase):
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
sequence_length = [[3, 3, 3]]
max_sequence_lengths = [3]
expected_result = _transpose_batch_time(
[[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
with ops.device("/device:GPU:0"):
beams = beam_search_ops.gather_tree(
step_ids=step_ids, parent_ids=parent_ids,
sequence_length=sequence_length)
step_ids=step_ids,
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=10)
with self.test_session(use_gpu=True):
self.assertAllEqual(expected_result, beams.eval())
def testGatherTreeBatch(self):
# sequence_length is [batch_size, beam_width] = [4, 5]
sequence_length = [[0] * 5, [1] * 5, [2] * 5, [3] * 5]
batch_size = 10
beam_width = 15
max_time = 8
max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0]
end_token = 5
with self.test_session(use_gpu=True):
# (max_time = 4, batch_size = 4, beam_width = 5)
step_ids = _transpose_batch_time(
[[[3, 4, 0, 4, 0],
[4, 2, 0, 3, 1],
[1, 1, 3, 2, 2],
[3, 1, 2, 3, 4]],
[[3, 4, 0, 4, 0],
[4, 2, 0, 3, 1],
[1, 1, 3, 2, 2],
[3, 1, 2, 3, 4]],
[[1, 2, 3, 4, 2],
[2, 1, 1, 3, 2],
[3, 0, 1, 0, 0],
[3, 4, 0, 2, 4]],
[[0, 2, 2, 3, 1],
[3, 2, 2, 2, 3],
[3, 4, 3, 0, 3],
[1, 2, 2, 2, 4]]])
parent_ids = _transpose_batch_time(
[[[4, 2, 4, 3, 4],
[3, 4, 0, 2, 0],
[3, 1, 3, 2, 2],
[0, 2, 1, 4, 2]],
[[4, 2, 4, 3, 4],
[3, 4, 0, 2, 0],
[3, 1, 3, 2, 2],
[0, 2, 1, 4, 2]],
[[3, 0, 0, 4, 0],
[1, 2, 4, 2, 2],
[4, 4, 0, 3, 0],
[2, 4, 4, 3, 0]],
[[3, 1, 4, 1, 3],
[3, 2, 4, 0, 4],
[1, 0, 1, 4, 2],
[0, 3, 2, 0, 1]]])
expected_beams = _transpose_batch_time(
[[[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1]],
[[3, 4, 0, 4, 0],
[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1]],
[[2, 3, 2, 3, 3],
[2, 1, 1, 3, 2],
[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1]],
[[2, 3, 2, 1, 1],
[2, 3, 2, 3, 2],
[3, 4, 3, 0, 3],
[-1, -1, -1, -1, -1]]])
step_ids = np.random.randint(
0, high=end_token + 1, size=(max_time, batch_size, beam_width))
parent_ids = np.random.randint(
0, high=beam_width - 1, size=(max_time, batch_size, beam_width))
beams = beam_search_ops.gather_tree(
step_ids=step_ids, parent_ids=parent_ids,
sequence_length=sequence_length)
self.assertAllEqual(expected_beams, beams.eval())
step_ids=step_ids.astype(np.int32),
parent_ids=parent_ids.astype(np.int32),
max_sequence_lengths=max_sequence_lengths,
end_token=end_token)
self.assertEqual((max_time, batch_size, beam_width), beams.shape)
beams_value = beams.eval()
for b in range(batch_size):
# Past max_sequence_lengths[b], we emit all -1s.
b_value = beams_value[max_sequence_lengths[b]:, b, :]
self.assertAllClose(b_value, -1. * np.ones_like(b_value))
for batch, beam in itertools.product(
range(batch_size), range(beam_width)):
v = np.squeeze(beams_value[:, batch, beam])
if end_token in v:
found = np.where(v == end_token)[0]
# Should be up to 1 instance of end_token per beam.
self.assertEqual(len(found), 1)
found = found[0]
# If an end_token is found, everything before it should be a
# valid id and everything after it should be -1.
if found > 0:
self.assertAllEqual(
v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool))
self.assertAllClose(
v[found + 1:], -1 * np.ones_like(v[found + 1:]))
if __name__ == "__main__":

View File

@ -253,6 +253,20 @@ class BeamSearchDecoder(decoder.Decoder):
output_shape_with_unknown_batch)
return nest.map_structure(lambda s: s[1:], layer_output_shape)
@property
def tracks_own_finished(self):
"""The BeamSearchDecoder shuffles its beams and their finished state.
For this reason, it conflicts with the `dynamic_decode` function's
tracking of finished states. Setting this property to true avoids
early stopping of decoding due to mismanagement of the finished state
in `dynamic_decode`.
Returns:
`True`.
"""
return True
@property
def output_size(self):
# Return the cell output and the id
@ -303,15 +317,23 @@ class BeamSearchDecoder(decoder.Decoder):
output.
sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`.
The sequence lengths determined for each beam during decode.
**NOTE** These are ignored; the updated sequence lengths are stored in
`final_state.lengths`.
Returns:
outputs: An instance of FinalBeamSearchDecoderOutput where the
outputs: An instance of `FinalBeamSearchDecoderOutput` where the
predicted_ids are the result of calling _gather_tree.
final_state: The same input instance of BeamSearchDecoderState.
final_state: The same input instance of `BeamSearchDecoderState`.
"""
del sequence_lengths
# Get max_sequence_length across all beams for each batch.
max_sequence_lengths = math_ops.to_int32(
math_ops.reduce_max(final_state.lengths, axis=1))
predicted_ids = beam_search_ops.gather_tree(
outputs.predicted_ids, outputs.parent_ids,
sequence_length=sequence_lengths)
outputs.predicted_ids,
outputs.parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=self._end_token)
outputs = FinalBeamSearchDecoderOutput(
beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
return outputs, final_state
@ -588,10 +610,11 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
name="next_beam_finished")
# Calculate the length of the next predictions.
# 1. Finished beams remain unchanged
# 2. Beams that are now finished (EOS predicted) remain unchanged
# 3. Beams that are not yet finished have their length increased by 1
lengths_to_add = math_ops.to_int64(math_ops.logical_not(next_finished))
# 1. Finished beams remain unchanged.
# 2. Beams that are now finished (EOS predicted) have their length
# increased by 1.
# 3. Beams that are not yet finished have their length increased by 1.
lengths_to_add = math_ops.to_int64(math_ops.logical_not(previously_finished))
next_prediction_len = _tensor_gather_helper(
gather_indices=next_beam_ids,
gather_from=beam_state.lengths,

View File

@ -100,16 +100,36 @@ class Decoder(object):
Returns:
`(outputs, next_state, next_inputs, finished)`: `outputs` is an object
containing the decoder output, `next_state` is a (structure of) state tensors
and TensorArrays, `next_inputs` is the tensor that should be used as input for
the next step, `finished` is a boolean tensor telling whether the sequence
is complete, for each sequence in the batch.
containing the decoder output, `next_state` is a (structure of) state
tensors and TensorArrays, `next_inputs` is the tensor that should be used
as input for the next step, `finished` is a boolean tensor telling whether
the sequence is complete, for each sequence in the batch.
"""
raise NotImplementedError
def finalize(self, outputs, final_state, sequence_lengths):
raise NotImplementedError
@property
def tracks_own_finished(self):
"""Describes whether the Decoder keeps track of finished states.
Most decoders will emit a true/false `finished` value independently
at each time step. In this case, the `dynamic_decode` function keeps track
of which batch entries are already finished, and performs a logical OR to
insert new batches to the finished set.
Some decoders, however, shuffle batches / beams between time steps and
`dynamic_decode` will mix up the finished state across these entries because
it does not track the reshuffle across time steps. In this case, it is
up to the decoder to declare that it will keep track of its own finished
state by setting this property to `True`.
Returns:
Python bool.
"""
return False
def _create_zero_outputs(size, dtype, batch_size):
"""Create a zero outputs Tensor structure."""
@ -232,7 +252,10 @@ def dynamic_decode(decoder,
"""
(next_outputs, decoder_state, next_inputs,
decoder_finished) = decoder.step(time, inputs, state)
next_finished = math_ops.logical_or(decoder_finished, finished)
if decoder.tracks_own_finished:
next_finished = decoder_finished
else:
next_finished = math_ops.logical_or(decoder_finished, finished)
if maximum_iterations is not None:
next_finished = math_ops.logical_or(
next_finished, time + 1 >= maximum_iterations)