Make is_resource_variable() an tf.__internal__ API.
PiperOrigin-RevId: 351272430 Change-Id: Ic628427e5ecc15628ef4f5d63e99a8b0604f3a81
This commit is contained in:
parent
c999815aac
commit
68024de2a7
@ -37,6 +37,7 @@ from tensorflow.python.ops import control_flow_ops
|
|||||||
from tensorflow.python.ops import gen_cudnn_rnn_ops
|
from tensorflow.python.ops import gen_cudnn_rnn_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import sysconfig
|
from tensorflow.python.platform import sysconfig
|
||||||
@ -418,6 +419,19 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU):
|
|||||||
if _use_new_code():
|
if _use_new_code():
|
||||||
self._defun_wrapper = _DefunWrapper(time_major, go_backwards, 'gru')
|
self._defun_wrapper = _DefunWrapper(time_major, go_backwards, 'gru')
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
super(GRU, self).build(input_shape)
|
||||||
|
|
||||||
|
if not all(isinstance(v, resource_variable_ops.ResourceVariable)
|
||||||
|
for v in self.weights):
|
||||||
|
# Non-resource variables, such as DistributedVariables and
|
||||||
|
# AutoCastVariables, do not work properly with the implementation
|
||||||
|
# selector, which is used when cuDNN is used. However, by chance, such
|
||||||
|
# variables happen to work in LSTM, so this check is only needed for GRU.
|
||||||
|
# TODO(b/136512020): Make non-resource variables work with the
|
||||||
|
# implementation selector.
|
||||||
|
self._could_use_gpu_kernel = False
|
||||||
|
|
||||||
def call(self, inputs, mask=None, training=None, initial_state=None):
|
def call(self, inputs, mask=None, training=None, initial_state=None):
|
||||||
# The input should be dense, padded with zeros. If a ragged input is fed
|
# The input should be dense, padded with zeros. If a ragged input is fed
|
||||||
# into the layer, it is padded and the row lengths are used for masking.
|
# into the layer, it is padded and the row lengths are used for masking.
|
||||||
|
@ -280,7 +280,7 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
tf_py_test(
|
||||||
name = "layer_correctness_test",
|
name = "layer_correctness_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["layer_correctness_test.py"],
|
srcs = ["layer_correctness_test.py"],
|
||||||
|
@ -39,6 +39,7 @@ from tensorflow.python.keras.optimizer_v2 import adam
|
|||||||
from tensorflow.python.module import module
|
from tensorflow.python.module import module
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import template
|
from tensorflow.python.ops import template
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
@ -272,7 +273,7 @@ class CheckpointingTests(keras_parameterized.TestCase):
|
|||||||
# Optimizer slot variables are created when the original variable is
|
# Optimizer slot variables are created when the original variable is
|
||||||
# restored.
|
# restored.
|
||||||
self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
|
self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
|
||||||
dummy_var = variables_lib.Variable([1.])
|
dummy_var = resource_variable_ops.ResourceVariable([1.])
|
||||||
on_create_optimizer.minimize(loss=dummy_var.read_value,
|
on_create_optimizer.minimize(loss=dummy_var.read_value,
|
||||||
var_list=[dummy_var])
|
var_list=[dummy_var])
|
||||||
status.assert_existing_objects_matched()
|
status.assert_existing_objects_matched()
|
||||||
@ -458,8 +459,8 @@ class CheckpointingTests(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Model, self).__init__()
|
super(Model, self).__init__()
|
||||||
self.w = variables_lib.Variable(0.0)
|
self.w = resource_variable_ops.ResourceVariable(0.0)
|
||||||
self.b = variables_lib.Variable(0.0)
|
self.b = resource_variable_ops.ResourceVariable(0.0)
|
||||||
self.vars = [self.w, self.b]
|
self.vars = [self.w, self.b]
|
||||||
|
|
||||||
def call(self, x):
|
def call(self, x):
|
||||||
@ -873,7 +874,8 @@ class CheckpointCompatibilityTests(keras_parameterized.TestCase):
|
|||||||
self._check_sentinels(root)
|
self._check_sentinels(root)
|
||||||
# Check that there is no error when keys are missing from the name-based
|
# Check that there is no error when keys are missing from the name-based
|
||||||
# checkpoint.
|
# checkpoint.
|
||||||
root.not_in_name_checkpoint = variables_lib.Variable([1.])
|
root.not_in_name_checkpoint = resource_variable_ops.ResourceVariable(
|
||||||
|
[1.])
|
||||||
status = object_saver.restore(save_path)
|
status = object_saver.restore(save_path)
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
status.assert_existing_objects_matched()
|
status.assert_existing_objects_matched()
|
||||||
|
@ -55,7 +55,6 @@ from tensorflow.python.types import core
|
|||||||
from tensorflow.python.util import _pywrap_utils
|
from tensorflow.python.util import _pywrap_utils
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util.deprecation import deprecated
|
from tensorflow.python.util.deprecation import deprecated
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
|
||||||
|
|
||||||
acd.register_read_only_resource_op("ReadVariableOp")
|
acd.register_read_only_resource_op("ReadVariableOp")
|
||||||
acd.register_read_only_resource_op("VariableShape")
|
acd.register_read_only_resource_op("VariableShape")
|
||||||
@ -2212,7 +2211,6 @@ ops.register_proto_function(
|
|||||||
from_proto=_from_proto_fn)
|
from_proto=_from_proto_fn)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("__internal__.ops.is_resource_variable", v1=[])
|
|
||||||
def is_resource_variable(var):
|
def is_resource_variable(var):
|
||||||
""""Returns True if `var` is to be considered a ResourceVariable."""
|
""""Returns True if `var` is to be considered a ResourceVariable."""
|
||||||
return isinstance(var, BaseResourceVariable) or hasattr(
|
return isinstance(var, BaseResourceVariable) or hasattr(
|
||||||
|
@ -4,8 +4,4 @@ tf_module {
|
|||||||
name: "broadcast_weights"
|
name: "broadcast_weights"
|
||||||
argspec: "args=[\'weights\', \'values\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'weights\', \'values\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "is_resource_variable"
|
|
||||||
argspec: "args=[\'var\'], varargs=None, keywords=None, defaults=None"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user