[tf contrib seq2seq] Provide informative error messages in AttentionWrapper.call.
Related issues: #11077, #11540. PiperOrigin-RevId: 162362081
This commit is contained in:
parent
235060d7e7
commit
a3d3ce1603
@ -1137,7 +1137,14 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
|
|||||||
- `attention_or_cell_output` depending on `output_attention`.
|
- `attention_or_cell_output` depending on `output_attention`.
|
||||||
- `next_state` is an instance of `AttentionWrapperState`
|
- `next_state` is an instance of `AttentionWrapperState`
|
||||||
containing the state calculated at this time step.
|
containing the state calculated at this time step.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `state` is not an instance of `AttentionWrapperState`.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(state, AttentionWrapperState):
|
||||||
|
raise TypeError("Expected state to be instance of AttentionWrapperState. "
|
||||||
|
"Received type %s instead." % type(state))
|
||||||
|
|
||||||
# Step 1: Calculate the true inputs to the cell based on the
|
# Step 1: Calculate the true inputs to the cell based on the
|
||||||
# previous attention value.
|
# previous attention value.
|
||||||
cell_inputs = self._cell_input_fn(inputs, state.attention)
|
cell_inputs = self._cell_input_fn(inputs, state.attention)
|
||||||
|
Loading…
Reference in New Issue
Block a user