Make experimental_ref no longer experimental

Also clean up related documentation to utilize doctest.

PiperOrigin-RevId: 294256062
Change-Id: I9d7ff8dc1324bda64f3638ba3f8e48a0fcb42545
This commit is contained in:
Gaurav Jain 2020-02-10 10:52:45 -08:00 committed by TensorFlower Gardener
parent e253ee6945
commit 9c7c7c9979
20 changed files with 187 additions and 186 deletions

View File

@ -494,10 +494,8 @@ class _FetchHandler(object):
if (isinstance(fetch, ops.Tensor) and
(fetch.op.type == 'GetSessionHandle' or
fetch.op.type == 'GetSessionHandleV2')):
self._fetch_handles[fetch.experimental_ref()] = fetch.op.inputs[0].dtype
self._final_fetches = [
x for x in self._fetches if x.experimental_ref() not in feeds
]
self._fetch_handles[fetch.ref()] = fetch.op.inputs[0].dtype
self._final_fetches = [x for x in self._fetches if x.ref() not in feeds]
def _assert_fetchable(self, graph, op):
if not graph.is_fetchable(op):
@ -553,16 +551,16 @@ class _FetchHandler(object):
else:
# If the fetch was in the feeds, use the fed value, otherwise
# use the returned value.
if self._fetches[i].experimental_ref() in self._feed_handles:
if self._fetches[i].ref() in self._feed_handles:
# A fetch had a corresponding direct TensorHandle feed. Call eval()
# to obtain the Tensor value from the TensorHandle.
value = self._feed_handles[self._fetches[i].experimental_ref()].eval()
value = self._feed_handles[self._fetches[i].ref()].eval()
else:
value = self._feeds.get(self._fetches[i].experimental_ref())
value = self._feeds.get(self._fetches[i].ref())
if value is None:
value = tensor_values[j]
j += 1
dtype = self._fetch_handles.get(self._fetches[i].experimental_ref())
dtype = self._fetch_handles.get(self._fetches[i].ref())
if dtype:
full_values.append(session_ops.TensorHandle(value, dtype, session))
else:
@ -1147,7 +1145,7 @@ class BaseSession(SessionInterface):
session_ops.TensorHandle)
if is_tensor_handle_feed:
np_val = subfeed_val.to_numpy_array()
feed_handles[subfeed_t.experimental_ref()] = subfeed_val
feed_handles[subfeed_t.ref()] = subfeed_val
else:
np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
@ -1160,7 +1158,7 @@ class BaseSession(SessionInterface):
if not self.graph.is_feedable(subfeed_t):
raise ValueError('Tensor %s may not be fed.' % subfeed_t)
feed_dict_tensor[subfeed_t.experimental_ref()] = np_val
feed_dict_tensor[subfeed_t.ref()] = np_val
feed_map[compat.as_bytes(subfeed_t.name)] = (subfeed_t, subfeed_val)
# Create a fetch handler to take care of the structure of fetches.
@ -1435,7 +1433,7 @@ class BaseSession(SessionInterface):
np_val = np.array(handle.handle, dtype=np.object)
feed_name = handle_mover[0]
feed_tensor = feed_map[feed_name][0]
feed_dict[feed_tensor.experimental_ref()] = np_val
feed_dict[feed_tensor.ref()] = np_val
return handles
def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list,

View File

@ -659,7 +659,7 @@ def _zeros(shape, dtype):
device = ctx.device_name
if tensor_util.is_tensor(shape):
shape_key = shape.experimental_ref()
shape_key = shape.ref()
else:
shape_key = shape
cache_key = shape_key, dtype, device

View File

