diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index cd65a54b837..9c6939bb46f 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -1137,7 +1137,14 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): - `attention_or_cell_output` depending on `output_attention`. - `next_state` is an instance of `AttentionWrapperState` containing the state calculated at this time step. + + Raises: + TypeError: If `state` is not an instance of `AttentionWrapperState`. """ + if not isinstance(state, AttentionWrapperState): + raise TypeError("Expected state to be instance of AttentionWrapperState. " + "Received type %s instead." % type(state)) + # Step 1: Calculate the true inputs to the cell based on the # previous attention value. cell_inputs = self._cell_input_fn(inputs, state.attention)