From d8033ab10d2b59309106040bb8aefb95cfd75ca9 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Thu, 21 Feb 2019 13:14:20 -0800 Subject: [PATCH] Turn on shape relaxation retracing in @tf.function. PiperOrigin-RevId: 235054774 --- tensorflow/python/eager/def_function.py | 6 +- tensorflow/python/eager/function.py | 279 ++++++++++++++++++---- tensorflow/python/eager/function_test.py | 217 ++++++++++++----- tensorflow/python/eager/pywrap_tfe.h | 7 +- tensorflow/python/eager/pywrap_tfe_src.cc | 77 ++++-- tensorflow/python/framework/func_graph.py | 94 ++++++-- 6 files changed, 528 insertions(+), 152 deletions(-) diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 23067cf1a6a..1d54973487c 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -578,9 +578,11 @@ class Function(object): concrete_functions = [] # pylint: disable=protected-access if self._stateful_fn: - concrete_functions.extend(self._stateful_fn._function_cache.values()) + concrete_functions.extend( + self._stateful_fn._function_cache.all_values()) if self._stateless_fn: - concrete_functions.extend(self._stateless_fn._function_cache.values()) + concrete_functions.extend( + self._stateless_fn._function_cache.all_values()) # pylint: enable=protected-access deduplicated_concrete_functions = list() seen_signatures = list() diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 45e5e5cc818..c53cde692e3 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -43,6 +43,7 @@ from tensorflow.python.framework import error_interpolation from tensorflow.python.framework import errors from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import functional_ops @@ -60,10 +61,94 @@ from tensorflow.python.util import tf_inspect FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name" BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name" -CacheKey = collections.namedtuple("CacheKey", [ - "input_signature", "parent_graph", "device_functions", "colocation_stack", - "uses_xla" -]) +class CacheKey( + collections.namedtuple("CacheKey", [ + "input_signature", "parent_graph", "device_functions", + "colocation_stack", "uses_xla"])): + + def replace(self, *args, **kwargs): + return self._replace(*args, **kwargs) + + +def _flat_shape_list(*params): + """Return a flat list of TensorShapes, one for each tensor[spec] in `*params`. + + Args: + *params: Set of nested entries containing Tensors, TensorSpec, and + non-tensors. + + Returns: + A list of entries containing either `None` or `TensorShape`. + """ + return [tensor_shape.TensorShape(x.shape) + if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) else None + for x in nest.flatten(params)] + + +def _compatible_shapes(flat_x, flat_y): + """Check if lists of TensorShapes contain compatible shapes. + + Args: + flat_x: List of TensorShape or None. + flat_y: List of TensorShape or None. + + Returns: + A python bool. + + Raises: + RuntimeError: if `len(flat_x) != len(flat_y)`. + RuntimeError: if `flat_x[i] is None != flat_y[i] is None` for any `i`. + """ + if len(flat_x) != len(flat_y): + raise RuntimeError("Expected shape lists of identical lengths, but saw: " + "%s and %s" % (flat_x, flat_y)) + def is_compatible(x, y): + """Internal help function. + + Args: + x: TensorShape or None. + y: TensorShape or None. + + Returns: + Python bool. + + Raises: + RuntimeError: If `x is None != y is None`. + """ + # If both x and y are None, there is no shape to compare. Otherwise check + # if they are compatible with each other. Either way, both input signatures + # must have have Tensors in the same entries. If not, raise an assertion + # error. + if x is None != y is None: + raise RuntimeError( + "Expected signature type matches between flattened input shapes " + "%s and %s; but saw that (%s is None) != (%s is None)" + % (flat_x, flat_y, x, y)) + return x is None or x.is_compatible_with(y) + return all(is_compatible(x, y) for x, y in zip(flat_x, flat_y)) + + +def _common_shape(x, y): + """Find a `TensorShape` that is compatible with both `x` and `y`.""" + if x is None != y is None: + raise RuntimeError( + "Cannot find a common shape when LHS shape is None but RHS shape " + "is not (or vice versa): %s vs. %s" % (x, y)) + if x is None: + return None # The associated input was not a Tensor, no shape generated. + if not isinstance(x, tensor_shape.TensorShape): + raise TypeError("Expected x to be a TensorShape but saw %s" % (x,)) + if not isinstance(y, tensor_shape.TensorShape): + raise TypeError("Expected y to be a TensorShape but saw %s" % (y,)) + if x.rank != y.rank or x.rank is None: + return tensor_shape.TensorShape(None) + dims = [] + for dim_x, dim_y in zip(x.dims, y.dims): + if dim_x != dim_y or tensor_shape.dimension_value(dim_x) is None: + dims.append(None) + else: + dims.append(tensor_shape.dimension_value(dim_x)) + return tensor_shape.TensorShape(dims) def is_same_structure(structure1, @@ -1073,6 +1158,34 @@ def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature): return inputs +class FunctionCache(object): + """A lightweight container for cached functions. + """ + + def __init__(self): + # The set of functions that have been missed; entries are CacheKey with + # input_signature `None` (e.g. a "call context key") + self.missed = set() + # The primary cache, mapping a fully shaped CacheKey to a function. + self.primary = collections.OrderedDict() + # A cache key lookup, mapping a CacheKey generated without shape info to a + # flat list of relaxed shapes (one for each argument). Arguments that are + # not Tensors contain a `None` for the corresponding relaxed shape. + self.arg_relaxed_shapes = collections.OrderedDict() + # The secondary cache, mapping a CacheKey generated without shape info to a + # function. + self.arg_relaxed = collections.OrderedDict() + # All OrderedDicts require manual garbage collection. + self._garbage_collectors = [ + _FunctionGarbageCollector(self.primary), + _FunctionGarbageCollector(self.arg_relaxed), + _FunctionGarbageCollector(self.arg_relaxed_shapes)] + + def all_values(self): + """A set of all `ConcreteFunction` instances held by this cache.""" + return set(self.primary.values()) | set(self.arg_relaxed.values()) + + class Function(object): """Wrapper class for the graph functions defined for a Python function. @@ -1126,8 +1239,7 @@ class Function(object): self._name = name self._autograph = autograph self._autograph_options = autograph_options - self._function_cache = collections.OrderedDict() - self._garbage_collector = _FunctionGarbageCollector(self._function_cache) + self._function_cache = FunctionCache() self._function_attributes = attributes or {} self._capture_by_value = capture_by_value @@ -1284,13 +1396,15 @@ class Function(object): # Return the cached `Function` for the instance return self._descriptor_cache[instance] - def _cache_key(self, args, kwargs): + def _cache_key(self, args, kwargs, include_tensor_ranks_only=False): """Computes the cache key given inputs and execution context.""" if self._input_signature is None: inputs = (args, kwargs) if kwargs else args - input_signature = pywrap_tensorflow.TFE_Py_EncodeArg(inputs) + input_signature = pywrap_tensorflow.TFE_Py_EncodeArg( + inputs, include_tensor_ranks_only) else: del args, kwargs + assert not include_tensor_ranks_only input_signature = self._flat_input_signature ctx = context.context() @@ -1336,6 +1450,46 @@ class Function(object): return CacheKey(input_signature, parent_graph, device_functions, colocation_stack, uses_xla) + def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None): + """Create a `ConcreteFunction` from `args` and `kwargs`.""" + if self._input_signature is None: + arglen = len(args) + else: + arglen = len(self._input_signature) + base_arg_names = self._function_spec.arg_names[:arglen] + num_missing_args = arglen - len(self._function_spec.arg_names) + missing_arg_names = [self._function_spec.vararg_name] * num_missing_args + # Produce a list of missing args of the form ["arg_0", "arg_1", ...], + # where arg is based on the self._function_spec.vararg_name. + missing_arg_names = [ + "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names) + ] + arg_names = base_arg_names + missing_arg_names + graph_function = ConcreteFunction( + func_graph_module.func_graph_from_py_func( + self._name, + self._python_function, + args, + kwargs, + self._input_signature, + autograph=self._autograph, + autograph_options=self._autograph_options, + arg_names=arg_names, + override_flat_arg_shapes=override_flat_arg_shapes, + capture_by_value=self._capture_by_value), + self._function_attributes) + + # pylint: disable=protected-access + # Tell the ConcreteFunction to clean up its graph once it goes out of + # scope. ConcreteFunction does not do this in its constructor since it + # gets used in some places (like Keras) where the FuncGraph lives + # longer than the ConcreteFunction. + graph_function._garbage_collector = ConcreteFunctionGarbageCollector( + graph_function.graph) + # pylint: enable=protected-access + + return graph_function + def _maybe_define_function(self, args, kwargs): """Gets a function for these inputs, defining it if necessary. @@ -1353,57 +1507,76 @@ class Function(object): Raises: ValueError: If inputs are incompatible with the input signature. TypeError: If the function inputs include non-hashable objects + RuntimeError: If there's an internal bug (inconsistency) in handling + shape relaxation retracing. """ if self._input_signature is None or args is not None or kwargs is not None: args, kwargs = self._function_spec.canonicalize_function_inputs( *args, **kwargs) cache_key = self._cache_key(args, kwargs) - with self._lock: - try: - graph_function = self._function_cache.get(cache_key, None) - except TypeError as e: - raise TypeError( - "Arguments supplied to `defun`-generated functions must be" - " hashable. Original error: %s" % e) - if graph_function is None: - logging.vlog(1, - "Creating new FuncGraph for Python function %r (key: %r)", - self._python_function, cache_key) - if self._input_signature is None: - arglen = len(args) - else: - arglen = len(self._input_signature) - base_arg_names = self._function_spec.arg_names[:arglen] - num_missing_args = arglen - len(self._function_spec.arg_names) - missing_arg_names = [self._function_spec.vararg_name] * num_missing_args - # Produce a list of missing args of the form ["arg_0", "arg_1", ...], - # where arg is based on the self._function_spec.vararg_name. - missing_arg_names = [ - "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names) - ] - arg_names = base_arg_names + missing_arg_names - graph_function = ConcreteFunction( - func_graph_module.func_graph_from_py_func( - self._name, - self._python_function, - args, - kwargs, - self._input_signature, - autograph=self._autograph, - autograph_options=self._autograph_options, - arg_names=arg_names, - capture_by_value=self._capture_by_value), - self._function_attributes) - # pylint: disable=protected-access - # Tell the ConcreteFunction to clean up its graph once it goes out of - # scope. ConcreteFunction does not do this in its constructor since it - # gets used in some places (like Keras) where the FuncGraph lives - # longer than the ConcreteFunction. - graph_function._garbage_collector = ConcreteFunctionGarbageCollector( - graph_function.graph) - # pylint: enable=protected-access - self._function_cache[cache_key] = graph_function + try: + hash(cache_key) + except TypeError as e: + raise TypeError( + "Arguments supplied to `defun`-generated functions must be" + " hashable. Original error: %s" % e) + + with self._lock: + graph_function = self._function_cache.primary.get(cache_key, None) + if graph_function is not None: + return graph_function, args, kwargs + + logging.vlog(1, + "Creating new FuncGraph for Python function %r (key: %r)", + self._python_function, cache_key) + logging.vlog(2, + "Python function signature [args: %s] [kwargs: %s]", + str(args), + str(kwargs)) + + call_context_key = cache_key.replace(input_signature=None) + + # If there's a provided input signature, or XLA is being used, or + # there's no cache miss for this calling context so far, go ahead and + # build the function and bypass shape relaxation retracing. + if (self._input_signature is not None + or cache_key.uses_xla + or call_context_key not in self._function_cache.missed): + self._function_cache.missed.add(call_context_key) + graph_function = self._create_graph_function(args, kwargs) + self._function_cache.primary[cache_key] = graph_function + return graph_function, args, kwargs + + rank_only_cache_key = self._cache_key( + args, kwargs, include_tensor_ranks_only=True) + + arg_shapes = _flat_shape_list(args, kwargs) + relaxed_arg_shapes = self._function_cache.arg_relaxed_shapes.get( + rank_only_cache_key, None) + relaxed_arg_function = self._function_cache.arg_relaxed.get( + rank_only_cache_key, None) + + if (relaxed_arg_function is not None + and _compatible_shapes(relaxed_arg_shapes, arg_shapes)): + return relaxed_arg_function, args, kwargs + + if relaxed_arg_shapes is None: + relaxed_arg_shapes = arg_shapes + else: + if len(arg_shapes) != len(relaxed_arg_shapes): + raise RuntimeError("Expected arg_shapes len to match " + "relaxed_arg_shapes len: %d vs. %d" + % (len(arg_shapes), len(relaxed_arg_shapes))) + relaxed_arg_shapes = [ + _common_shape(x, y) for (x, y) in zip( + arg_shapes, relaxed_arg_shapes)] + self._function_cache.arg_relaxed_shapes[rank_only_cache_key] = ( + relaxed_arg_shapes) + graph_function = self._create_graph_function( + args, kwargs, override_flat_arg_shapes=relaxed_arg_shapes) + self._function_cache.arg_relaxed[rank_only_cache_key] = graph_function + return graph_function, args, kwargs diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 6f886766c27..270e420fa8c 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -66,6 +66,13 @@ from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect +def total_function_cache(defined): + # pylint: disable=protected-access + return (set(defined._function_cache.primary) + | set(defined._function_cache.arg_relaxed)) + # pylint: enable=protected-access + + class MiniModel(keras_training.Model): """Minimal model for mnist. @@ -99,6 +106,94 @@ class FunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108]) + def testVariable(self): + v1 = variables.Variable(1.0) + add = def_function.function(lambda x, v: x + v1 + v) + v2 = variables.Variable(1.0) + x = constant_op.constant(1.0) + r = add(x, v2) + self.assertEqual(3.0, self.evaluate(r)) + + def testInputShapeFunctionRelaxation(self): + unknown_dim = [False] + + @function.defun + def func(a): + if a._shape_tuple()[0] is None: + unknown_dim[0] = True + return a + 1 + + func(constant_op.constant([])) + self.assertFalse(unknown_dim[0]) + self.assertLen(total_function_cache(func), 1) + + func(constant_op.constant([1.0])) + self.assertFalse(unknown_dim[0]) + self.assertLen(total_function_cache(func), 2) + + func(constant_op.constant([1.0, 2.0])) + self.assertTrue(unknown_dim[0]) + self.assertLen(total_function_cache(func), 2) + + def testNestedInputShapeFunctionRelaxation(self): + unknown_dim = [False] + + @function.defun + def func(a_, b_=None): + del a_ # Only used to check which cache is used. + self.assertEqual(b_[0]._shape_tuple(), ()) + if b_[1]._shape_tuple()[0] is None: + unknown_dim[0] = True + return b_[0] + 1 + + a = 'hi' + b0 = constant_op.constant(1.0) + func(a, b_=[b0, constant_op.constant([])]) + self.assertFalse(unknown_dim[0]) + self.assertLen(total_function_cache(func), 1) + + func(a, b_=[b0, constant_op.constant([1.0])]) + self.assertFalse(unknown_dim[0]) + self.assertLen(total_function_cache(func), 2) + + func(a, b_=[b0, constant_op.constant([1.0, 1.0])]) + self.assertTrue(unknown_dim[0]) + self.assertLen(total_function_cache(func), 2) + + unknown_dim[0] = False + + # Now do the same except with a new a which is not a tensor; this should + # change the cache key. + a = 'bye' + func(a, b_=[b0, constant_op.constant([])]) + self.assertFalse(unknown_dim[0]) + self.assertLen(total_function_cache(func), 3) + + # Since we already marked a cache miss for a function with the same + # non-input signatures, here we will immediately start relaxing shapes. + func(a, b_=[b0, constant_op.constant([1.0])]) + self.assertTrue(unknown_dim[0]) + self.assertLen(total_function_cache(func), 3) + + def testFunctionRelaxationLosesInnerDimWithKerasLayer(self): + layer = keras.layers.Dense(1) + fn = def_function.function()(layer) + + with self.captureWritesToStream(sys.stderr) as printed: + fn(array_ops.ones((3, 2))) + self.assertNotIn('ValueError', printed.contents()) + with self.captureWritesToStream(sys.stderr) as printed: + # Use batch size 2 to trigger a second cache miss on the shape. + fn(array_ops.ones((2, 2))) + self.assertNotIn('ValueError', printed.contents()) + + # Shape relaxation passes TensorShape([None, None]), which causes layer + # matmul to fail, due to incompatible dims. What would have been a graph + # build time error (layer would complain about the inner dim being 4). + with self.captureWritesToStream(sys.stderr) as printed: + with self.assertRaisesRegexp(errors.InvalidArgumentError, r'MatMul'): + fn(array_ops.ones((3, 4))) + def testWastedAdd(self): @def_function.function() @@ -382,13 +477,13 @@ class FunctionTest(test.TestCase, parameterized.TestCase): x = random_ops.random_uniform([2, 2]).numpy() defined = function.defun(f) defined(x) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) x = random_ops.random_uniform([2, 2]).numpy() defined(x) # A NumPy array with different values but the same shape and dtype # shouldn't trigger another function definition. - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) # Test that the numpy array is properly an argument to the graph function. self.assertEqual(1., defined(numpy.ones([])).numpy()) @@ -1110,7 +1205,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): defined = function.defun(multi_device_fn) outputs = self.evaluate(defined()) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) @@ -1118,7 +1213,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): with ops.device('/cpu:3'): outputs = self.evaluate(defined()) # All function definitions are agnostic to call site devices. - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) @@ -1126,7 +1221,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): with ops.device('/cpu:0'): outputs = self.evaluate(defined()) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) @@ -1211,10 +1306,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase): defined = function.defun(func) defined(Foo()) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) defined(Foo()) - self.assertLen(defined._function_cache, 2) + self.assertLen(total_function_cache(defined), 2) def testCacheTensorDtypeCollision(self): @@ -1224,11 +1319,11 @@ class FunctionTest(test.TestCase, parameterized.TestCase): defined = function.defun(func) t = constant_op.constant([[1.0]], dtype=dtypes.complex64) defined(t) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) t = constant_op.constant([[1.0]], dtype=dtypes.complex128) defined(t) - self.assertLen(defined._function_cache, 2) + self.assertLen(total_function_cache(defined), 2) def testCacheTensorShapeCollision(self): @@ -1238,11 +1333,11 @@ class FunctionTest(test.TestCase, parameterized.TestCase): defined = function.defun(func) t = constant_op.constant([[1.0]], dtype=dtypes.complex64) defined(t) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) t = constant_op.constant([1.0], dtype=dtypes.complex64) defined(t) - self.assertLen(defined._function_cache, 2) + self.assertLen(total_function_cache(defined), 2) def testCacheTensorShapeDtypeCollision(self): @@ -1252,11 +1347,11 @@ class FunctionTest(test.TestCase, parameterized.TestCase): defined = function.defun(func) t = constant_op.constant([[1.0]], dtype=dtypes.complex64) defined(t) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) t = constant_op.constant([1.0], dtype=dtypes.complex128) defined(t) - self.assertLen(defined._function_cache, 2) + self.assertLen(total_function_cache(defined), 2) def testCacheTensorUnknownShapesCollision(self): @@ -1266,21 +1361,34 @@ class FunctionTest(test.TestCase, parameterized.TestCase): with context.graph_mode(), self.cached_session(): defined = function.defun(func) - p = array_ops.placeholder(dtype=dtypes.float32, shape=None) + p = array_ops.placeholder(dtype=dtypes.float32, shape=[]) defined(p) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) - p = array_ops.placeholder(dtype=dtypes.float32, shape=[None]) + p = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) defined(p) - self.assertLen(defined._function_cache, 2) + self.assertLen(total_function_cache(defined), 2) - p = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None]) + p = array_ops.placeholder(dtype=dtypes.float32, shape=[2]) defined(p) - self.assertLen(defined._function_cache, 3) + # Gradual shape relaxation is performed; and the common shape between + # [1] and [2] is one containing unknown dimensions. + self.assertLen(total_function_cache(defined), 2) - t = constant_op.constant(1.0, dtype=dtypes.float32) + # pylint: disable=protected-access + self.assertLen(defined._function_cache.arg_relaxed_shapes, 1) + relaxed_shapes = ( + list(defined._function_cache.arg_relaxed_shapes.values())[0]) + self.assertEqual(len(relaxed_shapes), 1) + relaxed_shape = relaxed_shapes[0] + # pylint: enable=protected-access + self.assertEqual(relaxed_shape.rank, 1) + self.assertEqual(tensor_shape.dimension_value(relaxed_shape[0]), None) + + t = constant_op.constant([1.0, 1.0, 1.0], dtype=dtypes.float32) defined(t) - self.assertLen(defined._function_cache, 4) + # Shape (3,) matches the relaxed shape TensorShape([None]) + self.assertLen(total_function_cache(defined), 2) def testPythonFunctionWithDefaultArgs(self): @@ -1295,7 +1403,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): def cache_keys(): """Sanitizes cache keys of non-input metadata.""" - return tuple(key[0] for key in defined._function_cache) + return tuple(key[0] for key in total_function_cache(defined)) # `True` corresponds to the fact that we're executing eagerly self.assertIn(('URRRu', (0, 1, 20)), cache_keys()) @@ -1305,19 +1413,19 @@ class FunctionTest(test.TestCase, parameterized.TestCase): # This matches the previous call. defined(foo=1) - self.assertLen(defined._function_cache, 2) + self.assertLen(total_function_cache(defined), 2) defined(1, 2, 3) - self.assertLen(defined._function_cache, 3) + self.assertLen(total_function_cache(defined), 3) self.assertIn(('URRRu', (1, 2, 3)), cache_keys()) # This matches the previous call. defined(1, bar=2, baz=3) - self.assertLen(defined._function_cache, 3) + self.assertLen(total_function_cache(defined), 3) # This matches the previous call. defined(1, baz=3, bar=2) - self.assertLen(defined._function_cache, 3) + self.assertLen(total_function_cache(defined), 3) def testFunctoolsPartialUnwrappedCorrectly(self): @@ -1343,12 +1451,12 @@ class FunctionTest(test.TestCase, parameterized.TestCase): defined = function.defun(foo, input_signature=signature) a = array_ops.ones([2]) self.assertAllEqual(a, defined(a)) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) self.assertAllEqual(a, defined.get_concrete_function()(a)) self.assertAllEqual(a, defined.get_concrete_function(a)(a)) self.assertAllEqual(a, defined.get_concrete_function( tensor_spec.TensorSpec((2,), dtype=dtypes.float32))(a)) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) def bar(a): self.assertEqual(a._shape_tuple(), (2, None)) @@ -1358,13 +1466,13 @@ class FunctionTest(test.TestCase, parameterized.TestCase): defined = function.defun(bar, input_signature=signature) a = array_ops.ones([2, 1]) out = defined(a) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) self.assertAllEqual(out, a) # Changing the second dimension shouldn't create a new function. b = array_ops.ones([2, 3]) out = defined(b) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) self.assertAllEqual(out, b) def testInputSignatureWithCompatibleInputs(self): @@ -1405,7 +1513,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): b = array_ops.ones([1]) expected = expected_foo([a, a], b) out = foo([a, a], b) - self.assertLen(foo._function_cache, 1) + self.assertLen(total_function_cache(foo), 1) nest.assert_same_structure(out, expected) self.assertAllEqual(out[0][0], a) self.assertAllEqual(out[0][1], a) @@ -1417,7 +1525,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): c = array_ops.ones([1]) expected = expected_foo([a, b], c) out = foo([a, b], c) - self.assertLen(foo._function_cache, 1) + self.assertLen(total_function_cache(foo), 1) nest.assert_same_structure(out, expected) self.assertAllEqual(out[0][0], a) self.assertAllEqual(out[0][1], b) @@ -1428,7 +1536,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): b = b.numpy().tolist() c = c.numpy().tolist() out = foo([a, b], c) - self.assertLen(foo._function_cache, 1) + self.assertLen(total_function_cache(foo), 1) nest.assert_same_structure(out, expected) self.assertAllEqual(out[0][0], a) self.assertAllEqual(out[0][1], b) @@ -1597,22 +1705,22 @@ class FunctionTest(test.TestCase, parameterized.TestCase): integer = constant_op.constant(2, dtypes.int64) out1, out2 = foo(flt, integer) - self.assertLen(foo._function_cache, 1) + self.assertLen(total_function_cache(foo), 1) self.assertEqual(out1.numpy(), 1.0) self.assertEqual(out2.numpy(), 2) out1, out2 = foo(flt=flt, integer=integer) - self.assertLen(foo._function_cache, 1) + self.assertLen(total_function_cache(foo), 1) self.assertEqual(out1.numpy(), 1.0) self.assertEqual(out2.numpy(), 2) out1, out2 = foo(integer=integer, flt=flt) - self.assertLen(foo._function_cache, 1) + self.assertLen(total_function_cache(foo), 1) self.assertEqual(out1.numpy(), 1.0) self.assertEqual(out2.numpy(), 2) out1, out2 = foo(flt, integer=integer) - self.assertLen(foo._function_cache, 1) + self.assertLen(total_function_cache(foo), 1) self.assertEqual(out1.numpy(), 1.0) self.assertEqual(out2.numpy(), 2) @@ -1642,27 +1750,27 @@ class FunctionTest(test.TestCase, parameterized.TestCase): a = constant_op.constant(2.0) b = constant_op.constant([1.0, 2.0]) one = defined(a, b) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) two = defined(a=a, b=b) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) three = defined(b=b, a=a) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) four = defined(a, b=b) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) # The next call corresponds to a new input signature, hence # we expect another function to be defined. five = defined(b, a) - self.assertLen(defined._function_cache, 2) + self.assertLen(total_function_cache(defined), 2) six = defined(a=b, b=a) - self.assertLen(defined._function_cache, 2) + self.assertLen(total_function_cache(defined), 2) seven = defined(b=a, a=b) - self.assertLen(defined._function_cache, 2) + self.assertLen(total_function_cache(defined), 2) self.assertAllEqual(one, [1.0, 2.0]) self.assertAllEqual(two, [1.0, 2.0]) @@ -2049,18 +2157,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase): with ops.Graph().as_default(): x = constant_op.constant(11) maybe_add(x, True) - self.assertLen(maybe_add._function_cache, 1) - self.assertLen(add._function_cache, 1) + self.assertLen(total_function_cache(maybe_add), 1) + self.assertLen(total_function_cache(add), 1) maybe_add(x, False) - self.assertLen(maybe_add._function_cache, 2) - self.assertLen(add._function_cache, 1) + self.assertLen(total_function_cache(maybe_add), 2) + self.assertLen(total_function_cache(add), 1) with ops.Graph().as_default(): x = constant_op.constant(11) maybe_add(x, True) - self.assertLen(maybe_add._function_cache, 3) - self.assertLen(add._function_cache, 2) + self.assertLen(total_function_cache(maybe_add), 3) + self.assertLen(total_function_cache(add), 2) def testCacheKeyOverlappingShapes(self): @function.defun @@ -2068,10 +2176,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase): return t defined(array_ops.zeros([12, 1])) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) defined(array_ops.zeros([1, 21])) - self.assertLen(defined._function_cache, 2) + self.assertLen(total_function_cache(defined), 2) def testCacheKeyNestedLists(self): @function.defun @@ -2082,10 +2190,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase): b = constant_op.constant(2.) c = constant_op.constant(3.) defined([[a], b, c]) - self.assertLen(defined._function_cache, 1) + self.assertLen(total_function_cache(defined), 1) defined([[a, b], c]) - self.assertLen(defined._function_cache, 2) + self.assertLen(total_function_cache(defined), 2) def testDecoratedMethod(self): m = DefunnedMiniModel() @@ -2647,6 +2755,7 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase): result = func(g1, g2, c1, g3, c2) self.assertEqual(result.numpy(), 5.0 * 7.0 * 17.0) + if __name__ == '__main__': ops.enable_eager_execution( config=config_pb2.ConfigProto(device_count={'CPU': 4})) diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index 63440c04b98..1db1b23d4c9 100755 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -231,8 +231,11 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim); PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor); // Encodes the object as a tuple that is meant to be used as part of the key -// for the defun function cache. -PyObject* TFE_Py_EncodeArg(PyObject*); +// for the defun function cache. If `include_tensor_ranks_only` is true, +// then the encoding only stores tensor ranks, and the key is +// agnostic to dimension sizes. Otherwise, full tensor shape encodings are +// returned. +PyObject* TFE_Py_EncodeArg(PyObject*, bool include_tensor_ranks_only); void TFE_Py_EnableInteractivePythonLogging(); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 350c8d4746f..3286e1add81 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -2913,7 +2913,9 @@ struct EncodeResult { } }; -tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) { +tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, + bool include_tensor_ranks_only, + EncodeResult* result) { if (EagerTensor_CheckExact(arg)) { TFE_TensorHandle* t = EagerTensor_Handle(arg); tensorflow::TensorShape tensor_shape; @@ -2922,10 +2924,13 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) { absl::StrAppend(&result->str, kDType, t->handle->dtype); absl::StrAppend(&result->str, kShape); - for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) { - absl::StrAppend(&result->str, dim_size, kShapeDelim); + if (include_tensor_ranks_only) { + absl::StrAppend(&result->str, tensor_shape.dim_sizes().size()); + } else { + for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) { + absl::StrAppend(&result->str, dim_size, kShapeDelim); + } } - return tensorflow::Status::OK(); } @@ -2949,6 +2954,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) { static_cast(MakeInt(dtype_enum.get())); absl::StrAppend(&result->str, kDType, dtype); + static char _shape_tuple[] = "_shape_tuple"; tensorflow::Safe_PyObjectPtr shape_tuple( PyObject_CallMethod(arg, _shape_tuple, nullptr)); @@ -2969,23 +2975,30 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) { shape_tuple.get(), "shape_tuple didn't return a sequence")); int len = PySequence_Fast_GET_SIZE(shape_seq.get()); - for (int i = 0; i < len; ++i) { - PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i); - if (item == Py_None) { - absl::StrAppend(&result->str, kNone); - } else { - absl::StrAppend(&result->str, MakeInt(item)); + + if (include_tensor_ranks_only) { + absl::StrAppend(&result->str, len); + } else { + for (int i = 0; i < len; ++i) { + PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i); + if (item == Py_None) { + absl::StrAppend(&result->str, kNone); + } else { + absl::StrAppend(&result->str, MakeInt(item)); + } } } - return tensorflow::Status::OK(); } -tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result); +tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, + bool include_tensor_ranks_only, + EncodeResult* result); // This function doesn't set the type of sequence before tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type, const char* end_type, + bool include_tensor_ranks_only, EncodeResult* result) { tensorflow::Safe_PyObjectPtr arg_seq( PySequence_Fast(arg, "unable to create seq from list/tuple")); @@ -2997,7 +3010,8 @@ tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type, if (item == Py_None) { absl::StrAppend(&result->str, kNone); } else { - TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(item, result)); + TF_RETURN_IF_ERROR( + TFE_Py_EncodeArgHelper(item, include_tensor_ranks_only, result)); } } absl::StrAppend(&result->str, end_type); @@ -3005,10 +3019,13 @@ tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type, return tensorflow::Status::OK(); } -tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) { +tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, + bool include_tensor_ranks_only, + EncodeResult* result) { if (tensorflow::swig::IsTensor(arg)) { absl::StrAppend(&result->str, kTensor); - TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(arg, result)); + TF_RETURN_IF_ERROR( + TFE_Py_EncodeTensor(arg, include_tensor_ranks_only, result)); } else if (tensorflow::swig::IsIndexedSlices(arg)) { absl::StrAppend(&result->str, kIndexedSlices); tensorflow::Safe_PyObjectPtr values(PyObject_GetAttrString(arg, "values")); @@ -3017,7 +3034,8 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) { return tensorflow::errors::InvalidArgument( "IndexedSlices does not have a values attr"); } - TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(values.get(), result)); + TF_RETURN_IF_ERROR( + TFE_Py_EncodeTensor(values.get(), include_tensor_ranks_only, result)); tensorflow::Safe_PyObjectPtr indices( PyObject_GetAttrString(arg, "indices")); @@ -3026,7 +3044,8 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) { return tensorflow::errors::InvalidArgument( "IndexedSlices does not have a indices attr"); } - TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(indices.get(), result)); + TF_RETURN_IF_ERROR( + TFE_Py_EncodeTensor(indices.get(), include_tensor_ranks_only, result)); tensorflow::Safe_PyObjectPtr dense_shape( PyObject_GetAttrString(arg, "dense_shape")); @@ -3036,12 +3055,15 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) { "IndexedSlices does not have a dense_shape attr"); } if (dense_shape.get() != Py_None) { - TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(dense_shape.get(), result)); + TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor( + dense_shape.get(), include_tensor_ranks_only, result)); } } else if (PyList_Check(arg)) { - TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kList, kListEnd, result)); + TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence( + arg, kList, kListEnd, include_tensor_ranks_only, result)); } else if (PyTuple_Check(arg)) { - TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kTuple, kTupleEnd, result)); + TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence( + arg, kTuple, kTupleEnd, include_tensor_ranks_only, result)); } else if (PyDict_Check(arg)) { tensorflow::Safe_PyObjectPtr keys(PyDict_Keys(arg)); if (PyList_Sort(keys.get()) == -1) { @@ -3053,9 +3075,11 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) { for (int i = 0; i < len; i++) { PyObject* key = PyList_GetItem(keys.get(), i); - TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(key, result)); + TF_RETURN_IF_ERROR( + TFE_Py_EncodeArgHelper(key, include_tensor_ranks_only, result)); PyObject* value = PyDict_GetItem(arg, key); - TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(value, result)); + TF_RETURN_IF_ERROR( + TFE_Py_EncodeArgHelper(value, include_tensor_ranks_only, result)); } } else { PyObject* object = PyWeakref_NewRef(arg, nullptr); @@ -3082,10 +3106,15 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) { // on known shapes to produce slimmer graphs, and correctness, as some // high-level APIs require shapes to be fully-known. // +// `include_tensor_ranks_only` allows caching on arguments excluding shape info, +// so that a slow path using relaxed shape can rely on a cache key that excludes +// shapes. +// // TODO(nareshmodi): Add support for sparse tensors. -PyObject* TFE_Py_EncodeArg(PyObject* arg) { +PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only) { EncodeResult result; - const auto status = TFE_Py_EncodeArgHelper(arg, &result); + const auto status = + TFE_Py_EncodeArgHelper(arg, include_tensor_ranks_only, &result); if (MaybeRaiseExceptionFromStatus(status, nullptr)) { return nullptr; } diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index 74f0ece2b40..9097a8dd1f0 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import collections as py_collections +import itertools import weakref from tensorflow.core.framework import attr_value_pb2 @@ -468,7 +469,8 @@ def func_graph_from_py_func(name, arg_names=None, op_return_value=None, collections=None, - capture_by_value=None): + capture_by_value=None, + override_flat_arg_shapes=None): """Returns a `FuncGraph` generated from `python_func`. Args: @@ -507,6 +509,12 @@ def func_graph_from_py_func(name, capture_by_value: An optional boolean. If True, the func graph will capture Variables by value instead of reference. By default inherit from outer graphs, and failing that will default to False. + override_flat_arg_shapes: An optional list of instances that are either + `None` or `TensorShape`. The length must match that of + `nest.flatten((args, kwargs))`. The entries containing value `None` + must match entries in flattened arguments containing non-tensors, while + entries containing a `TensorShape` must match entries in the flattened + arguments containing tensors. Returns: A FuncGraph. @@ -514,6 +522,8 @@ def func_graph_from_py_func(name, Raises: TypeError: If any of `python_func`'s return values is neither `None` nor a `Tensor`. + ValueError: If both `signature` and `override_flat_arg_shapes` are + passed in. """ if op_return_value is not None: assert isinstance(op_return_value, ops.Tensor), op_return_value @@ -530,13 +540,27 @@ def func_graph_from_py_func(name, default_use_recource = current_scope.use_resource current_scope.set_use_resource(True) + if signature is not None and override_flat_arg_shapes is not None: + raise ValueError( + "Passed both signature and override_flat_arg_shapes: %s and %s." + % (signature, override_flat_arg_shapes)) + if signature is not None: args = signature kwargs = {} # Creates and names placeholders for all arguments. - func_args = _get_defun_inputs_from_args(args, arg_names) - func_kwargs = _get_defun_inputs_from_kwargs(kwargs) + if override_flat_arg_shapes is not None: + flat_args = nest.flatten(args) + arg_shapes = override_flat_arg_shapes[:len(flat_args)] + kwarg_shapes = override_flat_arg_shapes[len(flat_args):] + else: + arg_shapes = None + kwarg_shapes = None + func_args = _get_defun_inputs_from_args( + args, arg_names, flat_shapes=arg_shapes) + func_kwargs = _get_defun_inputs_from_kwargs( + kwargs, flat_shapes=kwarg_shapes) # Convert all Tensors into TensorSpecs before saving the structured inputs. # If storing pure concrete functions that are not called through polymorphic @@ -761,43 +785,73 @@ def _create_substitute_placeholder(value, name=None, dtype=None): return placeholder -def _get_defun_inputs_from_args(args, names): +def _get_defun_inputs_from_args(args, names, flat_shapes=None): """Maps Python function positional args to graph-construction inputs.""" - return _get_defun_inputs(args, names, structure=args) + return _get_defun_inputs( + args, names, structure=args, flat_shapes=flat_shapes) -def _get_defun_inputs(flat_args, names, structure): +def _get_defun_inputs(args, names, structure, flat_shapes=None): """Maps python function args to graph-construction inputs. Args: - flat_args: A flat list of user-specified arguments. + args: A flat list of user-specified arguments. names: A list of strings with user-specified argument names, same length as - `flat_args`. May be `None`, in which case a generic name is used. + `args`. May be `None`, in which case a generic name is used. structure: The original argument list or dictionary. + flat_shapes: A flat list of values that are either `None` or + instances of `TensorShape`. If provided, then length must match + that of `nest.flatten(args)`; and locations where `args` are + instances of `Tensor` must have a corresponding `TensorShape` in + `flat_shapes`. May be `None`, in which case exact shapes are read + directly from the args. Returns: Placeholders with the same structure as `structure`. + + Raises: + RuntimeError: if `flat_shapes` is provided, but + `len(flat_shapes) != len(nest.flatten(args))`. + RuntimeError: if a shape from `flat_shapes` is not None + for an argument that is not a `Tensor`, `TensorSpec`, + or `ResourceVariable`. """ func_graph = ops.get_default_graph() function_inputs = [] if names is None: - names = [None] * len(flat_args) - for arg_value, name in zip(flat_args, names): + names = [None] * len(args) + if flat_shapes is None: + shapes_iter = itertools.repeat(None) + else: + len_flat_args = len(nest.flatten(args)) + if len_flat_args != len(flat_shapes): + raise RuntimeError( + "Length of fully flat shapes (%d) must match that of " + "flatten(args) (%d). args: %s, flat_shapes: %s" + % (len(flat_shapes), + len_flat_args, + args, + flat_shapes)) + shapes_iter = iter(flat_shapes) + for arg_value, name in zip(args, names): for arg in nest.flatten(arg_value): + # We have a shape entry for each arg, regadless of whether it's a real + # Tensor or not. For non-tensor entries it should be None. + shape = next(shapes_iter) if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)): if isinstance(arg, tensor_spec.TensorSpec) and arg.name: requested_name = arg.name else: requested_name = name - + placeholder_shape = shape if shape is not None else arg.shape try: placeholder = graph_placeholder( - arg.dtype, arg.shape, + arg.dtype, placeholder_shape, name=requested_name) except ValueError: # Sometimes parameter names are not valid op names, so fall back to # unnamed placeholders. - placeholder = graph_placeholder(arg.dtype, arg.shape) + placeholder = graph_placeholder(arg.dtype, placeholder_shape) if name is not None: # Record the requested/user-specified name in case it's different than # the uniquified name, for validation when exporting signatures. @@ -816,18 +870,24 @@ def _get_defun_inputs(flat_args, names, structure): attr_value_pb2.AttrValue(s=compat.as_bytes(name))) function_inputs.append(arg) else: + if shape is not None: + raise RuntimeError( + "Expected provided shape override to be None for arg that isn't " + "a Tensor, but saw arg: '%s', shape: '%s'. args: %s" + % (arg, shape, args)) function_inputs.append(arg) return nest.pack_sequence_as(structure, function_inputs) -def _get_defun_inputs_from_kwargs(kwargs): +def _get_defun_inputs_from_kwargs(kwargs, flat_shapes): """Maps Python function keyword args to graph-construction inputs.""" if kwargs: - names, flat_args = zip(*sorted(kwargs.items())) + names, args = zip(*sorted(kwargs.items())) else: names = [] - flat_args = [] - return _get_defun_inputs(flat_args, names, structure=kwargs) + args = [] + return _get_defun_inputs( + args, names, structure=kwargs, flat_shapes=flat_shapes) def dismantle_func_graph(func_graph):