@ -361,11 +361,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
cf = f.get_concrete_function()
c = cc[0]
captured_variables = {v.experimental_ref() for v in (a, b, c)}
trainable_variables = {v.experimental_ref() for v in (b, c)}
self.assertEqual({v.experimental_ref() for v in cf.variables},
captured_variables)
self.assertEqual({v.experimental_ref() for v in cf.trainable_variables},
captured_variables = {v.ref() for v in (a, b, c)}
trainable_variables = {v.ref() for v in (b, c)}
self.assertEqual({v.ref() for v in cf.variables}, captured_variables)
self.assertEqual({v.ref() for v in cf.trainable_variables},
trainable_variables)
self.assertEqual(cf.variables, cf.graph.variables)
self.assertEqual(cf.trainable_variables, cf.graph.trainable_variables)
@ -2889,7 +2888,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
def testDecoratedMethodVariableCleanup(self):
m = DefunnedMiniModel()
m(array_ops.ones([1, 2]))
variable_refs = list({v.experimental_ref() for v in m.variables})
variable_refs = list({v.ref() for v in m.variables})
self.assertLen(variable_refs, 2)
del m

View File

@ -426,8 +426,8 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
strong_x = constant_op.constant([[1.]])
strong_y = constant_op.constant([[2.]])
strong_x_ref = strong_x.experimental_ref()
strong_y_ref = strong_y.experimental_ref()
strong_x_ref = strong_x.ref()
strong_y_ref = strong_y.ref()
weak_key_dict[strong_x_ref] = constant_op.constant([[3.]])
weak_key_dict[strong_y_ref] = constant_op.constant([[4.]])
strong_y.a = constant_op.constant([[5.]])

View File

@ -813,9 +813,9 @@ class _FuncGraph(ops.Graph):
def capture(self, tensor, name=None):
"""Adds the given tensor to this graph and returns the captured tensor."""
if tensor.experimental_ref() in self._captured:
if tensor.ref() in self._captured:
# Captured already.
return self._captured[tensor.experimental_ref()]
return self._captured[tensor.ref()]
elif self._capture_by_value:
return self._add_tensor_and_parents(tensor)
else:
@ -848,7 +848,7 @@ class _FuncGraph(ops.Graph):
compat.as_bytes(handle_data))
# pylint: enable=protected-access
self.inputs.append(ph)
self._captured[tensor.experimental_ref()] = ph
self._captured[tensor.ref()] = ph
self.extra_args.append(ph)
if _is_guaranteed_const(tensor):
with ops.control_dependencies(None):
@ -881,7 +881,7 @@ class _FuncGraph(ops.Graph):
op_def=op_def)
for t, captured_t in zip(op.outputs, captured_op.outputs):
self._captured[t.experimental_ref()] = captured_t
self._captured[t.ref()] = captured_t
return captured_op

View File

