Make experimental_ref no longer experimental
Also clean up related documentation to utilize doctest. PiperOrigin-RevId: 294256062 Change-Id: I9d7ff8dc1324bda64f3638ba3f8e48a0fcb42545
This commit is contained in:
parent
e253ee6945
commit
9c7c7c9979
@ -494,10 +494,8 @@ class _FetchHandler(object):
|
||||
if (isinstance(fetch, ops.Tensor) and
|
||||
(fetch.op.type == 'GetSessionHandle' or
|
||||
fetch.op.type == 'GetSessionHandleV2')):
|
||||
self._fetch_handles[fetch.experimental_ref()] = fetch.op.inputs[0].dtype
|
||||
self._final_fetches = [
|
||||
x for x in self._fetches if x.experimental_ref() not in feeds
|
||||
]
|
||||
self._fetch_handles[fetch.ref()] = fetch.op.inputs[0].dtype
|
||||
self._final_fetches = [x for x in self._fetches if x.ref() not in feeds]
|
||||
|
||||
def _assert_fetchable(self, graph, op):
|
||||
if not graph.is_fetchable(op):
|
||||
@ -553,16 +551,16 @@ class _FetchHandler(object):
|
||||
else:
|
||||
# If the fetch was in the feeds, use the fed value, otherwise
|
||||
# use the returned value.
|
||||
if self._fetches[i].experimental_ref() in self._feed_handles:
|
||||
if self._fetches[i].ref() in self._feed_handles:
|
||||
# A fetch had a corresponding direct TensorHandle feed. Call eval()
|
||||
# to obtain the Tensor value from the TensorHandle.
|
||||
value = self._feed_handles[self._fetches[i].experimental_ref()].eval()
|
||||
value = self._feed_handles[self._fetches[i].ref()].eval()
|
||||
else:
|
||||
value = self._feeds.get(self._fetches[i].experimental_ref())
|
||||
value = self._feeds.get(self._fetches[i].ref())
|
||||
if value is None:
|
||||
value = tensor_values[j]
|
||||
j += 1
|
||||
dtype = self._fetch_handles.get(self._fetches[i].experimental_ref())
|
||||
dtype = self._fetch_handles.get(self._fetches[i].ref())
|
||||
if dtype:
|
||||
full_values.append(session_ops.TensorHandle(value, dtype, session))
|
||||
else:
|
||||
@ -1147,7 +1145,7 @@ class BaseSession(SessionInterface):
|
||||
session_ops.TensorHandle)
|
||||
if is_tensor_handle_feed:
|
||||
np_val = subfeed_val.to_numpy_array()
|
||||
feed_handles[subfeed_t.experimental_ref()] = subfeed_val
|
||||
feed_handles[subfeed_t.ref()] = subfeed_val
|
||||
else:
|
||||
np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
|
||||
|
||||
@ -1160,7 +1158,7 @@ class BaseSession(SessionInterface):
|
||||
if not self.graph.is_feedable(subfeed_t):
|
||||
raise ValueError('Tensor %s may not be fed.' % subfeed_t)
|
||||
|
||||
feed_dict_tensor[subfeed_t.experimental_ref()] = np_val
|
||||
feed_dict_tensor[subfeed_t.ref()] = np_val
|
||||
feed_map[compat.as_bytes(subfeed_t.name)] = (subfeed_t, subfeed_val)
|
||||
|
||||
# Create a fetch handler to take care of the structure of fetches.
|
||||
@ -1435,7 +1433,7 @@ class BaseSession(SessionInterface):
|
||||
np_val = np.array(handle.handle, dtype=np.object)
|
||||
feed_name = handle_mover[0]
|
||||
feed_tensor = feed_map[feed_name][0]
|
||||
feed_dict[feed_tensor.experimental_ref()] = np_val
|
||||
feed_dict[feed_tensor.ref()] = np_val
|
||||
return handles
|
||||
|
||||
def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list,
|
||||
|
@ -659,7 +659,7 @@ def _zeros(shape, dtype):
|
||||
device = ctx.device_name
|
||||
|
||||
if tensor_util.is_tensor(shape):
|
||||
shape_key = shape.experimental_ref()
|
||||
shape_key = shape.ref()
|
||||
else:
|
||||
shape_key = shape
|
||||
cache_key = shape_key, dtype, device
|
||||
|
@ -361,11 +361,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
cf = f.get_concrete_function()
|
||||
c = cc[0]
|
||||
|
||||
captured_variables = {v.experimental_ref() for v in (a, b, c)}
|
||||
trainable_variables = {v.experimental_ref() for v in (b, c)}
|
||||
self.assertEqual({v.experimental_ref() for v in cf.variables},
|
||||
captured_variables)
|
||||
self.assertEqual({v.experimental_ref() for v in cf.trainable_variables},
|
||||
captured_variables = {v.ref() for v in (a, b, c)}
|
||||
trainable_variables = {v.ref() for v in (b, c)}
|
||||
self.assertEqual({v.ref() for v in cf.variables}, captured_variables)
|
||||
self.assertEqual({v.ref() for v in cf.trainable_variables},
|
||||
trainable_variables)
|
||||
self.assertEqual(cf.variables, cf.graph.variables)
|
||||
self.assertEqual(cf.trainable_variables, cf.graph.trainable_variables)
|
||||
@ -2889,7 +2888,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
def testDecoratedMethodVariableCleanup(self):
|
||||
m = DefunnedMiniModel()
|
||||
m(array_ops.ones([1, 2]))
|
||||
variable_refs = list({v.experimental_ref() for v in m.variables})
|
||||
variable_refs = list({v.ref() for v in m.variables})
|
||||
self.assertLen(variable_refs, 2)
|
||||
del m
|
||||
|
||||
|
@ -426,8 +426,8 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
strong_x = constant_op.constant([[1.]])
|
||||
strong_y = constant_op.constant([[2.]])
|
||||
strong_x_ref = strong_x.experimental_ref()
|
||||
strong_y_ref = strong_y.experimental_ref()
|
||||
strong_x_ref = strong_x.ref()
|
||||
strong_y_ref = strong_y.ref()
|
||||
weak_key_dict[strong_x_ref] = constant_op.constant([[3.]])
|
||||
weak_key_dict[strong_y_ref] = constant_op.constant([[4.]])
|
||||
strong_y.a = constant_op.constant([[5.]])
|
||||
|
@ -813,9 +813,9 @@ class _FuncGraph(ops.Graph):
|
||||
|
||||
def capture(self, tensor, name=None):
|
||||
"""Adds the given tensor to this graph and returns the captured tensor."""
|
||||
if tensor.experimental_ref() in self._captured:
|
||||
if tensor.ref() in self._captured:
|
||||
# Captured already.
|
||||
return self._captured[tensor.experimental_ref()]
|
||||
return self._captured[tensor.ref()]
|
||||
elif self._capture_by_value:
|
||||
return self._add_tensor_and_parents(tensor)
|
||||
else:
|
||||
@ -848,7 +848,7 @@ class _FuncGraph(ops.Graph):
|
||||
compat.as_bytes(handle_data))
|
||||
# pylint: enable=protected-access
|
||||
self.inputs.append(ph)
|
||||
self._captured[tensor.experimental_ref()] = ph
|
||||
self._captured[tensor.ref()] = ph
|
||||
self.extra_args.append(ph)
|
||||
if _is_guaranteed_const(tensor):
|
||||
with ops.control_dependencies(None):
|
||||
@ -881,7 +881,7 @@ class _FuncGraph(ops.Graph):
|
||||
op_def=op_def)
|
||||
|
||||
for t, captured_t in zip(op.outputs, captured_op.outputs):
|
||||
self._captured[t.experimental_ref()] = captured_t
|
||||
self._captured[t.ref()] = captured_t
|
||||
|
||||
return captured_op
|
||||
|
||||
|
@ -724,8 +724,8 @@ class Tensor(_TensorLike):
|
||||
g = getattr(self, "graph", None)
|
||||
if (Tensor._USE_EQUALITY and executing_eagerly_outside_functions() and
|
||||
(g is None or g.building_function)):
|
||||
raise TypeError("Tensor is unhashable if Tensor equality is enabled. "
|
||||
"Instead, use tensor.experimental_ref() as the key.")
|
||||
raise TypeError("Tensor is unhashable. "
|
||||
"Instead, use tensor.ref() as the key.")
|
||||
else:
|
||||
return id(self)
|
||||
|
||||
@ -814,56 +814,50 @@ class Tensor(_TensorLike):
|
||||
"""
|
||||
return _eval_using_default_session(self, feed_dict, self.graph, session)
|
||||
|
||||
@deprecation.deprecated(None, "Use ref() instead.")
|
||||
def experimental_ref(self):
|
||||
# tf.Variable also has the same experimental_ref() API. If you update the
|
||||
# documenation here, please update tf.Variable.experimental_ref() as well.
|
||||
return self.ref()
|
||||
|
||||
def ref(self):
|
||||
# tf.Variable also has the same ref() API. If you update the
|
||||
# documentation here, please update tf.Variable.ref() as well.
|
||||
"""Returns a hashable reference object to this Tensor.
|
||||
|
||||
Warning: Experimental API that could be changed or removed.
|
||||
|
||||
The primary usecase for this API is to put tensors in a set/dictionary.
|
||||
The primary use case for this API is to put tensors in a set/dictionary.
|
||||
We can't put tensors in a set/dictionary as `tensor.__hash__()` is no longer
|
||||
available starting Tensorflow 2.0.
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
The following will raise an exception starting 2.0
|
||||
|
||||
x = tf.constant(5)
|
||||
y = tf.constant(10)
|
||||
z = tf.constant(10)
|
||||
>>> x = tf.constant(5)
|
||||
>>> y = tf.constant(10)
|
||||
>>> z = tf.constant(10)
|
||||
>>> tensor_set = {x, y, z}
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
|
||||
>>> tensor_dict = {x: 'five', y: 'ten'}
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
|
||||
|
||||
# The followings will raise an exception starting 2.0
|
||||
# TypeError: Tensor is unhashable if Tensor equality is enabled.
|
||||
tensor_set = {x, y, z}
|
||||
tensor_dict = {x: 'five', y: 'ten', z: 'ten'}
|
||||
```
|
||||
Instead, we can use `tensor.ref()`.
|
||||
|
||||
Instead, we can use `tensor.experimental_ref()`.
|
||||
|
||||
```python
|
||||
tensor_set = {x.experimental_ref(),
|
||||
y.experimental_ref(),
|
||||
z.experimental_ref()}
|
||||
|
||||
print(x.experimental_ref() in tensor_set)
|
||||
==> True
|
||||
|
||||
tensor_dict = {x.experimental_ref(): 'five',
|
||||
y.experimental_ref(): 'ten',
|
||||
z.experimental_ref(): 'ten'}
|
||||
|
||||
print(tensor_dict[y.experimental_ref()])
|
||||
==> ten
|
||||
```
|
||||
>>> tensor_set = {x.ref(), y.ref(), z.ref()}
|
||||
>>> x.ref() in tensor_set
|
||||
True
|
||||
>>> tensor_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'}
|
||||
>>> tensor_dict[y.ref()]
|
||||
'ten'
|
||||
|
||||
Also, the reference object provides `.deref()` function that returns the
|
||||
original Tensor.
|
||||
|
||||
```python
|
||||
x = tf.constant(5)
|
||||
print(x.experimental_ref().deref())
|
||||
==> tf.Tensor(5, shape=(), dtype=int32)
|
||||
```
|
||||
>>> x = tf.constant(5)
|
||||
>>> x.ref().deref()
|
||||
<tf.Tensor: shape=(), dtype=int32, numpy=5>
|
||||
"""
|
||||
return object_identity.Reference(self)
|
||||
|
||||
@ -4425,12 +4419,12 @@ class Graph(object):
|
||||
|
||||
def add_op(self, op):
|
||||
if isinstance(op, Tensor):
|
||||
op = op.experimental_ref()
|
||||
op = op.ref()
|
||||
self._seen_nodes.add(op)
|
||||
|
||||
def op_in_group(self, op):
|
||||
if isinstance(op, Tensor):
|
||||
op = op.experimental_ref()
|
||||
op = op.ref()
|
||||
return op in self._seen_nodes
|
||||
|
||||
def _push_control_dependencies_controller(self, controller):
|
||||
|
@ -210,19 +210,19 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
||||
z = constant_op.constant([6, 10])
|
||||
w = variables.Variable(5)
|
||||
|
||||
self.assertEqual(x1.experimental_ref(), x1.experimental_ref())
|
||||
self.assertEqual(x2.experimental_ref(), x2.experimental_ref())
|
||||
self.assertEqual(x1.experimental_ref(), x2.experimental_ref())
|
||||
self.assertEqual(y.experimental_ref(), y.experimental_ref())
|
||||
self.assertEqual(z.experimental_ref(), z.experimental_ref())
|
||||
self.assertEqual(w.experimental_ref(), w.experimental_ref())
|
||||
self.assertEqual(x1.ref(), x1.ref())
|
||||
self.assertEqual(x2.ref(), x2.ref())
|
||||
self.assertEqual(x1.ref(), x2.ref())
|
||||
self.assertEqual(y.ref(), y.ref())
|
||||
self.assertEqual(z.ref(), z.ref())
|
||||
self.assertEqual(w.ref(), w.ref())
|
||||
|
||||
self.assertNotEqual(x1.experimental_ref(), y.experimental_ref())
|
||||
self.assertNotEqual(x1.experimental_ref(), z.experimental_ref())
|
||||
self.assertNotEqual(x1.experimental_ref(), w.experimental_ref())
|
||||
self.assertNotEqual(y.experimental_ref(), z.experimental_ref())
|
||||
self.assertNotEqual(y.experimental_ref(), w.experimental_ref())
|
||||
self.assertNotEqual(z.experimental_ref(), w.experimental_ref())
|
||||
self.assertNotEqual(x1.ref(), y.ref())
|
||||
self.assertNotEqual(x1.ref(), z.ref())
|
||||
self.assertNotEqual(x1.ref(), w.ref())
|
||||
self.assertNotEqual(y.ref(), z.ref())
|
||||
self.assertNotEqual(y.ref(), w.ref())
|
||||
self.assertNotEqual(z.ref(), w.ref())
|
||||
|
||||
def testRefDeref(self):
|
||||
x1 = constant_op.constant(3)
|
||||
@ -231,19 +231,19 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
||||
z = constant_op.constant([6, 10])
|
||||
w = variables.Variable(5)
|
||||
|
||||
self.assertIs(x1, x1.experimental_ref().deref())
|
||||
self.assertIs(x2, x2.experimental_ref().deref())
|
||||
self.assertIs(x1, x2.experimental_ref().deref())
|
||||
self.assertIs(x2, x1.experimental_ref().deref())
|
||||
self.assertIs(y, y.experimental_ref().deref())
|
||||
self.assertIs(z, z.experimental_ref().deref())
|
||||
self.assertIs(x1, x1.ref().deref())
|
||||
self.assertIs(x2, x2.ref().deref())
|
||||
self.assertIs(x1, x2.ref().deref())
|
||||
self.assertIs(x2, x1.ref().deref())
|
||||
self.assertIs(y, y.ref().deref())
|
||||
self.assertIs(z, z.ref().deref())
|
||||
|
||||
self.assertIsNot(x1, y.experimental_ref().deref())
|
||||
self.assertIsNot(x1, z.experimental_ref().deref())
|
||||
self.assertIsNot(x1, w.experimental_ref().deref())
|
||||
self.assertIsNot(y, z.experimental_ref().deref())
|
||||
self.assertIsNot(y, w.experimental_ref().deref())
|
||||
self.assertIsNot(z, w.experimental_ref().deref())
|
||||
self.assertIsNot(x1, y.ref().deref())
|
||||
self.assertIsNot(x1, z.ref().deref())
|
||||
self.assertIsNot(x1, w.ref().deref())
|
||||
self.assertIsNot(y, z.ref().deref())
|
||||
self.assertIsNot(y, w.ref().deref())
|
||||
self.assertIsNot(z, w.ref().deref())
|
||||
|
||||
def testRefInSet(self):
|
||||
x1 = constant_op.constant(3)
|
||||
@ -252,22 +252,22 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
||||
z = constant_op.constant([6, 10])
|
||||
w = variables.Variable(5)
|
||||
|
||||
self.assertEqual(x1.experimental_ref(), x2.experimental_ref())
|
||||
self.assertEqual(x1.ref(), x2.ref())
|
||||
|
||||
tensor_set = {
|
||||
x1.experimental_ref(),
|
||||
x2.experimental_ref(),
|
||||
y.experimental_ref(),
|
||||
z.experimental_ref(),
|
||||
w.experimental_ref(),
|
||||
x1.ref(),
|
||||
x2.ref(),
|
||||
y.ref(),
|
||||
z.ref(),
|
||||
w.ref(),
|
||||
}
|
||||
|
||||
self.assertEqual(len(tensor_set), 4)
|
||||
self.assertIn(x1.experimental_ref(), tensor_set)
|
||||
self.assertIn(x2.experimental_ref(), tensor_set)
|
||||
self.assertIn(y.experimental_ref(), tensor_set)
|
||||
self.assertIn(z.experimental_ref(), tensor_set)
|
||||
self.assertIn(w.experimental_ref(), tensor_set)
|
||||
self.assertIn(x1.ref(), tensor_set)
|
||||
self.assertIn(x2.ref(), tensor_set)
|
||||
self.assertIn(y.ref(), tensor_set)
|
||||
self.assertIn(z.ref(), tensor_set)
|
||||
self.assertIn(w.ref(), tensor_set)
|
||||
|
||||
def testRefInDict(self):
|
||||
x1 = constant_op.constant(3)
|
||||
@ -276,36 +276,36 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
||||
z = constant_op.constant([6, 10])
|
||||
w = variables.Variable(5)
|
||||
|
||||
self.assertEqual(x1.experimental_ref(), x2.experimental_ref())
|
||||
self.assertEqual(x1.ref(), x2.ref())
|
||||
|
||||
tensor_dict = {
|
||||
x1.experimental_ref(): "x1",
|
||||
y.experimental_ref(): "y",
|
||||
z.experimental_ref(): "z",
|
||||
w.experimental_ref(): "w",
|
||||
x1.ref(): "x1",
|
||||
y.ref(): "y",
|
||||
z.ref(): "z",
|
||||
w.ref(): "w",
|
||||
}
|
||||
|
||||
self.assertEqual(len(tensor_dict), 4)
|
||||
|
||||
# Overwriting x1
|
||||
tensor_dict[x2.experimental_ref()] = "x2"
|
||||
tensor_dict[x2.ref()] = "x2"
|
||||
self.assertEqual(len(tensor_dict), 4)
|
||||
|
||||
self.assertEqual(tensor_dict[x1.experimental_ref()], "x2")
|
||||
self.assertEqual(tensor_dict[x2.experimental_ref()], "x2")
|
||||
self.assertEqual(tensor_dict[y.experimental_ref()], "y")
|
||||
self.assertEqual(tensor_dict[z.experimental_ref()], "z")
|
||||
self.assertEqual(tensor_dict[w.experimental_ref()], "w")
|
||||
self.assertEqual(tensor_dict[x1.ref()], "x2")
|
||||
self.assertEqual(tensor_dict[x2.ref()], "x2")
|
||||
self.assertEqual(tensor_dict[y.ref()], "y")
|
||||
self.assertEqual(tensor_dict[z.ref()], "z")
|
||||
self.assertEqual(tensor_dict[w.ref()], "w")
|
||||
|
||||
def testTensorRefStrong(self):
|
||||
x = constant_op.constant(1.)
|
||||
x_ref = x.experimental_ref()
|
||||
x_ref = x.ref()
|
||||
del x
|
||||
self.assertIsNotNone(x_ref.deref())
|
||||
|
||||
def testVariableRefStrong(self):
|
||||
x = variables.Variable(1.)
|
||||
x_ref = x.experimental_ref()
|
||||
x_ref = x.ref()
|
||||
del x
|
||||
self.assertIsNotNone(x_ref.deref())
|
||||
|
||||
|
@ -868,9 +868,10 @@ class Lambda(Layer):
|
||||
# checking only to immediately discard it.
|
||||
return
|
||||
|
||||
tracked_weights = set(v.experimental_ref() for v in self.weights)
|
||||
untracked_new_vars = [v for v in created_variables
|
||||
if v.experimental_ref() not in tracked_weights]
|
||||
tracked_weights = set(v.ref() for v in self.weights)
|
||||
untracked_new_vars = [
|
||||
v for v in created_variables if v.ref() not in tracked_weights
|
||||
]
|
||||
if untracked_new_vars:
|
||||
variable_str = '\n'.join(' {}'.format(i) for i in untracked_new_vars)
|
||||
error_str = textwrap.dedent(
|
||||
@ -886,8 +887,9 @@ class Lambda(Layer):
|
||||
).format(name=self.name, variable_str=variable_str)
|
||||
raise ValueError(error_str)
|
||||
|
||||
untracked_used_vars = [v for v in accessed_variables
|
||||
if v.experimental_ref() not in tracked_weights]
|
||||
untracked_used_vars = [
|
||||
v for v in accessed_variables if v.ref() not in tracked_weights
|
||||
]
|
||||
if untracked_used_vars and not self._already_warned:
|
||||
variable_str = '\n'.join(' {}'.format(i) for i in untracked_used_vars)
|
||||
self._warn(textwrap.dedent(
|
||||
|
@ -149,7 +149,7 @@ class AutoCastVariable(variables.Variable):
|
||||
# reasons:
|
||||
# * 'count_up_to': This method only applies to int variables, which cannot
|
||||
# be wrapped with an AutoCastVariable.
|
||||
# * 'experimental_ref': Instead we inherit the definition from Variable.
|
||||
# * 'ref': Instead we inherit the definition from Variable.
|
||||
# If we defined and delegated to Variable, the ref of an AutoCastVariable
|
||||
# would be the same as the ref of the underlying variable, which would be
|
||||
# strange as they are different Python objects.
|
||||
|
@ -539,8 +539,7 @@ class AdamOptimizerTest(test.TestCase):
|
||||
opt = adam.Adam(1.)
|
||||
opt.minimize(lambda: v1 + v2, var_list=[v1, v2])
|
||||
# There should be iteration, and two unique slot variables for v1 and v2.
|
||||
self.assertEqual(5,
|
||||
len(set(v.experimental_ref() for v in opt.variables())))
|
||||
self.assertEqual(5, len(set(v.ref() for v in opt.variables())))
|
||||
self.assertEqual(
|
||||
self.evaluate(opt.variables()[0]), self.evaluate(opt.iterations))
|
||||
|
||||
|
@ -311,14 +311,16 @@ def _graph_mode_decorator(f, args, kwargs):
|
||||
# Checking global and local variables attempts to ensure that no non-resource
|
||||
# Variables are added to the graph.
|
||||
current_var_scope = variable_scope.get_variable_scope()
|
||||
before_vars = set(
|
||||
[v.experimental_ref() for v in current_var_scope.global_variables() +
|
||||
current_var_scope.local_variables()])
|
||||
before_vars = set([
|
||||
v.ref() for v in current_var_scope.global_variables() +
|
||||
current_var_scope.local_variables()
|
||||
])
|
||||
with backprop.GradientTape() as tape:
|
||||
result, grad_fn = f(*args)
|
||||
after_vars = set(
|
||||
[v.experimental_ref() for v in current_var_scope.global_variables() +
|
||||
current_var_scope.local_variables()])
|
||||
after_vars = set([
|
||||
v.ref() for v in current_var_scope.global_variables() +
|
||||
current_var_scope.local_variables()
|
||||
])
|
||||
new_vars = after_vars - before_vars
|
||||
new_vars_list = [v.deref() for v in new_vars]
|
||||
for v in new_vars_list:
|
||||
@ -330,11 +332,10 @@ def _graph_mode_decorator(f, args, kwargs):
|
||||
# The variables that grad_fn needs to return gradients for are the set of
|
||||
# variables used that are *not* part of the inputs.
|
||||
inputs = args
|
||||
variables_in_tape = frozenset([
|
||||
v.experimental_ref() for v in tape.watched_variables()
|
||||
]) - frozenset(v.experimental_ref() for v in inputs)
|
||||
variables_in_tape = frozenset([v.ref() for v in tape.watched_variables()
|
||||
]) - frozenset(v.ref() for v in inputs)
|
||||
variables_in_subgraph = frozenset([
|
||||
v.experimental_ref()
|
||||
v.ref()
|
||||
for v in get_dependent_variables(input_ops=inputs, output_ops=result)
|
||||
])
|
||||
variables = list(
|
||||
@ -411,7 +412,7 @@ def _eager_mode_decorator(f, args, kwargs):
|
||||
# variables used that are *not* part of the inputs.
|
||||
variables = [
|
||||
v.deref() # pylint: disable=g-complex-comprehension
|
||||
for v in set(v.experimental_ref() for v in tape.watched_variables())
|
||||
for v in set(v.ref() for v in tape.watched_variables())
|
||||
if all(v.deref() is not i for i in all_inputs)
|
||||
]
|
||||
grad_argspec = tf_inspect.getfullargspec(grad_fn)
|
||||
|
@ -1705,19 +1705,18 @@ def _SparseMatMulGrad(op, grad):
|
||||
t_a = op.get_attr("transpose_a")
|
||||
t_b = op.get_attr("transpose_b")
|
||||
is_sparse = {}
|
||||
is_sparse[op.inputs[0].experimental_ref()] = op.get_attr("a_is_sparse")
|
||||
is_sparse[op.inputs[1].experimental_ref()] = op.get_attr("b_is_sparse")
|
||||
is_sparse[op.inputs[0].ref()] = op.get_attr("a_is_sparse")
|
||||
is_sparse[op.inputs[1].ref()] = op.get_attr("b_is_sparse")
|
||||
# Use heuristic to figure out if grad might be sparse
|
||||
is_sparse[grad.experimental_ref()] = not context.executing_eagerly() and (
|
||||
is_sparse[grad.ref()] = not context.executing_eagerly() and (
|
||||
grad.op.type == "ReluGrad")
|
||||
|
||||
def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False):
|
||||
"""Helper function to create SparseMatMul op."""
|
||||
|
||||
assert t1.experimental_ref() in is_sparse and t2.experimental_ref(
|
||||
) in is_sparse
|
||||
t1_sparse = is_sparse[t1.experimental_ref()]
|
||||
t2_sparse = is_sparse[t2.experimental_ref()]
|
||||
assert t1.ref() in is_sparse and t2.ref() in is_sparse
|
||||
t1_sparse = is_sparse[t1.ref()]
|
||||
t2_sparse = is_sparse[t2.ref()]
|
||||
if transpose_b:
|
||||
t2 = array_ops.transpose(t2)
|
||||
transpose_b = False
|
||||
|
@ -1622,7 +1622,7 @@ def _channel_flatten_input(x, data_format):
|
||||
"""
|
||||
|
||||
graph = ops.get_default_graph()
|
||||
cache_key = (graph, x.experimental_ref(), data_format)
|
||||
cache_key = (graph, x.ref(), data_format)
|
||||
if cache_key not in _channel_flatten_input_cache:
|
||||
x_shape = array_ops.shape(x)
|
||||
if data_format == b"NCHW":
|
||||
|
@ -1076,8 +1076,8 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
|
||||
|
||||
def __hash__(self):
|
||||
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
|
||||
raise TypeError("Variable is unhashable if Tensor equality is enabled. "
|
||||
"Instead, use tensor.experimental_ref() as the key.")
|
||||
raise TypeError("Variable is unhashable. "
|
||||
"Instead, use tensor.ref() as the key.")
|
||||
else:
|
||||
return id(self)
|
||||
|
||||
@ -1209,56 +1209,50 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
|
||||
def _get_save_slice_info(self):
|
||||
return self._save_slice_info
|
||||
|
||||
@deprecated(None, "Use ref() instead.")
|
||||
def experimental_ref(self):
|
||||
# tf.Tensor also has the same experimental_ref() API. If you update the
|
||||
# documenation here, please update tf.Tensor.experimental_ref() as well.
|
||||
return self.ref()
|
||||
|
||||
def ref(self):
|
||||
# tf.Tensor also has the same ref() API. If you update the
|
||||
# documentation here, please update tf.Tensor.ref() as well.
|
||||
"""Returns a hashable reference object to this Variable.
|
||||
|
||||
Warning: Experimental API that could be changed or removed.
|
||||
|
||||
The primary usecase for this API is to put variables in a set/dictionary.
|
||||
The primary use case for this API is to put variables in a set/dictionary.
|
||||
We can't put variables in a set/dictionary as `variable.__hash__()` is no
|
||||
longer available starting Tensorflow 2.0.
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
The following will raise an exception starting 2.0
|
||||
|
||||
x = tf.Variable(5)
|
||||
y = tf.Variable(10)
|
||||
z = tf.Variable(10)
|
||||
>>> x = tf.Variable(5)
|
||||
>>> y = tf.Variable(10)
|
||||
>>> z = tf.Variable(10)
|
||||
>>> variable_set = {x, y, z}
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: Variable is unhashable. Instead, use tensor.ref() as the key.
|
||||
>>> variable_dict = {x: 'five', y: 'ten'}
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: Variable is unhashable. Instead, use tensor.ref() as the key.
|
||||
|
||||
# The followings will raise an exception starting 2.0
|
||||
# TypeError: Variable is unhashable if Variable equality is enabled.
|
||||
variable_set = {x, y, z}
|
||||
variable_dict = {x: 'five', y: 'ten'}
|
||||
```
|
||||
Instead, we can use `variable.ref()`.
|
||||
|
||||
Instead, we can use `variable.experimental_ref()`.
|
||||
|
||||
```python
|
||||
variable_set = {x.experimental_ref(),
|
||||
y.experimental_ref(),
|
||||
z.experimental_ref()}
|
||||
|
||||
print(x.experimental_ref() in variable_set)
|
||||
==> True
|
||||
|
||||
variable_dict = {x.experimental_ref(): 'five',
|
||||
y.experimental_ref(): 'ten',
|
||||
z.experimental_ref(): 'ten'}
|
||||
|
||||
print(variable_dict[y.experimental_ref()])
|
||||
==> ten
|
||||
```
|
||||
>>> variable_set = {x.ref(), y.ref(), z.ref()}
|
||||
>>> x.ref() in variable_set
|
||||
True
|
||||
>>> variable_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'}
|
||||
>>> variable_dict[y.ref()]
|
||||
'ten'
|
||||
|
||||
Also, the reference object provides `.deref()` function that returns the
|
||||
original Variable.
|
||||
|
||||
```python
|
||||
x = tf.Variable(5)
|
||||
print(x.experimental_ref().deref())
|
||||
==> <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=5>
|
||||
```
|
||||
>>> x = tf.Variable(5)
|
||||
>>> x.ref().deref()
|
||||
<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=5>
|
||||
"""
|
||||
return object_identity.Reference(self)
|
||||
|
||||
|
@ -460,8 +460,7 @@ def _get_intermediates(func_graph):
|
||||
# 3. Do not accumulate loop vars that are returned as-is just like captured
|
||||
# tensors.
|
||||
intermediates = []
|
||||
reverse_captures = dict(
|
||||
(v.experimental_ref(), k) for k, v in func_graph.captures)
|
||||
reverse_captures = dict((v.ref(), k) for k, v in func_graph.captures)
|
||||
|
||||
for op in func_graph.get_operations():
|
||||
if op.type == "Identity":
|
||||
@ -473,7 +472,7 @@ def _get_intermediates(func_graph):
|
||||
if (o is not func_graph.inputs[0] and # Loop counter.
|
||||
o.dtype != dtypes.resource and # Do not accumulate resource tensors.
|
||||
_get_accumulator(o) is None and # Has existing accumulator.
|
||||
o.experimental_ref() not in reverse_captures
|
||||
o.ref() not in reverse_captures
|
||||
): # Captured value, hence loop invariant.
|
||||
intermediates.append(o)
|
||||
return intermediates
|
||||
|
@ -433,7 +433,7 @@ class ExponentialMovingAverage(object):
|
||||
raise TypeError("The variables must be half, float, or double: %s" %
|
||||
var.name)
|
||||
|
||||
if var.experimental_ref() not in self._averages:
|
||||
if var.ref() not in self._averages:
|
||||
# For variables: to lower communication bandwidth across devices we keep
|
||||
# the moving averages on the same device as the variables. For other
|
||||
# tensors, we rely on the existing device allocation mechanism.
|
||||
@ -455,8 +455,8 @@ class ExponentialMovingAverage(object):
|
||||
"Variable", "VariableV2", "VarHandleOp"
|
||||
]))
|
||||
if self._zero_debias:
|
||||
zero_debias_true.add(avg.experimental_ref())
|
||||
self._averages[var.experimental_ref()] = avg
|
||||
zero_debias_true.add(avg.ref())
|
||||
self._averages[var.ref()] = avg
|
||||
|
||||
with ops.name_scope(self.name) as scope:
|
||||
decay = ops.convert_to_tensor(self._decay, name="decay")
|
||||
@ -467,8 +467,8 @@ class ExponentialMovingAverage(object):
|
||||
(1.0 + num_updates) / (10.0 + num_updates))
|
||||
updates = []
|
||||
for var in var_list:
|
||||
avg = self._averages[var.experimental_ref()]
|
||||
zero_debias = avg.experimental_ref() in zero_debias_true
|
||||
avg = self._averages[var.ref()]
|
||||
zero_debias = avg.ref() in zero_debias_true
|
||||
updates.append(assign_moving_average(avg, var, decay, zero_debias))
|
||||
return control_flow_ops.group(*updates, name=scope)
|
||||
|
||||
@ -482,7 +482,7 @@ class ExponentialMovingAverage(object):
|
||||
A `Variable` object or `None` if the moving average of `var`
|
||||
is not maintained.
|
||||
"""
|
||||
return self._averages.get(var.experimental_ref(), None)
|
||||
return self._averages.get(var.ref(), None)
|
||||
|
||||
def average_name(self, var):
|
||||
"""Returns the name of the `Variable` holding the average for `var`.
|
||||
@ -506,8 +506,8 @@ class ExponentialMovingAverage(object):
|
||||
by the `ExponentialMovingAverage class` to hold the moving average of
|
||||
`var`.
|
||||
"""
|
||||
if var.experimental_ref() in self._averages:
|
||||
return self._averages[var.experimental_ref()].op.name
|
||||
if var.ref() in self._averages:
|
||||
return self._averages[var.ref()].op.name
|
||||
return ops.get_default_graph().unique_name(
|
||||
var.op.name + "/" + self.name, mark_as_used=False)
|
||||
|
||||
|
@ -55,6 +55,10 @@ tf_class {
|
||||
name: "get_shape"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "ref"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_shape"
|
||||
argspec: "args=[\'self\', \'shape\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -112,6 +112,10 @@ tf_class {
|
||||
name: "read_value"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "ref"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "scatter_add"
|
||||
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
|
@ -55,6 +55,10 @@ tf_class {
|
||||
name: "get_shape"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "ref"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_shape"
|
||||
argspec: "args=[\'self\', \'shape\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -111,6 +111,10 @@ tf_class {
|
||||
name: "read_value"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "ref"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "scatter_add"
|
||||
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user