[tf.contrib.seq2seq] Reserve -1s in GatherTree for error states.
GatherTree now emits end_token after the first decoded end_token in the path, instead of -1s at the end of each sequence. PiperOrigin-RevId: 172816652
This commit is contained in:
parent
f2250bfe85
commit
e0e4f69397
tensorflow/contrib/seq2seq
kernels
ops
python/kernel_tests
@ -112,7 +112,7 @@ struct GatherTree<CPUDevice, int32> {
|
||||
const int32 max_time = parent_ids.dimension(0);
|
||||
const int32 batch_size = parent_ids.dimension(1);
|
||||
const int32 beam_width = parent_ids.dimension(2);
|
||||
beams.setConstant(-1);
|
||||
beams.setConstant(end_token);
|
||||
|
||||
auto DoWork = [&, ctx, end_token](int start_batch_beam,
|
||||
int limit_batch_beam) {
|
||||
@ -138,10 +138,13 @@ struct GatherTree<CPUDevice, int32> {
|
||||
beams(level, batch, beam) = step_ids(level, batch, parent);
|
||||
parent = parent_ids(level, batch, parent);
|
||||
}
|
||||
// Not necessary when using a BeamSearchDecoder, but necessary
|
||||
// when a user feeds in possibly broken trajectory (i.e., non-eos
|
||||
// entries in a beam following eos entries).
|
||||
bool finished = false;
|
||||
for (int32 time = 0; time < max_seq_len_b; ++time) {
|
||||
if (finished) {
|
||||
beams(time, batch, beam) = -1;
|
||||
beams(time, batch, beam) = end_token;
|
||||
} else if (beams(time, batch, beam) == end_token) {
|
||||
finished = true;
|
||||
}
|
||||
|
@ -46,24 +46,31 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
|
||||
const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam);
|
||||
beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix);
|
||||
int32 parent = ldg(parent_ids + initial_beam_ix);
|
||||
bool found_bad = false;
|
||||
for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
|
||||
const int32 level_beam_ix = GET_IX(level, beam);
|
||||
const int32 level_parent_ix = GET_IX(level, parent);
|
||||
if (parent < 0 || parent > beam_width) {
|
||||
beams[level_beam_ix] = -1;
|
||||
parent = -1;
|
||||
found_bad = true;
|
||||
} else {
|
||||
beams[level_beam_ix] = ldg(step_ids + level_parent_ix);
|
||||
parent = ldg(parent_ids + level_parent_ix);
|
||||
}
|
||||
}
|
||||
bool finished = false;
|
||||
for (int32 time = 0; time < max_seq_len_b; ++time) {
|
||||
const int32 level_beam_ix = GET_IX(time, beam);
|
||||
if (finished) {
|
||||
beams[level_beam_ix] = -1;
|
||||
} else if (beams[level_beam_ix] == end_token) {
|
||||
finished = true;
|
||||
// Not necessary when using a BeamSearchDecoder, but necessary
|
||||
// when a user feeds in possibly broken trajectory (i.e., non-eos
|
||||
// entries in a beam following eos entries).
|
||||
if (!found_bad) {
|
||||
bool finished = false;
|
||||
for (int32 time = 0; time < max_seq_len_b; ++time) {
|
||||
const int32 level_beam_ix = GET_IX(time, beam);
|
||||
if (finished) {
|
||||
beams[level_beam_ix] = end_token;
|
||||
} else if (beams[level_beam_ix] == end_token) {
|
||||
finished = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
#undef GET_IX
|
||||
@ -80,8 +87,8 @@ struct GatherTree<GPUDevice, T> {
|
||||
const int32 max_time = parent_ids.dimension(0);
|
||||
const int32 batch_size = parent_ids.dimension(1);
|
||||
const int32 beam_width = parent_ids.dimension(2);
|
||||
// First kernel launch to zero things out
|
||||
beams.device(d) = beams.constant(T(-1));
|
||||
// First kernel launch to "zero" things out
|
||||
beams.device(d) = beams.constant(end_token);
|
||||
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d);
|
||||
// clang-format off
|
||||
|
@ -53,11 +53,14 @@ REGISTER_OP("GatherTree")
|
||||
.Doc(R"doc(
|
||||
Calculates the full beams from the per-step ids and parent beam ids.
|
||||
|
||||
This op implements the following mathematical equations:
|
||||
On CPU, if an out of bound parent id is found, an error is returned.
|
||||
On GPU, if an out of bound parent id is found, a -1 is stored in the
|
||||
corresponding output value and the execution for that beam returns early.
|
||||
|
||||
```python
|
||||
TODO(ebrevdo): fill in
|
||||
```
|
||||
For a given beam, past the time step containing the first decoded `end_token`
|
||||
all values are filled in with `end_token`.
|
||||
|
||||
TODO(ebrevdo): fill in the remainder of this docstring.
|
||||
|
||||
step_ids: `[max_time, batch_size, beam_width]`.
|
||||
parent_ids: `[max_time, batch_size, beam_width]`.
|
||||
|
@ -36,24 +36,26 @@ class GatherTreeTest(test.TestCase):
|
||||
|
||||
def testGatherTreeOne(self):
|
||||
# (max_time = 4, batch_size = 1, beams = 3)
|
||||
end_token = 10
|
||||
step_ids = _transpose_batch_time(
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
[[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
|
||||
max_sequence_lengths = [3]
|
||||
expected_result = _transpose_batch_time(
|
||||
[[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
expected_result = _transpose_batch_time([[[2, 2, 2], [6, 5, 6], [7, 8, 9],
|
||||
[10, 10, 10]]])
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids,
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=10)
|
||||
end_token=end_token)
|
||||
with self.test_session(use_gpu=True):
|
||||
self.assertAllEqual(expected_result, beams.eval())
|
||||
|
||||
def testBadParentValuesOnCPU(self):
|
||||
# (batch_size = 1, max_time = 4, beams = 3)
|
||||
# bad parent in beam 1 time 1
|
||||
end_token = 10
|
||||
step_ids = _transpose_batch_time(
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
@ -64,7 +66,7 @@ class GatherTreeTest(test.TestCase):
|
||||
step_ids=step_ids,
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=10)
|
||||
end_token=end_token)
|
||||
with self.test_session():
|
||||
with self.assertRaisesOpError(
|
||||
r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
|
||||
@ -77,19 +79,20 @@ class GatherTreeTest(test.TestCase):
|
||||
return
|
||||
# (max_time = 4, batch_size = 1, beams = 3)
|
||||
# bad parent in beam 1 time 1; appears as a negative index at time 0
|
||||
end_token = 10
|
||||
step_ids = _transpose_batch_time(
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
|
||||
max_sequence_lengths = [3]
|
||||
expected_result = _transpose_batch_time(
|
||||
[[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6], [7, 8, 9],
|
||||
[10, 10, 10]]])
|
||||
with ops.device("/device:GPU:0"):
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids,
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=10)
|
||||
end_token=end_token)
|
||||
with self.test_session(use_gpu=True):
|
||||
self.assertAllEqual(expected_result, beams.eval())
|
||||
|
||||
@ -115,24 +118,24 @@ class GatherTreeTest(test.TestCase):
|
||||
self.assertEqual((max_time, batch_size, beam_width), beams.shape)
|
||||
beams_value = beams.eval()
|
||||
for b in range(batch_size):
|
||||
# Past max_sequence_lengths[b], we emit all -1s.
|
||||
# Past max_sequence_lengths[b], we emit all end tokens.
|
||||
b_value = beams_value[max_sequence_lengths[b]:, b, :]
|
||||
self.assertAllClose(b_value, -1. * np.ones_like(b_value))
|
||||
self.assertAllClose(b_value, end_token * np.ones_like(b_value))
|
||||
for batch, beam in itertools.product(
|
||||
range(batch_size), range(beam_width)):
|
||||
v = np.squeeze(beams_value[:, batch, beam])
|
||||
if end_token in v:
|
||||
found_bad = np.where(v == -1)[0]
|
||||
self.assertEqual(0, len(found_bad))
|
||||
found = np.where(v == end_token)[0]
|
||||
# Should be up to 1 instance of end_token per beam.
|
||||
self.assertEqual(len(found), 1)
|
||||
found = found[0]
|
||||
found = found[0] # First occurrence of end_token.
|
||||
# If an end_token is found, everything before it should be a
|
||||
# valid id and everything after it should be -1.
|
||||
if found > 0:
|
||||
self.assertAllEqual(
|
||||
v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool))
|
||||
self.assertAllClose(
|
||||
v[found + 1:], -1 * np.ones_like(v[found + 1:]))
|
||||
self.assertAllClose(v[found + 1:],
|
||||
end_token * np.ones_like(v[found + 1:]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user