@ -724,8 +724,8 @@ class Tensor(_TensorLike):
g = getattr(self, "graph", None)
if (Tensor._USE_EQUALITY and executing_eagerly_outside_functions() and
(g is None or g.building_function)):
raise TypeError("Tensor is unhashable if Tensor equality is enabled. "
"Instead, use tensor.experimental_ref() as the key.")
raise TypeError("Tensor is unhashable. "
"Instead, use tensor.ref() as the key.")
else:
return id(self)
@ -814,56 +814,50 @@ class Tensor(_TensorLike):
"""
return _eval_using_default_session(self, feed_dict, self.graph, session)
@deprecation.deprecated(None, "Use ref() instead.")
def experimental_ref(self):
# tf.Variable also has the same experimental_ref() API. If you update the
# documenation here, please update tf.Variable.experimental_ref() as well.
return self.ref()
def ref(self):
# tf.Variable also has the same ref() API. If you update the
# documentation here, please update tf.Variable.ref() as well.
"""Returns a hashable reference object to this Tensor.
Warning: Experimental API that could be changed or removed.
The primary usecase for this API is to put tensors in a set/dictionary.
The primary use case for this API is to put tensors in a set/dictionary.
We can't put tensors in a set/dictionary as `tensor.__hash__()` is no longer
available starting Tensorflow 2.0.
```python
import tensorflow as tf
The following will raise an exception starting 2.0
x = tf.constant(5)
y = tf.constant(10)
z = tf.constant(10)
>>> x = tf.constant(5)
>>> y = tf.constant(10)
>>> z = tf.constant(10)
>>> tensor_set = {x, y, z}
Traceback (most recent call last):
...
TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
>>> tensor_dict = {x: 'five', y: 'ten'}
Traceback (most recent call last):
...
TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
# The followings will raise an exception starting 2.0
# TypeError: Tensor is unhashable if Tensor equality is enabled.
tensor_set = {x, y, z}
tensor_dict = {x: 'five', y: 'ten', z: 'ten'}
```
Instead, we can use `tensor.ref()`.
Instead, we can use `tensor.experimental_ref()`.
```python
tensor_set = {x.experimental_ref(),
y.experimental_ref(),
z.experimental_ref()}
print(x.experimental_ref() in tensor_set)
==> True
tensor_dict = {x.experimental_ref(): 'five',
y.experimental_ref(): 'ten',
z.experimental_ref(): 'ten'}
print(tensor_dict[y.experimental_ref()])
==> ten
```
>>> tensor_set = {x.ref(), y.ref(), z.ref()}
>>> x.ref() in tensor_set
True
>>> tensor_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'}
>>> tensor_dict[y.ref()]
'ten'
Also, the reference object provides `.deref()` function that returns the
original Tensor.
```python
x = tf.constant(5)
print(x.experimental_ref().deref())
==> tf.Tensor(5, shape=(), dtype=int32)
```
>>> x = tf.constant(5)
>>> x.ref().deref()
<tf.Tensor: shape=(), dtype=int32, numpy=5>
"""
return object_identity.Reference(self)
@ -4425,12 +4419,12 @@ class Graph(object):
def add_op(self, op):
if isinstance(op, Tensor):
op = op.experimental_ref()
op = op.ref()
self._seen_nodes.add(op)
def op_in_group(self, op):
if isinstance(op, Tensor):
op = op.experimental_ref()
op = op.ref()
return op in self._seen_nodes
def _push_control_dependencies_controller(self, controller):

View File

@ -210,19 +210,19 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
z = constant_op.constant([6, 10])
w = variables.Variable(5)
self.assertEqual(x1.experimental_ref(), x1.experimental_ref())
self.assertEqual(x2.experimental_ref(), x2.experimental_ref())
self.assertEqual(x1.experimental_ref(), x2.experimental_ref())
self.assertEqual(y.experimental_ref(), y.experimental_ref())
self.assertEqual(z.experimental_ref(), z.experimental_ref())
self.assertEqual(w.experimental_ref(), w.experimental_ref())
self.assertEqual(x1.ref(), x1.ref())
self.assertEqual(x2.ref(), x2.ref())
self.assertEqual(x1.ref(), x2.ref())
self.assertEqual(y.ref(), y.ref())
self.assertEqual(z.ref(), z.ref())
self.assertEqual(w.ref(), w.ref())
self.assertNotEqual(x1.experimental_ref(), y.experimental_ref())
self.assertNotEqual(x1.experimental_ref(), z.experimental_ref())
self.assertNotEqual(x1.experimental_ref(), w.experimental_ref())
self.assertNotEqual(y.experimental_ref(), z.experimental_ref())
self.assertNotEqual(y.experimental_ref(), w.experimental_ref())
self.assertNotEqual(z.experimental_ref(), w.experimental_ref())
self.assertNotEqual(x1.ref(), y.ref())
self.assertNotEqual(x1.ref(), z.ref())
self.assertNotEqual(x1.ref(), w.ref())
self.assertNotEqual(y.ref(), z.ref())
self.assertNotEqual(y.ref(), w.ref())
self.assertNotEqual(z.ref(), w.ref())
def testRefDeref(self):
x1 = constant_op.constant(3)
@ -231,19 +231,19 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
z = constant_op.constant([6, 10])
w = variables.Variable(5)
self.assertIs(x1, x1.experimental_ref().deref())
self.assertIs(x2, x2.experimental_ref().deref())
self.assertIs(x1, x2.experimental_ref().deref())
self.assertIs(x2, x1.experimental_ref().deref())
self.assertIs(y, y.experimental_ref().deref())
self.assertIs(z, z.experimental_ref().deref())
self.assertIs(x1, x1.ref().deref())
self.assertIs(x2, x2.ref().deref())
self.assertIs(x1, x2.ref().deref())
self.assertIs(x2, x1.ref().deref())
self.assertIs(y, y.ref().deref())
self.assertIs(z, z.ref().deref())
self.assertIsNot(x1, y.experimental_ref().deref())
self.assertIsNot(x1, z.experimental_ref().deref())
self.assertIsNot(x1, w.experimental_ref().deref())
self.assertIsNot(y, z.experimental_ref().deref())
self.assertIsNot(y, w.experimental_ref().deref())
self.assertIsNot(z, w.experimental_ref().deref())
self.assertIsNot(x1, y.ref().deref())
self.assertIsNot(x1, z.ref().deref())
self.assertIsNot(x1, w.ref().deref())
self.assertIsNot(y, z.ref().deref())
self.assertIsNot(y, w.ref().deref())
self.assertIsNot(z, w.ref().deref())
def testRefInSet(self):
x1 = constant_op.constant(3)
@ -252,22 +252,22 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
z = constant_op.constant([6, 10])
w = variables.Variable(5)
self.assertEqual(x1.experimental_ref(), x2.experimental_ref())
self.assertEqual(x1.ref(), x2.ref())
tensor_set = {
x1.experimental_ref(),
x2.experimental_ref(),
y.experimental_ref(),
z.experimental_ref(),
w.experimental_ref(),
x1.ref(),
x2.ref(),
y.ref(),
z.ref(),
w.ref(),
}
self.assertEqual(len(tensor_set), 4)
self.assertIn(x1.experimental_ref(), tensor_set)
self.assertIn(x2.experimental_ref(), tensor_set)
self.assertIn(y.experimental_ref(), tensor_set)
self.assertIn(z.experimental_ref(), tensor_set)
self.assertIn(w.experimental_ref(), tensor_set)
self.assertIn(x1.ref(), tensor_set)
self.assertIn(x2.ref(), tensor_set)
self.assertIn(y.ref(), tensor_set)
self.assertIn(z.ref(), tensor_set)
self.assertIn(w.ref(), tensor_set)
def testRefInDict(self):
x1 = constant_op.constant(3)
@ -276,36 +276,36 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
z = constant_op.constant([6, 10])
w = variables.Variable(5)
self.assertEqual(x1.experimental_ref(), x2.experimental_ref())
self.assertEqual(x1.ref(), x2.ref())
tensor_dict = {
x1.experimental_ref(): "x1",
y.experimental_ref(): "y",
z.experimental_ref(): "z",
w.experimental_ref(): "w",
x1.ref(): "x1",
y.ref(): "y",
z.ref(): "z",
w.ref(): "w",
}
self.assertEqual(len(tensor_dict), 4)
# Overwriting x1
tensor_dict[x2.experimental_ref()] = "x2"
tensor_dict[x2.ref()] = "x2"
self.assertEqual(len(tensor_dict), 4)
self.assertEqual(tensor_dict[x1.experimental_ref()], "x2")
self.assertEqual(tensor_dict[x2.experimental_ref()], "x2")
self.assertEqual(tensor_dict[y.experimental_ref()], "y")
self.assertEqual(tensor_dict[z.experimental_ref()], "z")
self.assertEqual(tensor_dict[w.experimental_ref()], "w")
self.assertEqual(tensor_dict[x1.ref()], "x2")
self.assertEqual(tensor_dict[x2.ref()], "x2")
self.assertEqual(tensor_dict[y.ref()], "y")
self.assertEqual(tensor_dict[z.ref()], "z")
self.assertEqual(tensor_dict[w.ref()], "w")
def testTensorRefStrong(self):
x = constant_op.constant(1.)
x_ref = x.experimental_ref()
x_ref = x.ref()
del x
self.assertIsNotNone(x_ref.deref())
def testVariableRefStrong(self):
x = variables.Variable(1.)
x_ref = x.experimental_ref()
x_ref = x.ref()
del x
self.assertIsNotNone(x_ref.deref())

View File

@ -868,9 +868,10 @@ class Lambda(Layer):
# checking only to immediately discard it.
return
tracked_weights = set(v.experimental_ref() for v in self.weights)
untracked_new_vars = [v for v in created_variables
if v.experimental_ref() not in tracked_weights]
tracked_weights = set(v.ref() for v in self.weights)
untracked_new_vars = [
v for v in created_variables if v.ref() not in tracked_weights
]
if untracked_new_vars:
variable_str = '\n'.join(' {}'.format(i) for i in untracked_new_vars)
error_str = textwrap.dedent(
@ -886,8 +887,9 @@ class Lambda(Layer):
).format(name=self.name, variable_str=variable_str)
raise ValueError(error_str)
untracked_used_vars = [v for v in accessed_variables
if v.experimental_ref() not in tracked_weights]
untracked_used_vars = [
v for v in accessed_variables if v.ref() not in tracked_weights
]
if untracked_used_vars and not self._already_warned:
variable_str = '\n'.join(' {}'.format(i) for i in untracked_used_vars)
self._warn(textwrap.dedent(

View File

@ -149,7 +149,7 @@ class AutoCastVariable(variables.Variable):
# reasons:
# * 'count_up_to': This method only applies to int variables, which cannot
# be wrapped with an AutoCastVariable.
# * 'experimental_ref': Instead we inherit the definition from Variable.
# * 'ref': Instead we inherit the definition from Variable.
# If we defined and delegated to Variable, the ref of an AutoCastVariable
# would be the same as the ref of the underlying variable, which would be
# strange as they are different Python objects.

View File

@ -539,8 +539,7 @@ class AdamOptimizerTest(test.TestCase):
opt = adam.Adam(1.)
opt.minimize(lambda: v1 + v2, var_list=[v1, v2])
# There should be iteration, and two unique slot variables for v1 and v2.
self.assertEqual(5,
len(set(v.experimental_ref() for v in opt.variables())))
self.assertEqual(5, len(set(v.ref() for v in opt.variables())))
self.assertEqual(
self.evaluate(opt.variables()[0]), self.evaluate(opt.iterations))

View File

@ -311,14 +311,16 @@ def _graph_mode_decorator(f, args, kwargs):
# Checking global and local variables attempts to ensure that no non-resource
# Variables are added to the graph.
current_var_scope = variable_scope.get_variable_scope()
before_vars = set(
[v.experimental_ref() for v in current_var_scope.global_variables() +
current_var_scope.local_variables()])
before_vars = set([
v.ref() for v in current_var_scope.global_variables() +
current_var_scope.local_variables()
])
with backprop.GradientTape() as tape:
result, grad_fn = f(*args)
after_vars = set(
[v.experimental_ref() for v in current_var_scope.global_variables() +
current_var_scope.local_variables()])
after_vars = set([
v.ref() for v in current_var_scope.global_variables() +
current_var_scope.local_variables()
])
new_vars = after_vars - before_vars
new_vars_list = [v.deref() for v in new_vars]
for v in new_vars_list:
@ -330,11 +332,10 @@ def _graph_mode_decorator(f, args, kwargs):
# The variables that grad_fn needs to return gradients for are the set of
# variables used that are *not* part of the inputs.
inputs = args
variables_in_tape = frozenset([
v.experimental_ref() for v in tape.watched_variables()
]) - frozenset(v.experimental_ref() for v in inputs)
variables_in_tape = frozenset([v.ref() for v in tape.watched_variables()
]) - frozenset(v.ref() for v in inputs)
variables_in_subgraph = frozenset([
v.experimental_ref()
v.ref()
for v in get_dependent_variables(input_ops=inputs, output_ops=result)
])
variables = list(
@ -411,7 +412,7 @@ def _eager_mode_decorator(f, args, kwargs):
# variables used that are *not* part of the inputs.
variables = [
v.deref() # pylint: disable=g-complex-comprehension
for v in set(v.experimental_ref() for v in tape.watched_variables())
for v in set(v.ref() for v in tape.watched_variables())
if all(v.deref() is not i for i in all_inputs)
]
grad_argspec = tf_inspect.getfullargspec(grad_fn)

View File

@ -1705,19 +1705,18 @@ def _SparseMatMulGrad(op, grad):
t_a = op.get_attr("transpose_a")
t_b = op.get_attr("transpose_b")
is_sparse = {}
is_sparse[op.inputs[0].experimental_ref()] = op.get_attr("a_is_sparse")
is_sparse[op.inputs[1].experimental_ref()] = op.get_attr("b_is_sparse")
is_sparse[op.inputs[0].ref()] = op.get_attr("a_is_sparse")
is_sparse[op.inputs[1].ref()] = op.get_attr("b_is_sparse")
# Use heuristic to figure out if grad might be sparse
is_sparse[grad.experimental_ref()] = not context.executing_eagerly() and (
is_sparse[grad.ref()] = not context.executing_eagerly() and (
grad.op.type == "ReluGrad")
def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False):
"""Helper function to create SparseMatMul op."""
assert t1.experimental_ref() in is_sparse and t2.experimental_ref(
) in is_sparse
t1_sparse = is_sparse[t1.experimental_ref()]
t2_sparse = is_sparse[t2.experimental_ref()]
assert t1.ref() in is_sparse and t2.ref() in is_sparse
t1_sparse = is_sparse[t1.ref()]
t2_sparse = is_sparse[t2.ref()]
if transpose_b:
t2 = array_ops.transpose(t2)
transpose_b = False

View File

@ -1622,7 +1622,7 @@ def _channel_flatten_input(x, data_format):
"""
graph = ops.get_default_graph()
cache_key = (graph, x.experimental_ref(), data_format)
cache_key = (graph, x.ref(), data_format)
if cache_key not in _channel_flatten_input_cache:
x_shape = array_ops.shape(x)
if data_format == b"NCHW":

