Retrace tf.function when using different variables
Since variable ids can be used to python data structures it can be dangerous to reuse the previous function trace if the shape and dtype is the same. We thus remove the optimization and ensure we always retrace. PiperOrigin-RevId: 317680370 Change-Id: I1dfa3a626074e623b735869aa724138072e2c274
This commit is contained in:
parent
36c56b0faa
commit
a6c009467a
|
@ -568,10 +568,10 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
|||
self.assertEqual(trace_count[0], 1)
|
||||
self.assertEqual(self.evaluate(v1), 2.0)
|
||||
double_variable(v2)
|
||||
self.assertEqual(trace_count[0], 1 if ops.Tensor._USE_EQUALITY else 2)
|
||||
self.assertEqual(trace_count[0], 2)
|
||||
self.assertEqual(self.evaluate(v2), 4.0)
|
||||
double_variable(v3)
|
||||
self.assertEqual(trace_count[0], 2 if ops.Tensor._USE_EQUALITY else 3)
|
||||
self.assertEqual(trace_count[0], 3)
|
||||
self.assertEqual(self.evaluate(v3), 8)
|
||||
|
||||
def testShapeCache(self):
|
||||
|
|
|
@ -92,7 +92,7 @@ IMPLEMENTS_ATTRIBUTE_NAME = "_implements"
|
|||
SHARED_RENDEZVOUS_ATTRIBUTE_NAME = "shared_rendezvous"
|
||||
|
||||
|
||||
def _make_input_signature_hashable(elem, variable_map=None):
|
||||
def _make_input_signature_hashable(elem):
|
||||
"""Rewrite input signature to be hashable.
|
||||
|
||||
We replace nested variables in the input signature with TensorSpec in order to
|
||||
|
@ -100,18 +100,13 @@ def _make_input_signature_hashable(elem, variable_map=None):
|
|||
|
||||
Args:
|
||||
elem: Input signature element
|
||||
variable_map: Internal argument used for tracking variable aliases
|
||||
|
||||
Returns:
|
||||
A hashable object for the requested input signature
|
||||
"""
|
||||
if variable_map is None:
|
||||
variable_map = {}
|
||||
|
||||
# TODO(slebedev): consider using nest.
|
||||
if isinstance(elem, tuple):
|
||||
return tuple(map(lambda e: _make_input_signature_hashable(e, variable_map),
|
||||
elem))
|
||||
return tuple(map(_make_input_signature_hashable, elem))
|
||||
|
||||
try:
|
||||
hash(elem)
|
||||
|
@ -122,15 +117,17 @@ def _make_input_signature_hashable(elem, variable_map=None):
|
|||
v = elem()
|
||||
|
||||
if resource_variable_ops.is_resource_variable(v):
|
||||
idx = variable_map.get(id(v))
|
||||
if idx is None:
|
||||
idx = len(variable_map)
|
||||
variable_map[id(v)] = idx
|
||||
|
||||
# We include the class name to avoid having different types of variables
|
||||
# having the same hash. We Also include the variable index which allows
|
||||
# us to return a different hash if variables have been aliased in a call.
|
||||
return v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype), idx
|
||||
# We special case variables here to use unique_id as the cache key. This
|
||||
# ensures we have to retrace whenever a different variable is passed in.
|
||||
# This is needed to support cases where the user may use the id of a
|
||||
# variable in the function perhaps as a lookup in a dictionary.
|
||||
#
|
||||
# This choice leads to more retracing when we could have possibly used the
|
||||
# shape and dtype instead. However, we expect the number of variables in a
|
||||
# program to be bounded, and correspondingly the number of retraces.
|
||||
#
|
||||
# Note we also include the class name to avoid collisions with strings.
|
||||
return v.__class__, v._unique_id # pylint: disable=protected-access
|
||||
|
||||
if _is_ndarray(v):
|
||||
# Numpy arrays are not hashable, but when calling functions we treat them
|
||||
|
|
|
@ -18,6 +18,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import functools
|
||||
import itertools
|
||||
import multiprocessing.pool
|
||||
|
@ -2930,30 +2931,57 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||
# should only get a miss if the aliasing changed.
|
||||
defined(x, y, z)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
# Calling again is a cache hit
|
||||
defined(x, y, z)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
# Re-arranging arguments doesn't change signature
|
||||
# Re-arranging arguments causes cache miss
|
||||
defined(z, y, x)
|
||||
self.assertLen(total_function_cache(defined),
|
||||
1 if ops.Tensor._USE_EQUALITY else 2)
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
defined(z, y, x)
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
# Aliasing causes cache miss
|
||||
defined(x, x, z)
|
||||
self.assertLen(total_function_cache(defined),
|
||||
2 if ops.Tensor._USE_EQUALITY else 3)
|
||||
self.assertLen(total_function_cache(defined), 3)
|
||||
defined(x, x, z)
|
||||
self.assertLen(total_function_cache(defined), 3)
|
||||
|
||||
# Re-arranging arguments doesn't change signature
|
||||
# Re-arranging arguments causes cache miss
|
||||
defined(y, y, z)
|
||||
self.assertLen(total_function_cache(defined),
|
||||
2 if ops.Tensor._USE_EQUALITY else 4)
|
||||
self.assertLen(total_function_cache(defined), 4)
|
||||
defined(y, y, z)
|
||||
self.assertLen(total_function_cache(defined), 4)
|
||||
|
||||
# Different alias positions causes cache miss
|
||||
defined(z, y, y)
|
||||
self.assertLen(total_function_cache(defined),
|
||||
3 if ops.Tensor._USE_EQUALITY else 5)
|
||||
self.assertLen(total_function_cache(defined), 5)
|
||||
defined(z, y, y)
|
||||
self.assertLen(total_function_cache(defined), 5)
|
||||
|
||||
x_copy = copy.deepcopy(x)
|
||||
|
||||
# Deep copy causes cache miss
|
||||
defined(x_copy, y, z)
|
||||
self.assertLen(total_function_cache(defined), 6)
|
||||
defined(x_copy, y, z)
|
||||
self.assertLen(total_function_cache(defined), 6)
|
||||
|
||||
def testVariableRetracing(self):
|
||||
v1 = variables.Variable(1.)
|
||||
v2 = variables.Variable(1.)
|
||||
v3 = copy.deepcopy(variables.Variable(1.))
|
||||
|
||||
var_dict = {id(v1): constant_op.constant(1),
|
||||
id(v2): constant_op.constant(2),
|
||||
id(v3): constant_op.constant(3)}
|
||||
|
||||
@function.defun
|
||||
def lookup_tensor(v):
|
||||
return var_dict[id(v)]
|
||||
|
||||
self.assertEqual(1, lookup_tensor(v1).numpy())
|
||||
self.assertEqual(2, lookup_tensor(v2).numpy())
|
||||
self.assertEqual(3, lookup_tensor(v3).numpy())
|
||||
|
||||
def testDecoratedMethodInspect(self):
|
||||
|
||||
|
|
Loading…
Reference in New Issue