Modify beam search decoder to use symbolic shape for vocab size if the static shape is not present.
PiperOrigin-RevId: 159852297
This commit is contained in:
parent
e955cb3f66
commit
1757a2c117
@ -486,7 +486,7 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
|
||||
total_probs = array_ops.expand_dims(beam_state.log_probs, 2) + step_log_probs
|
||||
|
||||
# Calculate the continuation lengths by adding to all continuing beams.
|
||||
vocab_size = logits.shape[-1].value
|
||||
vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1]
|
||||
lengths_to_add = array_ops.one_hot(
|
||||
indices=array_ops.tile(
|
||||
array_ops.reshape(end_token, [1, 1]), [batch_size, beam_width]),
|
||||
|
Loading…
Reference in New Issue
Block a user