View File

@ -1076,8 +1076,8 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
def __hash__(self):
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
raise TypeError("Variable is unhashable if Tensor equality is enabled. "
"Instead, use tensor.experimental_ref() as the key.")
raise TypeError("Variable is unhashable. "
"Instead, use tensor.ref() as the key.")
else:
return id(self)
@ -1209,56 +1209,50 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
def _get_save_slice_info(self):
return self._save_slice_info
@deprecated(None, "Use ref() instead.")
def experimental_ref(self):
# tf.Tensor also has the same experimental_ref() API. If you update the
# documenation here, please update tf.Tensor.experimental_ref() as well.
return self.ref()
def ref(self):
# tf.Tensor also has the same ref() API. If you update the
# documentation here, please update tf.Tensor.ref() as well.
"""Returns a hashable reference object to this Variable.
Warning: Experimental API that could be changed or removed.
The primary usecase for this API is to put variables in a set/dictionary.
The primary use case for this API is to put variables in a set/dictionary.
We can't put variables in a set/dictionary as `variable.__hash__()` is no
longer available starting Tensorflow 2.0.
```python
import tensorflow as tf
The following will raise an exception starting 2.0
x = tf.Variable(5)
y = tf.Variable(10)
z = tf.Variable(10)
>>> x = tf.Variable(5)
>>> y = tf.Variable(10)
>>> z = tf.Variable(10)
>>> variable_set = {x, y, z}
Traceback (most recent call last):
...
TypeError: Variable is unhashable. Instead, use tensor.ref() as the key.
>>> variable_dict = {x: 'five', y: 'ten'}
Traceback (most recent call last):
...
TypeError: Variable is unhashable. Instead, use tensor.ref() as the key.
# The followings will raise an exception starting 2.0
# TypeError: Variable is unhashable if Variable equality is enabled.
variable_set = {x, y, z}
variable_dict = {x: 'five', y: 'ten'}
```
Instead, we can use `variable.ref()`.
Instead, we can use `variable.experimental_ref()`.
```python
variable_set = {x.experimental_ref(),
y.experimental_ref(),
z.experimental_ref()}
print(x.experimental_ref() in variable_set)
==> True
variable_dict = {x.experimental_ref(): 'five',
y.experimental_ref(): 'ten',
z.experimental_ref(): 'ten'}
print(variable_dict[y.experimental_ref()])
==> ten
```
>>> variable_set = {x.ref(), y.ref(), z.ref()}
>>> x.ref() in variable_set
True
>>> variable_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'}
>>> variable_dict[y.ref()]
'ten'
Also, the reference object provides `.deref()` function that returns the
original Variable.
```python
x = tf.Variable(5)
print(x.experimental_ref().deref())
==> <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=5>
```
>>> x = tf.Variable(5)
>>> x.ref().deref()
<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=5>
"""
return object_identity.Reference(self)

