diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn.py b/tensorflow/contrib/rnn/python/ops/core_rnn.py index bbfa6b88506..3ce075ce9c3 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn.py @@ -19,7 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.rnn.python.ops import core_rnn_cell from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -32,6 +31,7 @@ from tensorflow.python.util import nest # pylint: disable=protected-access _concat = rnn_cell_impl._concat +_like_rnncell = rnn_cell_impl._like_rnncell _infer_state_dtype = rnn._infer_state_dtype _reverse_seq = rnn._reverse_seq _rnn_step = rnn._rnn_step @@ -99,7 +99,7 @@ def static_rnn(cell, inputs, initial_state=None, dtype=None, (column size) cannot be inferred from inputs via shape inference. """ - if not isinstance(cell, core_rnn_cell.RNNCell): + if not _like_rnncell(cell): raise TypeError("cell must be an instance of RNNCell") if not nest.is_sequence(inputs): raise TypeError("inputs must be a sequence") @@ -319,9 +319,9 @@ def static_bidirectional_rnn(cell_fw, cell_bw, inputs, ValueError: If inputs is None or an empty list. """ - if not isinstance(cell_fw, core_rnn_cell.RNNCell): + if not _like_rnncell(cell_fw): raise TypeError("cell_fw must be an instance of RNNCell") - if not isinstance(cell_bw, core_rnn_cell.RNNCell): + if not _like_rnncell(cell_bw): raise TypeError("cell_bw must be an instance of RNNCell") if not nest.is_sequence(inputs): raise TypeError("inputs must be a sequence") diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py index eba2c0d2acb..f3e57cd3ec7 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py @@ -42,16 +42,21 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops.math_ops import sigmoid from tensorflow.python.ops.math_ops import tanh -from tensorflow.python.ops.rnn_cell_impl import _RNNCell as RNNCell from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +# pylint: disable=protected-access +RNNCell = rnn_cell_impl._RNNCell # pylint: disable=invalid-name +_like_rnncell = rnn_cell_impl._like_rnncell +# pylint: enable=protected-access + _BIAS_VARIABLE_NAME = "biases" _WEIGHTS_VARIABLE_NAME = "weights" @@ -424,7 +429,7 @@ class OutputProjectionWrapper(RNNCell): ValueError: if output_size is not positive. """ super(OutputProjectionWrapper, self).__init__(_reuse=reuse) - if not isinstance(cell, RNNCell): + if not _like_rnncell(cell): raise TypeError("The parameter cell is not RNNCell.") if output_size < 1: raise ValueError("Parameter output_size must be > 0: %d." % output_size) @@ -480,7 +485,7 @@ class InputProjectionWrapper(RNNCell): super(InputProjectionWrapper, self).__init__(_reuse=reuse) if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) - if not isinstance(cell, RNNCell): + if not _like_rnncell(cell): raise TypeError("The parameter cell is not RNNCell.") self._cell = cell self._num_proj = num_proj @@ -556,7 +561,7 @@ class DropoutWrapper(RNNCell): TypeError: if cell is not an RNNCell. ValueError: if any of the keep_probs are not between 0 and 1. """ - if not isinstance(cell, RNNCell): + if not _like_rnncell(cell): raise TypeError("The parameter cell is not a RNNCell.") with ops.name_scope("DropoutWrapperInit"): def tensor_and_const_value(v): @@ -791,7 +796,7 @@ class EmbeddingWrapper(RNNCell): ValueError: if embedding_classes is not positive. """ super(EmbeddingWrapper, self).__init__(_reuse=reuse) - if not isinstance(cell, RNNCell): + if not _like_rnncell(cell): raise TypeError("The parameter cell is not RNNCell.") if embedding_classes <= 0 or embedding_size <= 0: raise ValueError("Both embedding_classes and embedding_size must be > 0: " diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 7a0f894404c..217c379c36f 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -1057,7 +1058,7 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell): `state_is_tuple` is `False` or if attn_length is zero or less. """ super(AttentionCellWrapper, self).__init__(_reuse=reuse) - if not isinstance(cell, core_rnn_cell.RNNCell): + if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access raise TypeError("The parameter cell is not RNNCell.") if nest.is_sequence(cell.state_size) and not state_is_tuple: raise ValueError("Cell returns tuple of states, but the flag " diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 686a85e4e73..fd76882d846 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -543,7 +543,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell): name: Name to use when creating ops. """ super(AttentionWrapper, self).__init__(name=name) - if not isinstance(cell, core_rnn_cell.RNNCell): + if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access raise TypeError( "cell must be an RNNCell, saw type: %s" % type(cell).__name__) if not isinstance(attention_mechanism, AttentionMechanism): diff --git a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py index 6231a1fdf90..8ae175b6b59 100644 --- a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py @@ -21,13 +21,13 @@ from __future__ import print_function import collections -from tensorflow.contrib.rnn import core_rnn_cell from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.contrib.seq2seq.python.ops import helper as helper_py from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base as layers_base +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.util import nest @@ -60,7 +60,7 @@ class BasicDecoder(decoder.Decoder): Raises: TypeError: if `cell`, `helper` or `output_layer` have an incorrect type. """ - if not isinstance(cell, core_rnn_cell.RNNCell): + if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) if not isinstance(helper, helper_py.Helper): raise TypeError("helper must be a Helper, received: %s" % type(helper)) diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 63ce9dafc0d..eb494bda4b5 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -20,7 +20,6 @@ from __future__ import print_function import collections -from tensorflow.contrib.rnn import core_rnn_cell from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.python.framework import dtypes @@ -33,6 +32,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util import nest @@ -143,7 +143,7 @@ class BeamSearchDecoder(decoder.Decoder): ValueError: If `start_tokens` is not a vector or `end_token` is not a scalar. """ - if not isinstance(cell, core_rnn_cell.RNNCell): + if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) if (output_layer is not None and not isinstance(output_layer, layers_base.Layer)): diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 475c49091e9..2aa288e36ac 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -34,6 +34,7 @@ from tensorflow.python.util import nest # pylint: disable=protected-access _concat = rnn_cell_impl._concat +_like_rnncell = rnn_cell_impl._like_rnncell # pylint: enable=protected-access @@ -361,12 +362,10 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. """ - # pylint: disable=protected-access - if not isinstance(cell_fw, rnn_cell_impl._RNNCell): + if not _like_rnncell(cell_fw): raise TypeError("cell_fw must be an instance of RNNCell") - if not isinstance(cell_bw, rnn_cell_impl._RNNCell): + if not _like_rnncell(cell_bw): raise TypeError("cell_bw must be an instance of RNNCell") - # pylint: enable=protected-access with vs.variable_scope(scope or "bidirectional_rnn"): # Forward direction @@ -507,10 +506,8 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, ValueError: If inputs is None or an empty list. """ - # pylint: disable=protected-access - if not isinstance(cell, rnn_cell_impl._RNNCell): + if not _like_rnncell(cell): raise TypeError("cell must be an instance of RNNCell") - # pylint: enable=protected-access # By default, time_major==False and inputs are batch-major: shaped # [batch, time, depth] @@ -921,10 +918,8 @@ def raw_rnn(cell, loop_fn, a `callable`. """ - # pylint: disable=protected-access - if not isinstance(cell, rnn_cell_impl._RNNCell): + if not _like_rnncell(cell): raise TypeError("cell must be an instance of RNNCell") - # pylint: enable=protected-access if not callable(loop_fn): raise TypeError("loop_fn must be a callable") diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 10d23eb09f5..9c0fb1db23d 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -36,6 +36,13 @@ from tensorflow.python.ops import variables as tf_variables from tensorflow.python.util import nest +def _like_rnncell(cell): + """Checks that a given object is an RNNCell by using duck typing.""" + conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"), + hasattr(cell, "zero_state"), callable(cell)] + return all(conditions) + + def _concat(prefix, suffix, static=False): """Concat that enables int, Tensor, or TensorShape values.