Change rnn ops to check types via duck-typing, instead of a private attribute.

PiperOrigin-RevId: 156190878
This commit is contained in:
Adria Puigdomenech 2017-05-16 09:29:45 -07:00 committed by TensorFlower Gardener
parent f7040ddf1a
commit 3038bc9137
8 changed files with 33 additions and 25 deletions

View File

@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@ -32,6 +31,7 @@ from tensorflow.python.util import nest
# pylint: disable=protected-access
_concat = rnn_cell_impl._concat
_like_rnncell = rnn_cell_impl._like_rnncell
_infer_state_dtype = rnn._infer_state_dtype
_reverse_seq = rnn._reverse_seq
_rnn_step = rnn._rnn_step
@ -99,7 +99,7 @@ def static_rnn(cell, inputs, initial_state=None, dtype=None,
(column size) cannot be inferred from inputs via shape inference.
"""
if not isinstance(cell, core_rnn_cell.RNNCell):
if not _like_rnncell(cell):
raise TypeError("cell must be an instance of RNNCell")
if not nest.is_sequence(inputs):
raise TypeError("inputs must be a sequence")
@ -319,9 +319,9 @@ def static_bidirectional_rnn(cell_fw, cell_bw, inputs,
ValueError: If inputs is None or an empty list.
"""
if not isinstance(cell_fw, core_rnn_cell.RNNCell):
if not _like_rnncell(cell_fw):
raise TypeError("cell_fw must be an instance of RNNCell")
if not isinstance(cell_bw, core_rnn_cell.RNNCell):
if not _like_rnncell(cell_bw):
raise TypeError("cell_bw must be an instance of RNNCell")
if not nest.is_sequence(inputs):
raise TypeError("inputs must be a sequence")

View File

@ -42,16 +42,21 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.math_ops import sigmoid
from tensorflow.python.ops.math_ops import tanh
from tensorflow.python.ops.rnn_cell_impl import _RNNCell as RNNCell
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
# pylint: disable=protected-access
RNNCell = rnn_cell_impl._RNNCell # pylint: disable=invalid-name
_like_rnncell = rnn_cell_impl._like_rnncell
# pylint: enable=protected-access
_BIAS_VARIABLE_NAME = "biases"
_WEIGHTS_VARIABLE_NAME = "weights"
@ -424,7 +429,7 @@ class OutputProjectionWrapper(RNNCell):
ValueError: if output_size is not positive.
"""
super(OutputProjectionWrapper, self).__init__(_reuse=reuse)
if not isinstance(cell, RNNCell):
if not _like_rnncell(cell):
raise TypeError("The parameter cell is not RNNCell.")
if output_size < 1:
raise ValueError("Parameter output_size must be > 0: %d." % output_size)
@ -480,7 +485,7 @@ class InputProjectionWrapper(RNNCell):
super(InputProjectionWrapper, self).__init__(_reuse=reuse)
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
if not isinstance(cell, RNNCell):
if not _like_rnncell(cell):
raise TypeError("The parameter cell is not RNNCell.")
self._cell = cell
self._num_proj = num_proj
@ -556,7 +561,7 @@ class DropoutWrapper(RNNCell):
TypeError: if cell is not an RNNCell.
ValueError: if any of the keep_probs are not between 0 and 1.
"""
if not isinstance(cell, RNNCell):
if not _like_rnncell(cell):
raise TypeError("The parameter cell is not a RNNCell.")
with ops.name_scope("DropoutWrapperInit"):
def tensor_and_const_value(v):
@ -791,7 +796,7 @@ class EmbeddingWrapper(RNNCell):
ValueError: if embedding_classes is not positive.
"""
super(EmbeddingWrapper, self).__init__(_reuse=reuse)
if not isinstance(cell, RNNCell):
if not _like_rnncell(cell):
raise TypeError("The parameter cell is not RNNCell.")
if embedding_classes <= 0 or embedding_size <= 0:
raise ValueError("Both embedding_classes and embedding_size must be > 0: "

View File

@ -34,6 +34,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
@ -1057,7 +1058,7 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
`state_is_tuple` is `False` or if attn_length is zero or less.
"""
super(AttentionCellWrapper, self).__init__(_reuse=reuse)
if not isinstance(cell, core_rnn_cell.RNNCell):
if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
raise TypeError("The parameter cell is not RNNCell.")
if nest.is_sequence(cell.state_size) and not state_is_tuple:
raise ValueError("Cell returns tuple of states, but the flag "

View File

@ -543,7 +543,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
name: Name to use when creating ops.
"""
super(AttentionWrapper, self).__init__(name=name)
if not isinstance(cell, core_rnn_cell.RNNCell):
if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
raise TypeError(
"cell must be an RNNCell, saw type: %s" % type(cell).__name__)
if not isinstance(attention_mechanism, AttentionMechanism):

View File

@ -21,13 +21,13 @@ from __future__ import print_function
import collections
from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as layers_base
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.util import nest
@ -60,7 +60,7 @@ class BasicDecoder(decoder.Decoder):
Raises:
TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
"""
if not isinstance(cell, core_rnn_cell.RNNCell):
if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
if not isinstance(helper, helper_py.Helper):
raise TypeError("helper must be a Helper, received: %s" % type(helper))

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import collections
from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import dtypes
@ -33,6 +32,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
@ -143,7 +143,7 @@ class BeamSearchDecoder(decoder.Decoder):
ValueError: If `start_tokens` is not a vector or
`end_token` is not a scalar.
"""
if not isinstance(cell, core_rnn_cell.RNNCell):
if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
if (output_layer is not None
and not isinstance(output_layer, layers_base.Layer)):

View File

@ -34,6 +34,7 @@ from tensorflow.python.util import nest
# pylint: disable=protected-access
_concat = rnn_cell_impl._concat
_like_rnncell = rnn_cell_impl._like_rnncell
# pylint: enable=protected-access
@ -361,12 +362,10 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
"""
# pylint: disable=protected-access
if not isinstance(cell_fw, rnn_cell_impl._RNNCell):
if not _like_rnncell(cell_fw):
raise TypeError("cell_fw must be an instance of RNNCell")
if not isinstance(cell_bw, rnn_cell_impl._RNNCell):
if not _like_rnncell(cell_bw):
raise TypeError("cell_bw must be an instance of RNNCell")
# pylint: enable=protected-access
with vs.variable_scope(scope or "bidirectional_rnn"):
# Forward direction
@ -507,10 +506,8 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
ValueError: If inputs is None or an empty list.
"""
# pylint: disable=protected-access
if not isinstance(cell, rnn_cell_impl._RNNCell):
if not _like_rnncell(cell):
raise TypeError("cell must be an instance of RNNCell")
# pylint: enable=protected-access
# By default, time_major==False and inputs are batch-major: shaped
# [batch, time, depth]
@ -921,10 +918,8 @@ def raw_rnn(cell, loop_fn,
a `callable`.
"""
# pylint: disable=protected-access
if not isinstance(cell, rnn_cell_impl._RNNCell):
if not _like_rnncell(cell):
raise TypeError("cell must be an instance of RNNCell")
# pylint: enable=protected-access
if not callable(loop_fn):
raise TypeError("loop_fn must be a callable")

View File

@ -36,6 +36,13 @@ from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.util import nest
def _like_rnncell(cell):
"""Checks that a given object is an RNNCell by using duck typing."""
conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"),
hasattr(cell, "zero_state"), callable(cell)]
return all(conditions)
def _concat(prefix, suffix, static=False):
"""Concat that enables int, Tensor, or TensorShape values.