View File

@ -460,8 +460,7 @@ def _get_intermediates(func_graph):
# 3. Do not accumulate loop vars that are returned as-is just like captured
# tensors.
intermediates = []
reverse_captures = dict(
(v.experimental_ref(), k) for k, v in func_graph.captures)
reverse_captures = dict((v.ref(), k) for k, v in func_graph.captures)
for op in func_graph.get_operations():
if op.type == "Identity":
@ -473,7 +472,7 @@ def _get_intermediates(func_graph):
if (o is not func_graph.inputs[0] and # Loop counter.
o.dtype != dtypes.resource and # Do not accumulate resource tensors.
_get_accumulator(o) is None and # Has existing accumulator.
o.experimental_ref() not in reverse_captures
o.ref() not in reverse_captures
): # Captured value, hence loop invariant.
intermediates.append(o)
return intermediates

View File

@ -433,7 +433,7 @@ class ExponentialMovingAverage(object):
raise TypeError("The variables must be half, float, or double: %s" %
var.name)
if var.experimental_ref() not in self._averages:
if var.ref() not in self._averages:
# For variables: to lower communication bandwidth across devices we keep
# the moving averages on the same device as the variables. For other
# tensors, we rely on the existing device allocation mechanism.
@ -455,8 +455,8 @@ class ExponentialMovingAverage(object):
"Variable", "VariableV2", "VarHandleOp"
]))
if self._zero_debias:
zero_debias_true.add(avg.experimental_ref())
self._averages[var.experimental_ref()] = avg
zero_debias_true.add(avg.ref())
self._averages[var.ref()] = avg
with ops.name_scope(self.name) as scope:
decay = ops.convert_to_tensor(self._decay, name="decay")
@ -467,8 +467,8 @@ class ExponentialMovingAverage(object):
(1.0 + num_updates) / (10.0 + num_updates))
updates = []
for var in var_list:
avg = self._averages[var.experimental_ref()]
zero_debias = avg.experimental_ref() in zero_debias_true
avg = self._averages[var.ref()]
zero_debias = avg.ref() in zero_debias_true
updates.append(assign_moving_average(avg, var, decay, zero_debias))
return control_flow_ops.group(*updates, name=scope)
@ -482,7 +482,7 @@ class ExponentialMovingAverage(object):
A `Variable` object or `None` if the moving average of `var`
is not maintained.
"""
return self._averages.get(var.experimental_ref(), None)
return self._averages.get(var.ref(), None)
def average_name(self, var):
"""Returns the name of the `Variable` holding the average for `var`.
@ -506,8 +506,8 @@ class ExponentialMovingAverage(object):
by the `ExponentialMovingAverage class` to hold the moving average of
`var`.
"""
if var.experimental_ref() in self._averages:
return self._averages[var.experimental_ref()].op.name
if var.ref() in self._averages:
return self._averages[var.ref()].op.name
return ops.get_default_graph().unique_name(
var.op.name + "/" + self.name, mark_as_used=False)

View File

@ -55,6 +55,10 @@ tf_class {
name: "get_shape"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "ref"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_shape"
argspec: "args=[\'self\', \'shape\'], varargs=None, keywords=None, defaults=None"

View File

@ -112,6 +112,10 @@ tf_class {
name: "read_value"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "ref"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "scatter_add"
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "

View File

@ -55,6 +55,10 @@ tf_class {
name: "get_shape"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "ref"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_shape"
argspec: "args=[\'self\', \'shape\'], varargs=None, keywords=None, defaults=None"

View File

@ -111,6 +111,10 @@ tf_class {
name: "read_value"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "ref"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "scatter_add"
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "