From 0ccb3675d69cf1c277a59c17ef59ab3ba9b1a36c Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Thu, 1 Aug 2019 11:27:26 -0700 Subject: [PATCH] tf functions capture small EagerTensors as constant ops rather than placeholders which are fed from the outer context. This enables more comprehensive shape inference during function construction. PiperOrigin-RevId: 261162099 --- tensorflow/python/eager/function_test.py | 27 +++++++++++++++++++++ tensorflow/python/framework/dtypes.py | 4 ++++ tensorflow/python/framework/func_graph.py | 29 +++++++++++++++++++++++ 3 files changed, 60 insertions(+) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 500997bdb5a..39062ef3910 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -42,6 +42,7 @@ from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import func_graph from tensorflow.python.framework import function as tf_function from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops @@ -2125,6 +2126,32 @@ class FunctionTest(test.TestCase, parameterized.TestCase): # is added. self.assertLen(graph._functions, 6) + @parameterized.named_parameters( + dict(testcase_name='Defun', + function_decorator=function.defun), + dict(testcase_name='DefFunction', + function_decorator=def_function.function)) + def testEagerCaptures(self, function_decorator): + with context.eager_mode(): + large_tensor = array_ops.ones(shape=(256,)) + self.assertGreater(256, func_graph._EAGER_CONST_THRESHOLD) + + small_tensor = array_ops.ones(shape=(4,)) + self.assertLessEqual(4, func_graph._EAGER_CONST_THRESHOLD) + + v = resource_variable_ops.ResourceVariable(0.0) + + for captured, op_type in [(large_tensor, 'Placeholder'), + (small_tensor, 'Const'), (v, 'Placeholder')]: + @function_decorator + def test_fn(): + return captured + 1 # pylint: disable=cell-var-from-loop + + g = test_fn.get_concrete_function().graph + internal_captures = g.internal_captures + self.assertLen(internal_captures, 1) + self.assertEqual(internal_captures[0].op.type, op_type) + def testRegisterFunctionWithInputSignature(self): def matmul(x, y): return math_ops.matmul(x, y) diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index 16403b266ca..e817c3172ee 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -566,6 +566,10 @@ for pdt in [ _NP_TO_TF[pdt] = next( _NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype) + +TF_VALUE_DTYPES = set(_NP_TO_TF.values()) + + _TF_TO_NP = { types_pb2.DT_HALF: np.float16, diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index 9a65b5b2527..27feacf152b 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -22,12 +22,15 @@ import collections as py_collections import itertools import weakref +import numpy as np + from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.eager import tape from tensorflow.python.eager.graph_only_ops import graph_placeholder from tensorflow.python.framework import composite_tensor +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -64,6 +67,9 @@ WHITELIST_COLLECTIONS = [ ] +_EAGER_CONST_THRESHOLD = 128 + + class UnknownArgument(object): """Signifies an argument which is not currently handled.""" pass @@ -569,6 +575,13 @@ class FuncGraph(ops.Graph): if isinstance(tensor, ops.EagerTensor): if name is None: name = str(ops.uid()) + + # Small EagerTensors are captured with Const ops + if (tensor.dtype in dtypes.TF_VALUE_DTYPES and + np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD): + return self.capture_eager_tensor(tensor, name) + + # Large EagerTensors and resources are captured with Placeholder ops return self._capture_helper(tensor, name) if tensor.graph is not self: if name is None: @@ -643,6 +656,22 @@ class FuncGraph(ops.Graph): tape.record_operation("captured_value", [placeholder], [variable], lambda x: [x]) + def capture_eager_tensor(self, tensor, name): + capture = self._captures.get(ops.tensor_id(tensor)) + if capture is None: + # We clear all control dependencies and place the Const op on the same + # device as the source tensor. The device placement may be relaxed at + # a later date. + with ops.control_dependencies(None), self.device(tensor.device): + graph_const = constant_op.constant(tensor.numpy(), dtype=tensor.dtype, + shape=tensor.shape, name=name) + self.add_capture(tensor, graph_const) + else: + graph_const = capture[1] + tape.record_operation("captured_value", [graph_const], [tensor], + lambda x: [x]) + return graph_const + @property def external_captures(self): """External tensors captured by this function."""