From a3d3ce160368efe4b0b3f7f8df23cc8423680364 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Tue, 18 Jul 2017 09:30:16 -0700 Subject: [PATCH] [tf contrib seq2seq] Provide informative error messages in AttentionWrapper.call. Related issues: #11077, #11540. PiperOrigin-RevId: 162362081 --- tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index cd65a54b837..9c6939bb46f 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -1137,7 +1137,14 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): - `attention_or_cell_output` depending on `output_attention`. - `next_state` is an instance of `AttentionWrapperState` containing the state calculated at this time step. + + Raises: + TypeError: If `state` is not an instance of `AttentionWrapperState`. """ + if not isinstance(state, AttentionWrapperState): + raise TypeError("Expected state to be instance of AttentionWrapperState. " + "Received type %s instead." % type(state)) + # Step 1: Calculate the true inputs to the cell based on the # previous attention value. cell_inputs = self._cell_input_fn(inputs, state.attention)