diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc index 95273e2b33e..64973ccccdc 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc @@ -112,7 +112,7 @@ struct GatherTree { const int32 max_time = parent_ids.dimension(0); const int32 batch_size = parent_ids.dimension(1); const int32 beam_width = parent_ids.dimension(2); - beams.setConstant(-1); + beams.setConstant(end_token); auto DoWork = [&, ctx, end_token](int start_batch_beam, int limit_batch_beam) { @@ -138,10 +138,13 @@ struct GatherTree { beams(level, batch, beam) = step_ids(level, batch, parent); parent = parent_ids(level, batch, parent); } + // Not necessary when using a BeamSearchDecoder, but necessary + // when a user feeds in possibly broken trajectory (i.e., non-eos + // entries in a beam following eos entries). bool finished = false; for (int32 time = 0; time < max_seq_len_b; ++time) { if (finished) { - beams(time, batch, beam) = -1; + beams(time, batch, beam) = end_token; } else if (beams(time, batch, beam) == end_token) { finished = true; } diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc index e71efc48cec..bc28d492fe1 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc @@ -46,24 +46,31 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time, const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam); beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix); int32 parent = ldg(parent_ids + initial_beam_ix); + bool found_bad = false; for (int32 level = max_seq_len_b - 2; level >= 0; --level) { const int32 level_beam_ix = GET_IX(level, beam); const int32 level_parent_ix = GET_IX(level, parent); if (parent < 0 || parent > beam_width) { beams[level_beam_ix] = -1; parent = -1; + found_bad = true; } else { beams[level_beam_ix] = ldg(step_ids + level_parent_ix); parent = ldg(parent_ids + level_parent_ix); } } - bool finished = false; - for (int32 time = 0; time < max_seq_len_b; ++time) { - const int32 level_beam_ix = GET_IX(time, beam); - if (finished) { - beams[level_beam_ix] = -1; - } else if (beams[level_beam_ix] == end_token) { - finished = true; + // Not necessary when using a BeamSearchDecoder, but necessary + // when a user feeds in possibly broken trajectory (i.e., non-eos + // entries in a beam following eos entries). + if (!found_bad) { + bool finished = false; + for (int32 time = 0; time < max_seq_len_b; ++time) { + const int32 level_beam_ix = GET_IX(time, beam); + if (finished) { + beams[level_beam_ix] = end_token; + } else if (beams[level_beam_ix] == end_token) { + finished = true; + } } } #undef GET_IX @@ -80,8 +87,8 @@ struct GatherTree { const int32 max_time = parent_ids.dimension(0); const int32 batch_size = parent_ids.dimension(1); const int32 beam_width = parent_ids.dimension(2); - // First kernel launch to zero things out - beams.device(d) = beams.constant(T(-1)); + // First kernel launch to "zero" things out + beams.device(d) = beams.constant(end_token); CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d); // clang-format off diff --git a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc index 231504bfbb3..71539b6f592 100644 --- a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc @@ -53,11 +53,14 @@ REGISTER_OP("GatherTree") .Doc(R"doc( Calculates the full beams from the per-step ids and parent beam ids. -This op implements the following mathematical equations: +On CPU, if an out of bound parent id is found, an error is returned. +On GPU, if an out of bound parent id is found, a -1 is stored in the +corresponding output value and the execution for that beam returns early. -```python -TODO(ebrevdo): fill in -``` +For a given beam, past the time step containing the first decoded `end_token` +all values are filled in with `end_token`. + +TODO(ebrevdo): fill in the remainder of this docstring. step_ids: `[max_time, batch_size, beam_width]`. parent_ids: `[max_time, batch_size, beam_width]`. diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py index f3013148720..277c5b6ef76 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py @@ -36,24 +36,26 @@ class GatherTreeTest(test.TestCase): def testGatherTreeOne(self): # (max_time = 4, batch_size = 1, beams = 3) + end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]]) max_sequence_lengths = [3] - expected_result = _transpose_batch_time( - [[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) + expected_result = _transpose_batch_time([[[2, 2, 2], [6, 5, 6], [7, 8, 9], + [10, 10, 10]]]) beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, - end_token=10) + end_token=end_token) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval()) def testBadParentValuesOnCPU(self): # (batch_size = 1, max_time = 4, beams = 3) # bad parent in beam 1 time 1 + end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( @@ -64,7 +66,7 @@ class GatherTreeTest(test.TestCase): step_ids=step_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, - end_token=10) + end_token=end_token) with self.test_session(): with self.assertRaisesOpError( r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"): @@ -77,19 +79,20 @@ class GatherTreeTest(test.TestCase): return # (max_time = 4, batch_size = 1, beams = 3) # bad parent in beam 1 time 1; appears as a negative index at time 0 + end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) max_sequence_lengths = [3] - expected_result = _transpose_batch_time( - [[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) + expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6], [7, 8, 9], + [10, 10, 10]]]) with ops.device("/device:GPU:0"): beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, - end_token=10) + end_token=end_token) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval()) @@ -115,24 +118,24 @@ class GatherTreeTest(test.TestCase): self.assertEqual((max_time, batch_size, beam_width), beams.shape) beams_value = beams.eval() for b in range(batch_size): - # Past max_sequence_lengths[b], we emit all -1s. + # Past max_sequence_lengths[b], we emit all end tokens. b_value = beams_value[max_sequence_lengths[b]:, b, :] - self.assertAllClose(b_value, -1. * np.ones_like(b_value)) + self.assertAllClose(b_value, end_token * np.ones_like(b_value)) for batch, beam in itertools.product( range(batch_size), range(beam_width)): v = np.squeeze(beams_value[:, batch, beam]) if end_token in v: + found_bad = np.where(v == -1)[0] + self.assertEqual(0, len(found_bad)) found = np.where(v == end_token)[0] - # Should be up to 1 instance of end_token per beam. - self.assertEqual(len(found), 1) - found = found[0] + found = found[0] # First occurrence of end_token. # If an end_token is found, everything before it should be a # valid id and everything after it should be -1. if found > 0: self.assertAllEqual( v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool)) - self.assertAllClose( - v[found + 1:], -1 * np.ones_like(v[found + 1:])) + self.assertAllClose(v[found + 1:], + end_token * np.ones_like(v[found + 1:])) if __name__ == "__main__":