[tf.contrib.seq2seq] Bugfixes to BeamSearchDecoder and GatherTree.
1. Begin the gather tree at the maximum sequence length across all beams (within the batch). 2. Take a second pass starting from t=0 and mask out any beam ids past the *first* beam occurrence of end_token. 3. Update the final sequence lengths to include the first <eos> token in the beam. 4. Update dynamic_decode to allow the BeamSearchDecoder to keep track of its own "finished" states, as the shuffling in the decoder confused the tracking mechanism in dynamic_decode. This fixes a bug where beam search decoding stops early. 5. Cap sequence length used in GatherTree to min(max_time, max_seq_len(b)) to avoid accessing memory outside the dimensions of input matrices. Bugs caught by @bdaskalov on github and Pavel Sountsov. Proper solution and analysis thanks to Rui Zhao. Thanks all! Fixes #13536. PiperOrigin-RevId: 172471462
This commit is contained in:
parent
a1ba9f3bf1
commit
18f89c81d2
@ -49,40 +49,46 @@ class GatherTreeOp : public OpKernel {
|
||||
const Device& device = ctx->eigen_device<Device>();
|
||||
const Tensor& step_ids = ctx->input(0);
|
||||
const Tensor& parent_ids = ctx->input(1);
|
||||
const Tensor& sequence_length = ctx->input(2);
|
||||
const Tensor& max_sequence_lengths = ctx->input(2);
|
||||
const Tensor& end_token = ctx->input(3);
|
||||
const TensorShape& step_ids_shape = step_ids.shape();
|
||||
OP_REQUIRES(
|
||||
ctx, step_ids_shape.dims() == 3,
|
||||
errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ",
|
||||
step_ids_shape.DebugString()));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(max_sequence_lengths.shape()),
|
||||
errors::InvalidArgument(
|
||||
"max_sequence_lengths must be a vector, saw shape: ",
|
||||
max_sequence_lengths.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsMatrix(sequence_length.shape()),
|
||||
errors::InvalidArgument("sequence_length must be a matrix, saw shape: ",
|
||||
sequence_length.shape().DebugString()));
|
||||
OP_REQUIRES(ctx, sequence_length.dim_size(0) == step_ids_shape.dim_size(1),
|
||||
errors::InvalidArgument(
|
||||
"Inconsistent batch sizes: sequence_length.shape[0] (",
|
||||
sequence_length.dim_size(0), ") != ", "step_ids.shape[1] (",
|
||||
step_ids_shape.dim_size(1), ")"));
|
||||
OP_REQUIRES(ctx, sequence_length.dim_size(1) == step_ids_shape.dim_size(2),
|
||||
errors::InvalidArgument(
|
||||
"Inconsistent batch sizes: sequence_length.shape[1] (",
|
||||
sequence_length.dim_size(1), ") != ", "step_ids.shape[2] (",
|
||||
step_ids_shape.dim_size(2), ")"));
|
||||
ctx, TensorShapeUtils::IsScalar(end_token.shape()),
|
||||
errors::InvalidArgument("end_token must be a scalar, saw shape: ",
|
||||
end_token.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, step_ids_shape == parent_ids.shape(),
|
||||
errors::InvalidArgument(
|
||||
"step_ids.shape must match parent_ids.shape. but shapes are: ",
|
||||
step_ids_shape.DebugString(), " and ",
|
||||
parent_ids.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
step_ids_shape.dim_size(1) == max_sequence_lengths.shape().dim_size(0),
|
||||
errors::InvalidArgument("batch size dimensions step_ids.shape[1] and "
|
||||
"max_seqeuence_lengths.shape[0] must match. "
|
||||
"but shapes are: ",
|
||||
step_ids_shape.DebugString(), " and ",
|
||||
max_sequence_lengths.shape().DebugString()));
|
||||
Tensor* beams;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams));
|
||||
typename TTypes<T, 3>::ConstTensor step_ids_t = step_ids.tensor<T, 3>();
|
||||
typename TTypes<T, 3>::ConstTensor parent_ids_t = parent_ids.tensor<T, 3>();
|
||||
typename TTypes<T>::ConstMatrix seq_len_t = sequence_length.matrix<T>();
|
||||
typename TTypes<int32>::ConstVec max_seq_lens_t =
|
||||
max_sequence_lengths.vec<int32>();
|
||||
typename TTypes<T>::ConstScalar end_token_t = end_token.scalar<T>();
|
||||
typename TTypes<T, 3>::Tensor beams_t = beams->tensor<T, 3>();
|
||||
const T end_token_value = end_token_t();
|
||||
functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t,
|
||||
seq_len_t, beams_t);
|
||||
max_seq_lens_t, end_token_value, beams_t);
|
||||
}
|
||||
};
|
||||
|
||||
@ -99,27 +105,29 @@ namespace functor {
|
||||
template <>
|
||||
struct GatherTree<CPUDevice, int32> {
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& d,
|
||||
typename TTypes<int32, 3>::ConstTensor step_ids,
|
||||
typename TTypes<int32, 3>::ConstTensor parent_ids,
|
||||
typename TTypes<int32>::ConstMatrix sequence_length,
|
||||
typename TTypes<int32, 3>::Tensor beams) {
|
||||
const int64 max_time = parent_ids.dimension(0);
|
||||
const int64 batch_size = parent_ids.dimension(1);
|
||||
const int64 beam_width = parent_ids.dimension(2);
|
||||
TTypes<int32, 3>::ConstTensor step_ids,
|
||||
TTypes<int32, 3>::ConstTensor parent_ids,
|
||||
TTypes<int32>::ConstVec max_sequence_lengths,
|
||||
const int32 end_token, TTypes<int32, 3>::Tensor beams) {
|
||||
const int32 max_time = parent_ids.dimension(0);
|
||||
const int32 batch_size = parent_ids.dimension(1);
|
||||
const int32 beam_width = parent_ids.dimension(2);
|
||||
beams.setConstant(-1);
|
||||
|
||||
auto DoWork = [&, ctx](int start_batch_beam, int limit_batch_beam) {
|
||||
auto DoWork = [&, ctx, end_token](int start_batch_beam,
|
||||
int limit_batch_beam) {
|
||||
for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) {
|
||||
const int32 batch = i / beam_width;
|
||||
const int32 beam = i % beam_width;
|
||||
int32 seq_len_b = sequence_length(batch, beam);
|
||||
if (seq_len_b <= 0) {
|
||||
const int32 max_seq_len_b =
|
||||
Eigen::numext::mini(max_time, max_sequence_lengths(batch));
|
||||
if (max_seq_len_b <= 0) {
|
||||
continue;
|
||||
}
|
||||
beams(seq_len_b - 1, batch, beam) =
|
||||
step_ids(seq_len_b - 1, batch, beam);
|
||||
int32 parent = parent_ids(seq_len_b - 1, batch, beam);
|
||||
for (int32 level = seq_len_b - 2; level >= 0; --level) {
|
||||
beams(max_seq_len_b - 1, batch, beam) =
|
||||
step_ids(max_seq_len_b - 1, batch, beam);
|
||||
int32 parent = parent_ids(max_seq_len_b - 1, batch, beam);
|
||||
for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
|
||||
if (parent < 0 || parent > beam_width) {
|
||||
ctx->SetStatus(
|
||||
errors::InvalidArgument("Saw invalid parent id ", parent,
|
||||
@ -130,6 +138,14 @@ struct GatherTree<CPUDevice, int32> {
|
||||
beams(level, batch, beam) = step_ids(level, batch, parent);
|
||||
parent = parent_ids(level, batch, parent);
|
||||
}
|
||||
bool finished = false;
|
||||
for (int32 time = 0; time < max_seq_len_b; ++time) {
|
||||
if (finished) {
|
||||
beams(time, batch, beam) = -1;
|
||||
} else if (beams(time, batch, beam) == end_token) {
|
||||
finished = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
// Guesstimate of cost; ~5 lookup/store/compare per inner beam
|
||||
@ -137,7 +153,7 @@ struct GatherTree<CPUDevice, int32> {
|
||||
const int64 batch_beam_cost =
|
||||
Eigen::TensorOpCost::DivCost<int32>() +
|
||||
6 * Eigen::TensorOpCost::AddCost<int32>() +
|
||||
max_time * (5 * Eigen::TensorOpCost::AddCost<int32>());
|
||||
2 * max_time * (5 * Eigen::TensorOpCost::AddCost<int32>());
|
||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||
Shard(worker_threads.num_threads, worker_threads.workers,
|
||||
batch_size * beam_width, batch_beam_cost, DoWork);
|
||||
@ -148,24 +164,26 @@ struct GatherTree<CPUDevice, int32> {
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void GatherTree<GPUDevice, T>::operator()( \
|
||||
OpKernelContext* ctx, const GPUDevice& d, \
|
||||
typename TTypes<T, 3>::ConstTensor step_ids, \
|
||||
typename TTypes<T, 3>::ConstTensor parent_ids, \
|
||||
typename TTypes<T>::ConstMatrix sequence_length, \
|
||||
typename TTypes<T, 3>::Tensor beams); \
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void GatherTree<GPUDevice, T>::operator()( \
|
||||
OpKernelContext* ctx, const GPUDevice& d, \
|
||||
typename TTypes<T, 3>::ConstTensor step_ids, \
|
||||
typename TTypes<T, 3>::ConstTensor parent_ids, \
|
||||
TTypes<int32>::ConstVec max_sequence_lengths, const T end_token, \
|
||||
typename TTypes<T, 3>::Tensor beams); \
|
||||
extern template struct GatherTree<GPUDevice, T>;
|
||||
|
||||
DECLARE_GPU_SPEC(int32);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // end namespace functor
|
||||
|
||||
#define REGISTER_GPU_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("GatherTree").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
GatherTreeOp<GPUDevice, T>);
|
||||
#define REGISTER_GPU_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("GatherTree") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("end_token"), \
|
||||
GatherTreeOp<GPUDevice, T>);
|
||||
|
||||
REGISTER_GPU_KERNEL(int32);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
@ -31,8 +31,8 @@ struct GatherTree {
|
||||
void operator()(OpKernelContext* ctx, const Device& d,
|
||||
typename TTypes<T, 3>::ConstTensor step_ids,
|
||||
typename TTypes<T, 3>::ConstTensor parent_ids,
|
||||
typename TTypes<T>::ConstMatrix sequence_length,
|
||||
typename TTypes<T, 3>::Tensor beams);
|
||||
TTypes<int32>::ConstVec max_sequence_lengths,
|
||||
const T end_token, typename TTypes<T, 3>::Tensor beams);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
@ -29,20 +29,24 @@ template <typename T>
|
||||
__global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
|
||||
const int32 beam_width, const T* step_ids,
|
||||
const T* parent_ids,
|
||||
const T* sequence_length, T* beams) {
|
||||
const int32* max_sequence_lengths,
|
||||
const T end_token, T* beams) {
|
||||
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) {
|
||||
const int32 batch = i / beam_width;
|
||||
const int32 beam = i % beam_width;
|
||||
|
||||
const int32 seq_len_b = ldg(sequence_length + batch * beam_width + beam);
|
||||
if (seq_len_b <= 0) continue;
|
||||
const int32 max_seq_len_b =
|
||||
Eigen::numext::mini(max_time, ldg(max_sequence_lengths + batch));
|
||||
if (max_seq_len_b <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
#define GET_IX(time_ix, beam_ix) \
|
||||
(batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix))
|
||||
const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam);
|
||||
const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam);
|
||||
beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix);
|
||||
int32 parent = ldg(parent_ids + initial_beam_ix);
|
||||
for (int32 level = seq_len_b - 2; level >= 0; --level) {
|
||||
for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
|
||||
const int32 level_beam_ix = GET_IX(level, beam);
|
||||
const int32 level_parent_ix = GET_IX(level, parent);
|
||||
if (parent < 0 || parent > beam_width) {
|
||||
@ -53,6 +57,15 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
|
||||
parent = ldg(parent_ids + level_parent_ix);
|
||||
}
|
||||
}
|
||||
bool finished = false;
|
||||
for (int32 time = 0; time < max_seq_len_b; ++time) {
|
||||
const int32 level_beam_ix = GET_IX(time, beam);
|
||||
if (finished) {
|
||||
beams[level_beam_ix] = -1;
|
||||
} else if (beams[level_beam_ix] == end_token) {
|
||||
finished = true;
|
||||
}
|
||||
}
|
||||
#undef GET_IX
|
||||
}
|
||||
}
|
||||
@ -62,8 +75,8 @@ struct GatherTree<GPUDevice, T> {
|
||||
void operator()(OpKernelContext* ctx, const GPUDevice& d,
|
||||
typename TTypes<T, 3>::ConstTensor step_ids,
|
||||
typename TTypes<T, 3>::ConstTensor parent_ids,
|
||||
typename TTypes<T>::ConstMatrix sequence_length,
|
||||
typename TTypes<T, 3>::Tensor beams) {
|
||||
TTypes<int32>::ConstVec max_sequence_length,
|
||||
const T end_token, typename TTypes<T, 3>::Tensor beams) {
|
||||
const int32 max_time = parent_ids.dimension(0);
|
||||
const int32 batch_size = parent_ids.dimension(1);
|
||||
const int32 beam_width = parent_ids.dimension(2);
|
||||
@ -75,7 +88,10 @@ struct GatherTree<GPUDevice, T> {
|
||||
GatherTreeOpKernel<T>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
batch_size, max_time, beam_width,
|
||||
step_ids.data(), parent_ids.data(), sequence_length.data(),
|
||||
step_ids.data(),
|
||||
parent_ids.data(),
|
||||
max_sequence_length.data(),
|
||||
end_token,
|
||||
beams.data());
|
||||
// clang-format on
|
||||
}
|
||||
|
@ -25,27 +25,27 @@ using shape_inference::ShapeHandle;
|
||||
REGISTER_OP("GatherTree")
|
||||
.Input("step_ids: T")
|
||||
.Input("parent_ids: T")
|
||||
.Input("sequence_length: T")
|
||||
.Input("max_sequence_lengths: int32")
|
||||
.Input("end_token: T")
|
||||
.Output("beams: T")
|
||||
.Attr("T: {int32}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle step_ids, parent_ids, sequence_length;
|
||||
ShapeHandle step_ids, parent_ids, max_sequence_lengths, end_token;
|
||||
|
||||
// step_ids, parent_ids, and output are all shaped:
|
||||
// [max_time, batch_size, beam_width].
|
||||
// sequence_length is shaped [batch_size, beam_width].
|
||||
// max_sequence_length is shaped [batch_size] and end_token is a scalar.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &step_ids));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &parent_ids));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &sequence_length));
|
||||
|
||||
DimensionHandle batch_size = c->Dim(step_ids, 1);
|
||||
DimensionHandle beam_width = c->Dim(step_ids, 2);
|
||||
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max_sequence_lengths));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &end_token));
|
||||
TF_RETURN_IF_ERROR(c->Merge(step_ids, parent_ids, &step_ids));
|
||||
DimensionHandle batch_size = c->Dim(step_ids, 1);
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(batch_size, c->Dim(sequence_length, 0), &batch_size));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(beam_width, c->Dim(sequence_length, 1), &beam_width));
|
||||
c->Merge(batch_size, c->Dim(max_sequence_lengths, 0), &batch_size));
|
||||
ShapeHandle step_ids_prefix = c->Matrix(c->Dim(step_ids, 0), batch_size);
|
||||
TF_RETURN_IF_ERROR(c->MergePrefix(step_ids, step_ids_prefix, &step_ids,
|
||||
&step_ids_prefix));
|
||||
|
||||
c->set_output(0, step_ids);
|
||||
return tensorflow::Status::OK();
|
||||
@ -61,7 +61,8 @@ TODO(ebrevdo): fill in
|
||||
|
||||
step_ids: `[max_time, batch_size, beam_width]`.
|
||||
parent_ids: `[max_time, batch_size, beam_width]`.
|
||||
sequence_length: `[batch_size, beam_width]`.
|
||||
max_sequence_lengths: `[batch_size]`.
|
||||
end_token: `[]`.
|
||||
beams: `[max_time, batch_size, beam_width]`.
|
||||
)doc");
|
||||
|
||||
|
@ -54,15 +54,18 @@ class TestGatherTree(test.TestCase):
|
||||
[[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
|
||||
dtype=np.int32).transpose([1, 0, 2])
|
||||
|
||||
# sequence_lengths is shaped (batch_size = 2, beam_width = 3)
|
||||
sequence_lengths = [[3, 3, 3], [3, 3, 3]]
|
||||
# sequence_lengths is shaped (batch_size = 3)
|
||||
max_sequence_lengths = [3, 3]
|
||||
|
||||
expected_result = np.array(
|
||||
[[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
|
||||
[[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2])
|
||||
|
||||
res = beam_search_ops.gather_tree(
|
||||
predicted_ids, parent_ids, sequence_lengths)
|
||||
predicted_ids,
|
||||
parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=11)
|
||||
|
||||
with self.test_session() as sess:
|
||||
res_ = sess.run(res)
|
||||
|
@ -19,6 +19,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
# pylint: enable=unused-import
|
||||
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
||||
@ -38,12 +40,14 @@ class GatherTreeTest(test.TestCase):
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
[[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
|
||||
sequence_length = [[3, 3, 3]]
|
||||
max_sequence_lengths = [3]
|
||||
expected_result = _transpose_batch_time(
|
||||
[[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids, parent_ids=parent_ids,
|
||||
sequence_length=sequence_length)
|
||||
step_ids=step_ids,
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=10)
|
||||
with self.test_session(use_gpu=True):
|
||||
self.assertAllEqual(expected_result, beams.eval())
|
||||
|
||||
@ -54,11 +58,13 @@ class GatherTreeTest(test.TestCase):
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
|
||||
sequence_length = [[3, 3, 3]]
|
||||
max_sequence_lengths = [3]
|
||||
with ops.device("/cpu:0"):
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids, parent_ids=parent_ids,
|
||||
sequence_length=sequence_length)
|
||||
step_ids=step_ids,
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=10)
|
||||
with self.test_session():
|
||||
with self.assertRaisesOpError(
|
||||
r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
|
||||
@ -75,78 +81,58 @@ class GatherTreeTest(test.TestCase):
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
|
||||
sequence_length = [[3, 3, 3]]
|
||||
max_sequence_lengths = [3]
|
||||
expected_result = _transpose_batch_time(
|
||||
[[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
with ops.device("/device:GPU:0"):
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids, parent_ids=parent_ids,
|
||||
sequence_length=sequence_length)
|
||||
step_ids=step_ids,
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=10)
|
||||
with self.test_session(use_gpu=True):
|
||||
self.assertAllEqual(expected_result, beams.eval())
|
||||
|
||||
def testGatherTreeBatch(self):
|
||||
# sequence_length is [batch_size, beam_width] = [4, 5]
|
||||
sequence_length = [[0] * 5, [1] * 5, [2] * 5, [3] * 5]
|
||||
batch_size = 10
|
||||
beam_width = 15
|
||||
max_time = 8
|
||||
max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0]
|
||||
end_token = 5
|
||||
|
||||
with self.test_session(use_gpu=True):
|
||||
# (max_time = 4, batch_size = 4, beam_width = 5)
|
||||
step_ids = _transpose_batch_time(
|
||||
[[[3, 4, 0, 4, 0],
|
||||
[4, 2, 0, 3, 1],
|
||||
[1, 1, 3, 2, 2],
|
||||
[3, 1, 2, 3, 4]],
|
||||
[[3, 4, 0, 4, 0],
|
||||
[4, 2, 0, 3, 1],
|
||||
[1, 1, 3, 2, 2],
|
||||
[3, 1, 2, 3, 4]],
|
||||
[[1, 2, 3, 4, 2],
|
||||
[2, 1, 1, 3, 2],
|
||||
[3, 0, 1, 0, 0],
|
||||
[3, 4, 0, 2, 4]],
|
||||
[[0, 2, 2, 3, 1],
|
||||
[3, 2, 2, 2, 3],
|
||||
[3, 4, 3, 0, 3],
|
||||
[1, 2, 2, 2, 4]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
[[[4, 2, 4, 3, 4],
|
||||
[3, 4, 0, 2, 0],
|
||||
[3, 1, 3, 2, 2],
|
||||
[0, 2, 1, 4, 2]],
|
||||
[[4, 2, 4, 3, 4],
|
||||
[3, 4, 0, 2, 0],
|
||||
[3, 1, 3, 2, 2],
|
||||
[0, 2, 1, 4, 2]],
|
||||
[[3, 0, 0, 4, 0],
|
||||
[1, 2, 4, 2, 2],
|
||||
[4, 4, 0, 3, 0],
|
||||
[2, 4, 4, 3, 0]],
|
||||
[[3, 1, 4, 1, 3],
|
||||
[3, 2, 4, 0, 4],
|
||||
[1, 0, 1, 4, 2],
|
||||
[0, 3, 2, 0, 1]]])
|
||||
expected_beams = _transpose_batch_time(
|
||||
[[[-1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1]],
|
||||
[[3, 4, 0, 4, 0],
|
||||
[-1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1]],
|
||||
[[2, 3, 2, 3, 3],
|
||||
[2, 1, 1, 3, 2],
|
||||
[-1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1]],
|
||||
[[2, 3, 2, 1, 1],
|
||||
[2, 3, 2, 3, 2],
|
||||
[3, 4, 3, 0, 3],
|
||||
[-1, -1, -1, -1, -1]]])
|
||||
step_ids = np.random.randint(
|
||||
0, high=end_token + 1, size=(max_time, batch_size, beam_width))
|
||||
parent_ids = np.random.randint(
|
||||
0, high=beam_width - 1, size=(max_time, batch_size, beam_width))
|
||||
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids, parent_ids=parent_ids,
|
||||
sequence_length=sequence_length)
|
||||
self.assertAllEqual(expected_beams, beams.eval())
|
||||
step_ids=step_ids.astype(np.int32),
|
||||
parent_ids=parent_ids.astype(np.int32),
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=end_token)
|
||||
|
||||
self.assertEqual((max_time, batch_size, beam_width), beams.shape)
|
||||
beams_value = beams.eval()
|
||||
for b in range(batch_size):
|
||||
# Past max_sequence_lengths[b], we emit all -1s.
|
||||
b_value = beams_value[max_sequence_lengths[b]:, b, :]
|
||||
self.assertAllClose(b_value, -1. * np.ones_like(b_value))
|
||||
for batch, beam in itertools.product(
|
||||
range(batch_size), range(beam_width)):
|
||||
v = np.squeeze(beams_value[:, batch, beam])
|
||||
if end_token in v:
|
||||
found = np.where(v == end_token)[0]
|
||||
# Should be up to 1 instance of end_token per beam.
|
||||
self.assertEqual(len(found), 1)
|
||||
found = found[0]
|
||||
# If an end_token is found, everything before it should be a
|
||||
# valid id and everything after it should be -1.
|
||||
if found > 0:
|
||||
self.assertAllEqual(
|
||||
v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool))
|
||||
self.assertAllClose(
|
||||
v[found + 1:], -1 * np.ones_like(v[found + 1:]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -253,6 +253,20 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
output_shape_with_unknown_batch)
|
||||
return nest.map_structure(lambda s: s[1:], layer_output_shape)
|
||||
|
||||
@property
|
||||
def tracks_own_finished(self):
|
||||
"""The BeamSearchDecoder shuffles its beams and their finished state.
|
||||
|
||||
For this reason, it conflicts with the `dynamic_decode` function's
|
||||
tracking of finished states. Setting this property to true avoids
|
||||
early stopping of decoding due to mismanagement of the finished state
|
||||
in `dynamic_decode`.
|
||||
|
||||
Returns:
|
||||
`True`.
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
# Return the cell output and the id
|
||||
@ -303,15 +317,23 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
output.
|
||||
sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`.
|
||||
The sequence lengths determined for each beam during decode.
|
||||
**NOTE** These are ignored; the updated sequence lengths are stored in
|
||||
`final_state.lengths`.
|
||||
|
||||
Returns:
|
||||
outputs: An instance of FinalBeamSearchDecoderOutput where the
|
||||
outputs: An instance of `FinalBeamSearchDecoderOutput` where the
|
||||
predicted_ids are the result of calling _gather_tree.
|
||||
final_state: The same input instance of BeamSearchDecoderState.
|
||||
final_state: The same input instance of `BeamSearchDecoderState`.
|
||||
"""
|
||||
del sequence_lengths
|
||||
# Get max_sequence_length across all beams for each batch.
|
||||
max_sequence_lengths = math_ops.to_int32(
|
||||
math_ops.reduce_max(final_state.lengths, axis=1))
|
||||
predicted_ids = beam_search_ops.gather_tree(
|
||||
outputs.predicted_ids, outputs.parent_ids,
|
||||
sequence_length=sequence_lengths)
|
||||
outputs.predicted_ids,
|
||||
outputs.parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=self._end_token)
|
||||
outputs = FinalBeamSearchDecoderOutput(
|
||||
beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
|
||||
return outputs, final_state
|
||||
@ -588,10 +610,11 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
|
||||
name="next_beam_finished")
|
||||
|
||||
# Calculate the length of the next predictions.
|
||||
# 1. Finished beams remain unchanged
|
||||
# 2. Beams that are now finished (EOS predicted) remain unchanged
|
||||
# 3. Beams that are not yet finished have their length increased by 1
|
||||
lengths_to_add = math_ops.to_int64(math_ops.logical_not(next_finished))
|
||||
# 1. Finished beams remain unchanged.
|
||||
# 2. Beams that are now finished (EOS predicted) have their length
|
||||
# increased by 1.
|
||||
# 3. Beams that are not yet finished have their length increased by 1.
|
||||
lengths_to_add = math_ops.to_int64(math_ops.logical_not(previously_finished))
|
||||
next_prediction_len = _tensor_gather_helper(
|
||||
gather_indices=next_beam_ids,
|
||||
gather_from=beam_state.lengths,
|
||||
|
@ -100,16 +100,36 @@ class Decoder(object):
|
||||
|
||||
Returns:
|
||||
`(outputs, next_state, next_inputs, finished)`: `outputs` is an object
|
||||
containing the decoder output, `next_state` is a (structure of) state tensors
|
||||
and TensorArrays, `next_inputs` is the tensor that should be used as input for
|
||||
the next step, `finished` is a boolean tensor telling whether the sequence
|
||||
is complete, for each sequence in the batch.
|
||||
containing the decoder output, `next_state` is a (structure of) state
|
||||
tensors and TensorArrays, `next_inputs` is the tensor that should be used
|
||||
as input for the next step, `finished` is a boolean tensor telling whether
|
||||
the sequence is complete, for each sequence in the batch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def finalize(self, outputs, final_state, sequence_lengths):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def tracks_own_finished(self):
|
||||
"""Describes whether the Decoder keeps track of finished states.
|
||||
|
||||
Most decoders will emit a true/false `finished` value independently
|
||||
at each time step. In this case, the `dynamic_decode` function keeps track
|
||||
of which batch entries are already finished, and performs a logical OR to
|
||||
insert new batches to the finished set.
|
||||
|
||||
Some decoders, however, shuffle batches / beams between time steps and
|
||||
`dynamic_decode` will mix up the finished state across these entries because
|
||||
it does not track the reshuffle across time steps. In this case, it is
|
||||
up to the decoder to declare that it will keep track of its own finished
|
||||
state by setting this property to `True`.
|
||||
|
||||
Returns:
|
||||
Python bool.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
def _create_zero_outputs(size, dtype, batch_size):
|
||||
"""Create a zero outputs Tensor structure."""
|
||||
@ -232,7 +252,10 @@ def dynamic_decode(decoder,
|
||||
"""
|
||||
(next_outputs, decoder_state, next_inputs,
|
||||
decoder_finished) = decoder.step(time, inputs, state)
|
||||
next_finished = math_ops.logical_or(decoder_finished, finished)
|
||||
if decoder.tracks_own_finished:
|
||||
next_finished = decoder_finished
|
||||
else:
|
||||
next_finished = math_ops.logical_or(decoder_finished, finished)
|
||||
if maximum_iterations is not None:
|
||||
next_finished = math_ops.logical_or(
|
||||
next_finished, time + 1 >= maximum_iterations)
|
||||
|
Loading…
Reference in New Issue
Block a user