Make is_resource_variable() an tf.__internal__ API.

PiperOrigin-RevId: 351272430
Change-Id: Ic628427e5ecc15628ef4f5d63e99a8b0604f3a81
This commit is contained in:
A. Unique TensorFlower 2021-01-11 17:44:19 -08:00 committed by TensorFlower Gardener
parent c999815aac
commit 68024de2a7
5 changed files with 21 additions and 11 deletions

View File

@ -37,6 +37,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_cudnn_rnn_ops from tensorflow.python.ops import gen_cudnn_rnn_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn from tensorflow.python.ops import nn
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import sysconfig from tensorflow.python.platform import sysconfig
@ -418,6 +419,19 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU):
if _use_new_code(): if _use_new_code():
self._defun_wrapper = _DefunWrapper(time_major, go_backwards, 'gru') self._defun_wrapper = _DefunWrapper(time_major, go_backwards, 'gru')
def build(self, input_shape):
super(GRU, self).build(input_shape)
if not all(isinstance(v, resource_variable_ops.ResourceVariable)
for v in self.weights):
# Non-resource variables, such as DistributedVariables and
# AutoCastVariables, do not work properly with the implementation
# selector, which is used when cuDNN is used. However, by chance, such
# variables happen to work in LSTM, so this check is only needed for GRU.
# TODO(b/136512020): Make non-resource variables work with the
# implementation selector.
self._could_use_gpu_kernel = False
def call(self, inputs, mask=None, training=None, initial_state=None): def call(self, inputs, mask=None, training=None, initial_state=None):
# The input should be dense, padded with zeros. If a ragged input is fed # The input should be dense, padded with zeros. If a ragged input is fed
# into the layer, it is padded and the row lengths are used for masking. # into the layer, it is padded and the row lengths are used for masking.

View File

@ -280,7 +280,7 @@ cuda_py_test(
], ],
) )
cuda_py_test( tf_py_test(
name = "layer_correctness_test", name = "layer_correctness_test",
size = "medium", size = "medium",
srcs = ["layer_correctness_test.py"], srcs = ["layer_correctness_test.py"],

View File

@ -39,6 +39,7 @@ from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.module import module from tensorflow.python.module import module
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops from tensorflow.python.ops import state_ops
from tensorflow.python.ops import template from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
@ -272,7 +273,7 @@ class CheckpointingTests(keras_parameterized.TestCase):
# Optimizer slot variables are created when the original variable is # Optimizer slot variables are created when the original variable is
# restored. # restored.
self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
dummy_var = variables_lib.Variable([1.]) dummy_var = resource_variable_ops.ResourceVariable([1.])
on_create_optimizer.minimize(loss=dummy_var.read_value, on_create_optimizer.minimize(loss=dummy_var.read_value,
var_list=[dummy_var]) var_list=[dummy_var])
status.assert_existing_objects_matched() status.assert_existing_objects_matched()
@ -458,8 +459,8 @@ class CheckpointingTests(keras_parameterized.TestCase):
def __init__(self): def __init__(self):
super(Model, self).__init__() super(Model, self).__init__()
self.w = variables_lib.Variable(0.0) self.w = resource_variable_ops.ResourceVariable(0.0)
self.b = variables_lib.Variable(0.0) self.b = resource_variable_ops.ResourceVariable(0.0)
self.vars = [self.w, self.b] self.vars = [self.w, self.b]
def call(self, x): def call(self, x):
@ -873,7 +874,8 @@ class CheckpointCompatibilityTests(keras_parameterized.TestCase):
self._check_sentinels(root) self._check_sentinels(root)
# Check that there is no error when keys are missing from the name-based # Check that there is no error when keys are missing from the name-based
# checkpoint. # checkpoint.
root.not_in_name_checkpoint = variables_lib.Variable([1.]) root.not_in_name_checkpoint = resource_variable_ops.ResourceVariable(
[1.])
status = object_saver.restore(save_path) status = object_saver.restore(save_path)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
status.assert_existing_objects_matched() status.assert_existing_objects_matched()

View File

@ -55,7 +55,6 @@ from tensorflow.python.types import core
from tensorflow.python.util import _pywrap_utils from tensorflow.python.util import _pywrap_utils
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
acd.register_read_only_resource_op("ReadVariableOp") acd.register_read_only_resource_op("ReadVariableOp")
acd.register_read_only_resource_op("VariableShape") acd.register_read_only_resource_op("VariableShape")
@ -2212,7 +2211,6 @@ ops.register_proto_function(
from_proto=_from_proto_fn) from_proto=_from_proto_fn)
@tf_export("__internal__.ops.is_resource_variable", v1=[])
def is_resource_variable(var): def is_resource_variable(var):
""""Returns True if `var` is to be considered a ResourceVariable.""" """"Returns True if `var` is to be considered a ResourceVariable."""
return isinstance(var, BaseResourceVariable) or hasattr( return isinstance(var, BaseResourceVariable) or hasattr(

View File

@ -4,8 +4,4 @@ tf_module {
name: "broadcast_weights" name: "broadcast_weights"
argspec: "args=[\'weights\', \'values\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'weights\', \'values\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "is_resource_variable"
argspec: "args=[\'var\'], varargs=None, keywords=None, defaults=None"
}
} }