tf functions capture small EagerTensors as constant ops rather than placeholders which are fed from the outer context. This enables more comprehensive shape inference during function construction.

PiperOrigin-RevId: 261162099
This commit is contained in:
Taylor Robie 2019-08-01 11:27:26 -07:00 committed by TensorFlower Gardener
parent d8563bdb14
commit 0ccb3675d6
3 changed files with 60 additions and 0 deletions

View File

@ -42,6 +42,7 @@ from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import function as tf_function
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
@ -2125,6 +2126,32 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
# is added.
self.assertLen(graph._functions, 6)
@parameterized.named_parameters(
dict(testcase_name='Defun',
function_decorator=function.defun),
dict(testcase_name='DefFunction',
function_decorator=def_function.function))
def testEagerCaptures(self, function_decorator):
with context.eager_mode():
large_tensor = array_ops.ones(shape=(256,))
self.assertGreater(256, func_graph._EAGER_CONST_THRESHOLD)
small_tensor = array_ops.ones(shape=(4,))
self.assertLessEqual(4, func_graph._EAGER_CONST_THRESHOLD)
v = resource_variable_ops.ResourceVariable(0.0)
for captured, op_type in [(large_tensor, 'Placeholder'),
(small_tensor, 'Const'), (v, 'Placeholder')]:
@function_decorator
def test_fn():
return captured + 1 # pylint: disable=cell-var-from-loop
g = test_fn.get_concrete_function().graph
internal_captures = g.internal_captures
self.assertLen(internal_captures, 1)
self.assertEqual(internal_captures[0].op.type, op_type)
def testRegisterFunctionWithInputSignature(self):
def matmul(x, y):
return math_ops.matmul(x, y)

View File

@ -566,6 +566,10 @@ for pdt in [
_NP_TO_TF[pdt] = next(
_NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype)
TF_VALUE_DTYPES = set(_NP_TO_TF.values())
_TF_TO_NP = {
types_pb2.DT_HALF:
np.float16,

View File

@ -22,12 +22,15 @@ import collections as py_collections
import itertools
import weakref
import numpy as np
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@ -64,6 +67,9 @@ WHITELIST_COLLECTIONS = [
]
_EAGER_CONST_THRESHOLD = 128
class UnknownArgument(object):
"""Signifies an argument which is not currently handled."""
pass
@ -569,6 +575,13 @@ class FuncGraph(ops.Graph):
if isinstance(tensor, ops.EagerTensor):
if name is None:
name = str(ops.uid())
# Small EagerTensors are captured with Const ops
if (tensor.dtype in dtypes.TF_VALUE_DTYPES and
np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD):
return self.capture_eager_tensor(tensor, name)
# Large EagerTensors and resources are captured with Placeholder ops
return self._capture_helper(tensor, name)
if tensor.graph is not self:
if name is None:
@ -643,6 +656,22 @@ class FuncGraph(ops.Graph):
tape.record_operation("captured_value", [placeholder], [variable],
lambda x: [x])
def capture_eager_tensor(self, tensor, name):
capture = self._captures.get(ops.tensor_id(tensor))
if capture is None:
# We clear all control dependencies and place the Const op on the same
# device as the source tensor. The device placement may be relaxed at
# a later date.
with ops.control_dependencies(None), self.device(tensor.device):
graph_const = constant_op.constant(tensor.numpy(), dtype=tensor.dtype,
shape=tensor.shape, name=name)
self.add_capture(tensor, graph_const)
else:
graph_const = capture[1]
tape.record_operation("captured_value", [graph_const], [tensor],
lambda x: [x])
return graph_const
@property
def external_captures(self):
"""External tensors captured by this function."""