Modify beam search decoder to use symbolic shape for vocab size if the static shape is not present.

PiperOrigin-RevId: 159852297
This commit is contained in:
A. Unique TensorFlower 2017-06-22 11:25:58 -07:00 committed by TensorFlower Gardener
parent e955cb3f66
commit 1757a2c117

View File

@ -486,7 +486,7 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
total_probs = array_ops.expand_dims(beam_state.log_probs, 2) + step_log_probs
# Calculate the continuation lengths by adding to all continuing beams.
vocab_size = logits.shape[-1].value
vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1]
lengths_to_add = array_ops.one_hot(
indices=array_ops.tile(
array_ops.reshape(end_token, [1, 1]), [batch_size, beam_width]),