Retrace tf.function when using different variables

Since variable ids can be used to python data structures it can be
dangerous to reuse the previous function trace if the shape and dtype is
the same. We thus remove the optimization and ensure we always retrace.

PiperOrigin-RevId: 317680370
Change-Id: I1dfa3a626074e623b735869aa724138072e2c274
This commit is contained in:
Gaurav Jain 2020-06-22 10:20:00 -07:00 committed by TensorFlower Gardener
parent 36c56b0faa
commit a6c009467a
3 changed files with 55 additions and 30 deletions

View File

@ -568,10 +568,10 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertEqual(trace_count[0], 1)
self.assertEqual(self.evaluate(v1), 2.0)
double_variable(v2)
self.assertEqual(trace_count[0], 1 if ops.Tensor._USE_EQUALITY else 2)
self.assertEqual(trace_count[0], 2)
self.assertEqual(self.evaluate(v2), 4.0)
double_variable(v3)
self.assertEqual(trace_count[0], 2 if ops.Tensor._USE_EQUALITY else 3)
self.assertEqual(trace_count[0], 3)
self.assertEqual(self.evaluate(v3), 8)
def testShapeCache(self):

View File

@ -92,7 +92,7 @@ IMPLEMENTS_ATTRIBUTE_NAME = "_implements"
SHARED_RENDEZVOUS_ATTRIBUTE_NAME = "shared_rendezvous"
def _make_input_signature_hashable(elem, variable_map=None):
def _make_input_signature_hashable(elem):
"""Rewrite input signature to be hashable.
We replace nested variables in the input signature with TensorSpec in order to
@ -100,18 +100,13 @@ def _make_input_signature_hashable(elem, variable_map=None):
Args:
elem: Input signature element
variable_map: Internal argument used for tracking variable aliases
Returns:
A hashable object for the requested input signature
"""
if variable_map is None:
variable_map = {}
# TODO(slebedev): consider using nest.
if isinstance(elem, tuple):
return tuple(map(lambda e: _make_input_signature_hashable(e, variable_map),
elem))
return tuple(map(_make_input_signature_hashable, elem))
try:
hash(elem)
@ -122,15 +117,17 @@ def _make_input_signature_hashable(elem, variable_map=None):
v = elem()
if resource_variable_ops.is_resource_variable(v):
idx = variable_map.get(id(v))
if idx is None:
idx = len(variable_map)
variable_map[id(v)] = idx
# We include the class name to avoid having different types of variables
# having the same hash. We Also include the variable index which allows
# us to return a different hash if variables have been aliased in a call.
return v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype), idx
# We special case variables here to use unique_id as the cache key. This
# ensures we have to retrace whenever a different variable is passed in.
# This is needed to support cases where the user may use the id of a
# variable in the function perhaps as a lookup in a dictionary.
#
# This choice leads to more retracing when we could have possibly used the
# shape and dtype instead. However, we expect the number of variables in a
# program to be bounded, and correspondingly the number of retraces.
#
# Note we also include the class name to avoid collisions with strings.
return v.__class__, v._unique_id # pylint: disable=protected-access
if _is_ndarray(v):
# Numpy arrays are not hashable, but when calling functions we treat them

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
import copy
import functools
import itertools
import multiprocessing.pool
@ -2930,30 +2931,57 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
# should only get a miss if the aliasing changed.
defined(x, y, z)
self.assertLen(total_function_cache(defined), 1)
# Calling again is a cache hit
defined(x, y, z)
self.assertLen(total_function_cache(defined), 1)
# Re-arranging arguments doesn't change signature
# Re-arranging arguments causes cache miss
defined(z, y, x)
self.assertLen(total_function_cache(defined),
1 if ops.Tensor._USE_EQUALITY else 2)
self.assertLen(total_function_cache(defined), 2)
defined(z, y, x)
self.assertLen(total_function_cache(defined), 2)
# Aliasing causes cache miss
defined(x, x, z)
self.assertLen(total_function_cache(defined),
2 if ops.Tensor._USE_EQUALITY else 3)
self.assertLen(total_function_cache(defined), 3)
defined(x, x, z)
self.assertLen(total_function_cache(defined), 3)
# Re-arranging arguments doesn't change signature
# Re-arranging arguments causes cache miss
defined(y, y, z)
self.assertLen(total_function_cache(defined),
2 if ops.Tensor._USE_EQUALITY else 4)
self.assertLen(total_function_cache(defined), 4)
defined(y, y, z)
self.assertLen(total_function_cache(defined), 4)
# Different alias positions causes cache miss
defined(z, y, y)
self.assertLen(total_function_cache(defined),
3 if ops.Tensor._USE_EQUALITY else 5)
self.assertLen(total_function_cache(defined), 5)
defined(z, y, y)
self.assertLen(total_function_cache(defined), 5)
x_copy = copy.deepcopy(x)
# Deep copy causes cache miss
defined(x_copy, y, z)
self.assertLen(total_function_cache(defined), 6)
defined(x_copy, y, z)
self.assertLen(total_function_cache(defined), 6)
def testVariableRetracing(self):
v1 = variables.Variable(1.)
v2 = variables.Variable(1.)
v3 = copy.deepcopy(variables.Variable(1.))
var_dict = {id(v1): constant_op.constant(1),
id(v2): constant_op.constant(2),
id(v3): constant_op.constant(3)}
@function.defun
def lookup_tensor(v):
return var_dict[id(v)]
self.assertEqual(1, lookup_tensor(v1).numpy())
self.assertEqual(2, lookup_tensor(v2).numpy())
self.assertEqual(3, lookup_tensor(v3).numpy())
def testDecoratedMethodInspect(self):