Update rnn_cell.DeviceWrapper to be 2.0 compatible.

This was the last symbol under tf.nn.rnn_cell, which is now deleted.

PiperOrigin-RevId: 236138948
This commit is contained in:
Scott Zhu 2019-02-28 09:41:49 -08:00 committed by TensorFlower Gardener
parent 8b91877f66
commit 42a1de008f
12 changed files with 117 additions and 128 deletions

View File

@ -2861,21 +2861,20 @@ class RNNCellTest(test.TestCase, parameterized.TestCase):
# States are left untouched # States are left untouched
self.assertAllClose(res_m_new, res_m_new_res) self.assertAllClose(res_m_new, res_m_new_res)
@test_util.run_v1_only("b/124229375") @parameterized.parameters(
def testDeviceWrapper(self): [rnn_cell_impl.DeviceWrapper, rnn_cell_impl.DeviceWrapperV2])
with variable_scope.variable_scope( def testDeviceWrapper(self, wrapper_type):
"root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3])
x = array_ops.zeros([1, 3]) m = array_ops.zeros([1, 3])
m = array_ops.zeros([1, 3]) cell = rnn_cell_impl.GRUCell(3)
wrapped = rnn_cell_impl.GRUCell(3) wrapped_cell = wrapper_type(cell, "/cpu:0")
cell = rnn_cell_impl.DeviceWrapper(wrapped, "/cpu:14159") (name, dep), = wrapped_cell._checkpoint_dependencies
(name, dep), = cell._checkpoint_dependencies wrapped_cell.get_config() # Should not throw an error
cell.get_config() # Should not throw an error self.assertIs(dep, cell)
self.assertIs(dep, wrapped) self.assertEqual("cell", name)
self.assertEqual("cell", name)
outputs, _ = cell(x, m) outputs, _ = wrapped_cell(x, m)
self.assertTrue("cpu:14159" in outputs.device.lower()) self.assertIn("cpu:0", outputs.device.lower())
def _retrieve_cpu_gpu_stats(self, run_metadata): def _retrieve_cpu_gpu_stats(self, run_metadata):
cpu_stats = None cpu_stats = None
@ -2975,7 +2974,7 @@ class RNNCellTest(test.TestCase, parameterized.TestCase):
def testWrapperKerasStyle(self, wrapper, wrapper_v2): def testWrapperKerasStyle(self, wrapper, wrapper_v2):
"""Tests if wrapper cell is instantiated in keras style scope.""" """Tests if wrapper cell is instantiated in keras style scope."""
wrapped_cell_v2 = wrapper_v2(rnn_cell_impl.BasicRNNCell(1)) wrapped_cell_v2 = wrapper_v2(rnn_cell_impl.BasicRNNCell(1))
self.assertTrue(wrapped_cell_v2._keras_style) self.assertIsNone(getattr(wrapped_cell_v2, "_keras_style", None))
wrapped_cell = wrapper(rnn_cell_impl.BasicRNNCell(1)) wrapped_cell = wrapper(rnn_cell_impl.BasicRNNCell(1))
self.assertFalse(wrapped_cell._keras_style) self.assertFalse(wrapped_cell._keras_style)
@ -3014,22 +3013,21 @@ class RNNCellTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testWrapperWeights(self, wrapper): def testWrapperWeights(self, wrapper):
"""Tests that wrapper weights contain wrapped cells weights.""" """Tests that wrapper weights contain wrapped cells weights."""
base_cell = keras_layers.SimpleRNNCell(1, name="basic_rnn_cell")
with base_layer.keras_style_scope():
base_cell = rnn_cell_impl.BasicRNNCell(1, name="basic_rnn_cell")
rnn_cell = wrapper(base_cell) rnn_cell = wrapper(base_cell)
rnn_layer = keras_layers.RNN(rnn_cell) rnn_layer = keras_layers.RNN(rnn_cell)
inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32) inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32)
rnn_layer(inputs) rnn_layer(inputs)
expected_weights = ["rnn/" + var for var in ("kernel:0", "bias:0")] expected_weights = ["rnn/" + var for var in
self.assertEqual(len(rnn_cell.weights), 2) ("kernel:0", "recurrent_kernel:0", "bias:0")]
self.assertEqual(len(rnn_cell.weights), 3)
self.assertCountEqual([v.name for v in rnn_cell.weights], expected_weights) self.assertCountEqual([v.name for v in rnn_cell.weights], expected_weights)
self.assertCountEqual([v.name for v in rnn_cell.trainable_variables], self.assertCountEqual([v.name for v in rnn_cell.trainable_variables],
expected_weights) expected_weights)
self.assertCountEqual([v.name for v in rnn_cell.non_trainable_variables], self.assertCountEqual([v.name for v in rnn_cell.non_trainable_variables],
[]) [])
self.assertCountEqual([v.name for v in rnn_cell._cell.weights], self.assertCountEqual([v.name for v in rnn_cell.cell.weights],
expected_weights) expected_weights)
@parameterized.parameters( @parameterized.parameters(

View File

@ -1097,9 +1097,9 @@ class _RNNCellWrapperV1(RNNCell):
def __init__(self, cell): def __init__(self, cell):
super(_RNNCellWrapperV1, self).__init__() super(_RNNCellWrapperV1, self).__init__()
self._cell = cell self.cell = cell
if isinstance(cell, trackable.Trackable): if isinstance(cell, trackable.Trackable):
self._track_trackable(self._cell, name="cell") self._track_trackable(self.cell, name="cell")
def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
"""Calls the wrapped cell and performs the wrapping logic. """Calls the wrapped cell and performs the wrapping logic.
@ -1143,29 +1143,19 @@ class _RNNCellWrapperV1(RNNCell):
- New state: A tensor or tuple of tensors with new wrapped cell's state. - New state: A tensor or tuple of tensors with new wrapped cell's state.
""" """
return self._call_wrapped_cell( return self._call_wrapped_cell(
inputs, state, cell_call_fn=self._cell.__call__, scope=scope) inputs, state, cell_call_fn=self.cell.__call__, scope=scope)
class _RNNCellWrapperV2(keras_layer.AbstractRNNCell, _RNNCellWrapperV1): class _RNNCellWrapperV2(keras_layer.AbstractRNNCell):
"""Base class for cells wrappers V2 compatibility. """Base class for cells wrappers V2 compatibility.
This class along with `_RNNCellWrapperV1` allows to define cells wrappers that This class along with `_RNNCellWrapperV1` allows to define cells wrappers that
are compatible with V1 and V2, and defines helper methods for this purpose. are compatible with V1 and V2, and defines helper methods for this purpose.
""" """
def __init__(self, *args, **kwargs): def __init__(self, cell, *args, **kwargs):
super(_RNNCellWrapperV2, self).__init__(*args, **kwargs) super(_RNNCellWrapperV2, self).__init__(*args, **kwargs)
self._layers = [self._cell] self.cell = cell
# Force the keras_style to be true so that it will suppress all the legacy
# behavior from base_layer.
self._keras_style = True
def __call__(self, inputs, *args, **kwargs):
# Override the __call__() so that it does not fall to
# _RNNCellWrapperV1.__call__(), instead, directly use keras base_layer so
# that it will properly invoke the self.call() here.
return keras_layer.Layer.__call__(self, inputs, *args, **kwargs)
def call(self, inputs, state, **kwargs): def call(self, inputs, state, **kwargs):
"""Runs the RNN cell step computation. """Runs the RNN cell step computation.
@ -1189,16 +1179,15 @@ class _RNNCellWrapperV2(keras_layer.AbstractRNNCell, _RNNCellWrapperV1):
- New state: A tensor or tuple of tensors with new wrapped cell's state. - New state: A tensor or tuple of tensors with new wrapped cell's state.
""" """
return self._call_wrapped_cell( return self._call_wrapped_cell(
inputs, state, cell_call_fn=self._cell.call, **kwargs) inputs, state, cell_call_fn=self.cell.call, **kwargs)
def build(self, inputs_shape): def build(self, inputs_shape):
"""Builds the wrapped cell.""" """Builds the wrapped cell."""
self._cell.build(inputs_shape) self.cell.build(inputs_shape)
self.built = True self.built = True
@tf_export(v1=["nn.rnn_cell.DropoutWrapper"]) class DropoutWrapperBase(object):
class DropoutWrapper(_RNNCellWrapperV1):
"""Operator adding dropout to inputs and outputs of the given cell.""" """Operator adding dropout to inputs and outputs of the given cell."""
def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0, def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0,
@ -1268,7 +1257,7 @@ class DropoutWrapper(_RNNCellWrapperV1):
but not `callable`. but not `callable`.
ValueError: if any of the keep_probs are not between 0 and 1. ValueError: if any of the keep_probs are not between 0 and 1.
""" """
super(DropoutWrapper, self).__init__(cell) super(DropoutWrapperBase, self).__init__(cell)
assert_like_rnncell("cell", cell) assert_like_rnncell("cell", cell)
if (dropout_state_filter_visitor is not None if (dropout_state_filter_visitor is not None
@ -1345,19 +1334,19 @@ class DropoutWrapper(_RNNCellWrapperV1):
@property @property
def wrapped_cell(self): def wrapped_cell(self):
return self._cell return self.cell
@property @property
def state_size(self): def state_size(self):
return self._cell.state_size return self.cell.state_size
@property @property
def output_size(self): def output_size(self):
return self._cell.output_size return self.cell.output_size
def zero_state(self, batch_size, dtype): def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype) return self.cell.zero_state(batch_size, dtype)
def _variational_recurrent_dropout_value( def _variational_recurrent_dropout_value(
self, index, value, noise, keep_prob): self, index, value, noise, keep_prob):
@ -1440,18 +1429,27 @@ class DropoutWrapper(_RNNCellWrapperV1):
return output, new_state return output, new_state
@tf_export("nn.DropoutWrapper", v1=[]) @tf_export(v1=["nn.rnn_cell.DropoutWrapper"])
class DropoutWrapperV2(DropoutWrapper, _RNNCellWrapperV2): class DropoutWrapper(DropoutWrapperBase, _RNNCellWrapperV1):
"""Operator adding dropout to inputs and outputs of the given cell."""
def __init__(self, *args, **kwargs):
super(DropoutWrapper, self).__init__(*args, **kwargs)
__init__.__doc__ = DropoutWrapperBase.__init__.__doc__
@tf_export("nn.RNNCellDropoutWrapper", v1=[])
class DropoutWrapperV2(DropoutWrapperBase, _RNNCellWrapperV2):
"""Operator adding dropout to inputs and outputs of the given cell.""" """Operator adding dropout to inputs and outputs of the given cell."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(DropoutWrapperV2, self).__init__(*args, **kwargs) super(DropoutWrapperV2, self).__init__(*args, **kwargs)
__init__.__doc__ = DropoutWrapper.__init__.__doc__ __init__.__doc__ = DropoutWrapperBase.__init__.__doc__
@tf_export(v1=["nn.rnn_cell.ResidualWrapper"]) class ResidualWrapperBase(object):
class ResidualWrapper(_RNNCellWrapperV1):
"""RNNCell wrapper that ensures cell inputs are added to the outputs.""" """RNNCell wrapper that ensures cell inputs are added to the outputs."""
def __init__(self, cell, residual_fn=None): def __init__(self, cell, residual_fn=None):
@ -1464,20 +1462,20 @@ class ResidualWrapper(_RNNCellWrapperV1):
Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs
and outputs. and outputs.
""" """
super(ResidualWrapper, self).__init__(cell) super(ResidualWrapperBase, self).__init__(cell)
self._residual_fn = residual_fn self._residual_fn = residual_fn
@property @property
def state_size(self): def state_size(self):
return self._cell.state_size return self.cell.state_size
@property @property
def output_size(self): def output_size(self):
return self._cell.output_size return self.cell.output_size
def zero_state(self, batch_size, dtype): def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype) return self.cell.zero_state(batch_size, dtype)
def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
"""Run the cell and then apply the residual_fn on its inputs to its outputs. """Run the cell and then apply the residual_fn on its inputs to its outputs.
@ -1508,18 +1506,27 @@ class ResidualWrapper(_RNNCellWrapperV1):
return (res_outputs, new_state) return (res_outputs, new_state)
@tf_export("nn.ResidualWrapper", v1=[]) @tf_export(v1=["nn.rnn_cell.ResidualWrapper"])
class ResidualWrapperV2(ResidualWrapper, _RNNCellWrapperV2): class ResidualWrapper(ResidualWrapperBase, _RNNCellWrapperV1):
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
def __init__(self, *args, **kwargs):
super(ResidualWrapper, self).__init__(*args, **kwargs)
__init__.__doc__ = ResidualWrapperBase.__init__.__doc__
@tf_export("nn.RNNCellResidualWrapper", v1=[])
class ResidualWrapperV2(ResidualWrapperBase, _RNNCellWrapperV2):
"""RNNCell wrapper that ensures cell inputs are added to the outputs.""" """RNNCell wrapper that ensures cell inputs are added to the outputs."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ResidualWrapperV2, self).__init__(*args, **kwargs) super(ResidualWrapperV2, self).__init__(*args, **kwargs)
__init__.__doc__ = ResidualWrapper.__init__.__doc__ __init__.__doc__ = ResidualWrapperBase.__init__.__doc__
@tf_export("nn.rnn_cell.DeviceWrapper") class DeviceWrapperBase(object):
class DeviceWrapper(RNNCell):
"""Operator that ensures an RNNCell runs on a particular device.""" """Operator that ensures an RNNCell runs on a particular device."""
def __init__(self, cell, device): def __init__(self, cell, device):
@ -1531,29 +1538,45 @@ class DeviceWrapper(RNNCell):
cell: An instance of `RNNCell`. cell: An instance of `RNNCell`.
device: A device string or function, for passing to `tf.device`. device: A device string or function, for passing to `tf.device`.
""" """
super(DeviceWrapper, self).__init__() super(DeviceWrapperBase, self).__init__(cell)
self._cell = cell
if isinstance(cell, trackable.Trackable):
self._track_trackable(self._cell, name="cell")
self._device = device self._device = device
@property @property
def state_size(self): def state_size(self):
return self._cell.state_size return self.cell.state_size
@property @property
def output_size(self): def output_size(self):
return self._cell.output_size return self.cell.output_size
def zero_state(self, batch_size, dtype): def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
with ops.device(self._device): with ops.device(self._device):
return self._cell.zero_state(batch_size, dtype) return self.cell.zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None): def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
"""Run the cell on specified device.""" """Run the cell on specified device."""
with ops.device(self._device): with ops.device(self._device):
return self._cell(inputs, state, scope=scope) return cell_call_fn(inputs, state, **kwargs)
@tf_export(v1=["nn.rnn_cell.DeviceWrapper"])
class DeviceWrapper(DeviceWrapperBase, _RNNCellWrapperV1):
def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
super(DeviceWrapper, self).__init__(*args, **kwargs)
__init__.__doc__ = DeviceWrapperBase.__init__.__doc__
@tf_export("nn.RNNCellDeviceWrapper", v1=[])
class DeviceWrapperV2(DeviceWrapperBase, _RNNCellWrapperV2):
"""Operator that ensures an RNNCell runs on a particular device."""
def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
super(DeviceWrapperV2, self).__init__(*args, **kwargs)
__init__.__doc__ = DeviceWrapperBase.__init__.__doc__
@tf_export(v1=["nn.rnn_cell.MultiRNNCell"]) @tf_export(v1=["nn.rnn_cell.MultiRNNCell"])

View File

@ -38,7 +38,6 @@ TENSORFLOW_API_INIT_FILES = [
"math/__init__.py", "math/__init__.py",
"nest/__init__.py", "nest/__init__.py",
"nn/__init__.py", "nn/__init__.py",
"nn/rnn_cell/__init__.py",
"quantization/__init__.py", "quantization/__init__.py",
"ragged/__init__.py", "ragged/__init__.py",
"random/__init__.py", "random/__init__.py",

View File

@ -1,6 +1,8 @@
path: "tensorflow.nn.rnn_cell.DeviceWrapper" path: "tensorflow.nn.rnn_cell.DeviceWrapper"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DeviceWrapper\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DeviceWrapper\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DeviceWrapperBase\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV1\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>" is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
@ -104,7 +106,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'cell\', \'device\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
} }
member_method { member_method {
name: "add_loss" name: "add_loss"

View File

@ -1,6 +1,7 @@
path: "tensorflow.nn.rnn_cell.DropoutWrapper" path: "tensorflow.nn.rnn_cell.DropoutWrapper"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DropoutWrapper\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DropoutWrapper\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DropoutWrapperBase\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV1\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV1\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>" is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
@ -109,7 +110,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'cell\', \'input_keep_prob\', \'output_keep_prob\', \'state_keep_prob\', \'variational_recurrent\', \'input_size\', \'dtype\', \'seed\', \'dropout_state_filter_visitor\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1.0\', \'1.0\', \'False\', \'None\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
} }
member_method { member_method {
name: "add_loss" name: "add_loss"

View File

@ -1,6 +1,7 @@
path: "tensorflow.nn.rnn_cell.ResidualWrapper" path: "tensorflow.nn.rnn_cell.ResidualWrapper"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.ResidualWrapper\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.ResidualWrapper\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.ResidualWrapperBase\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV1\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV1\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>" is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
@ -105,7 +106,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'cell\', \'residual_fn\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
} }
member_method { member_method {
name: "add_loss" name: "add_loss"

View File

@ -1,8 +1,9 @@
path: "tensorflow.nn.rnn_cell.DeviceWrapper" path: "tensorflow.nn.RNNCellDeviceWrapper"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DeviceWrapper\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DeviceWrapperV2\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DeviceWrapperBase\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV2\'>"
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.AbstractRNNCell\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
@ -18,10 +19,6 @@ tf_class {
name: "dynamic" name: "dynamic"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "graph"
mtype: "<type \'property\'>"
}
member { member {
name: "inbound_nodes" name: "inbound_nodes"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
@ -74,10 +71,6 @@ tf_class {
name: "output_size" name: "output_size"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "scope_name"
mtype: "<type \'property\'>"
}
member { member {
name: "state_size" name: "state_size"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
@ -104,7 +97,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'cell\', \'device\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
} }
member_method { member_method {
name: "add_loss" name: "add_loss"
@ -124,7 +117,7 @@ tf_class {
} }
member_method { member_method {
name: "add_weight" name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
} }
member_method { member_method {
name: "apply" name: "apply"
@ -132,11 +125,11 @@ tf_class {
} }
member_method { member_method {
name: "build" name: "build"
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'inputs_shape\'], varargs=None, keywords=None, defaults=None"
} }
member_method { member_method {
name: "call" name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=kwargs, defaults=None"
} }
member_method { member_method {
name: "compute_mask" name: "compute_mask"

View File

@ -1,12 +1,9 @@
path: "tensorflow.nn.DropoutWrapper" path: "tensorflow.nn.RNNCellDropoutWrapper"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DropoutWrapperV2\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DropoutWrapperV2\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DropoutWrapper\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DropoutWrapperBase\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV2\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV2\'>"
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.AbstractRNNCell\'>" is_instance: "<class \'tensorflow.python.keras.layers.recurrent.AbstractRNNCell\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV1\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
@ -22,10 +19,6 @@ tf_class {
name: "dynamic" name: "dynamic"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "graph"
mtype: "<type \'property\'>"
}
member { member {
name: "inbound_nodes" name: "inbound_nodes"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
@ -78,10 +71,6 @@ tf_class {
name: "output_size" name: "output_size"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "scope_name"
mtype: "<type \'property\'>"
}
member { member {
name: "state_size" name: "state_size"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
@ -132,7 +121,7 @@ tf_class {
} }
member_method { member_method {
name: "add_weight" name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
} }
member_method { member_method {
name: "apply" name: "apply"

View File

@ -1,12 +1,9 @@
path: "tensorflow.nn.ResidualWrapper" path: "tensorflow.nn.RNNCellResidualWrapper"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.ResidualWrapperV2\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.ResidualWrapperV2\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.ResidualWrapper\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.ResidualWrapperBase\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV2\'>" is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV2\'>"
is_instance: "<class \'tensorflow.python.keras.layers.recurrent.AbstractRNNCell\'>" is_instance: "<class \'tensorflow.python.keras.layers.recurrent.AbstractRNNCell\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV1\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
@ -22,10 +19,6 @@ tf_class {
name: "dynamic" name: "dynamic"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "graph"
mtype: "<type \'property\'>"
}
member { member {
name: "inbound_nodes" name: "inbound_nodes"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
@ -78,10 +71,6 @@ tf_class {
name: "output_size" name: "output_size"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "scope_name"
mtype: "<type \'property\'>"
}
member { member {
name: "state_size" name: "state_size"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
@ -128,7 +117,7 @@ tf_class {
} }
member_method { member_method {
name: "add_weight" name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
} }
member_method { member_method {
name: "apply" name: "apply"

View File

@ -1,16 +1,16 @@
path: "tensorflow.nn" path: "tensorflow.nn"
tf_module { tf_module {
member { member {
name: "DropoutWrapper" name: "RNNCellDeviceWrapper"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member { member {
name: "ResidualWrapper" name: "RNNCellDropoutWrapper"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member { member {
name: "rnn_cell" name: "RNNCellResidualWrapper"
mtype: "<type \'module\'>" mtype: "<type \'type\'>"
} }
member { member {
name: "swish" name: "swish"

View File

@ -1,7 +0,0 @@
path: "tensorflow.nn.rnn_cell"
tf_module {
member {
name: "DeviceWrapper"
mtype: "<type \'type\'>"
}
}

View File

@ -441,6 +441,7 @@ renames = {
'tf.nn.relu_layer': 'tf.compat.v1.nn.relu_layer', 'tf.nn.relu_layer': 'tf.compat.v1.nn.relu_layer',
'tf.nn.rnn_cell.BasicLSTMCell': 'tf.compat.v1.nn.rnn_cell.BasicLSTMCell', 'tf.nn.rnn_cell.BasicLSTMCell': 'tf.compat.v1.nn.rnn_cell.BasicLSTMCell',
'tf.nn.rnn_cell.BasicRNNCell': 'tf.compat.v1.nn.rnn_cell.BasicRNNCell', 'tf.nn.rnn_cell.BasicRNNCell': 'tf.compat.v1.nn.rnn_cell.BasicRNNCell',
'tf.nn.rnn_cell.DeviceWrapper': 'tf.compat.v1.nn.rnn_cell.DeviceWrapper',
'tf.nn.rnn_cell.DropoutWrapper': 'tf.compat.v1.nn.rnn_cell.DropoutWrapper', 'tf.nn.rnn_cell.DropoutWrapper': 'tf.compat.v1.nn.rnn_cell.DropoutWrapper',
'tf.nn.rnn_cell.GRUCell': 'tf.compat.v1.nn.rnn_cell.GRUCell', 'tf.nn.rnn_cell.GRUCell': 'tf.compat.v1.nn.rnn_cell.GRUCell',
'tf.nn.rnn_cell.LSTMCell': 'tf.compat.v1.nn.rnn_cell.LSTMCell', 'tf.nn.rnn_cell.LSTMCell': 'tf.compat.v1.nn.rnn_cell.LSTMCell',