Make is_resource_variable() an tf.__internal__ API.
PiperOrigin-RevId: 351249930 Change-Id: I4c8aa09d5584531c723f0f4919cbf5f30080f705
This commit is contained in:
parent
a0e249979f
commit
86339436a7
@ -37,7 +37,6 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_cudnn_rnn_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import sysconfig
|
||||
@ -419,19 +418,6 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU):
|
||||
if _use_new_code():
|
||||
self._defun_wrapper = _DefunWrapper(time_major, go_backwards, 'gru')
|
||||
|
||||
def build(self, input_shape):
|
||||
super(GRU, self).build(input_shape)
|
||||
|
||||
if not all(isinstance(v, resource_variable_ops.ResourceVariable)
|
||||
for v in self.weights):
|
||||
# Non-resource variables, such as DistributedVariables and
|
||||
# AutoCastVariables, do not work properly with the implementation
|
||||
# selector, which is used when cuDNN is used. However, by chance, such
|
||||
# variables happen to work in LSTM, so this check is only needed for GRU.
|
||||
# TODO(b/136512020): Make non-resource variables work with the
|
||||
# implementation selector.
|
||||
self._could_use_gpu_kernel = False
|
||||
|
||||
def call(self, inputs, mask=None, training=None, initial_state=None):
|
||||
# The input should be dense, padded with zeros. If a ragged input is fed
|
||||
# into the layer, it is padded and the row lengths are used for masking.
|
||||
|
@ -280,7 +280,7 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
cuda_py_test(
|
||||
name = "layer_correctness_test",
|
||||
size = "medium",
|
||||
srcs = ["layer_correctness_test.py"],
|
||||
|
@ -39,7 +39,6 @@ from tensorflow.python.keras.optimizer_v2 import adam
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import template
|
||||
from tensorflow.python.ops import variable_scope
|
||||
@ -273,7 +272,7 @@ class CheckpointingTests(keras_parameterized.TestCase):
|
||||
# Optimizer slot variables are created when the original variable is
|
||||
# restored.
|
||||
self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
|
||||
dummy_var = resource_variable_ops.ResourceVariable([1.])
|
||||
dummy_var = variables_lib.Variable([1.])
|
||||
on_create_optimizer.minimize(loss=dummy_var.read_value,
|
||||
var_list=[dummy_var])
|
||||
status.assert_existing_objects_matched()
|
||||
@ -459,8 +458,8 @@ class CheckpointingTests(keras_parameterized.TestCase):
|
||||
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.w = resource_variable_ops.ResourceVariable(0.0)
|
||||
self.b = resource_variable_ops.ResourceVariable(0.0)
|
||||
self.w = variables_lib.Variable(0.0)
|
||||
self.b = variables_lib.Variable(0.0)
|
||||
self.vars = [self.w, self.b]
|
||||
|
||||
def call(self, x):
|
||||
@ -874,8 +873,7 @@ class CheckpointCompatibilityTests(keras_parameterized.TestCase):
|
||||
self._check_sentinels(root)
|
||||
# Check that there is no error when keys are missing from the name-based
|
||||
# checkpoint.
|
||||
root.not_in_name_checkpoint = resource_variable_ops.ResourceVariable(
|
||||
[1.])
|
||||
root.not_in_name_checkpoint = variables_lib.Variable([1.])
|
||||
status = object_saver.restore(save_path)
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_existing_objects_matched()
|
||||
|
@ -55,6 +55,7 @@ from tensorflow.python.types import core
|
||||
from tensorflow.python.util import _pywrap_utils
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
acd.register_read_only_resource_op("ReadVariableOp")
|
||||
acd.register_read_only_resource_op("VariableShape")
|
||||
@ -2211,6 +2212,7 @@ ops.register_proto_function(
|
||||
from_proto=_from_proto_fn)
|
||||
|
||||
|
||||
@tf_export("__internal__.ops.is_resource_variable", v1=[])
|
||||
def is_resource_variable(var):
|
||||
""""Returns True if `var` is to be considered a ResourceVariable."""
|
||||
return isinstance(var, BaseResourceVariable) or hasattr(
|
||||
|
@ -4,4 +4,8 @@ tf_module {
|
||||
name: "broadcast_weights"
|
||||
argspec: "args=[\'weights\', \'values\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_resource_variable"
|
||||
argspec: "args=[\'var\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user