tf functions capture small EagerTensors as constant ops rather than placeholders which are fed from the outer context. This enables more comprehensive shape inference during function construction.
PiperOrigin-RevId: 261162099
This commit is contained in:
parent
d8563bdb14
commit
0ccb3675d6
@ -42,6 +42,7 @@ from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import function as tf_function
|
||||
from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import ops
|
||||
@ -2125,6 +2126,32 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
# is added.
|
||||
self.assertLen(graph._functions, 6)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name='Defun',
|
||||
function_decorator=function.defun),
|
||||
dict(testcase_name='DefFunction',
|
||||
function_decorator=def_function.function))
|
||||
def testEagerCaptures(self, function_decorator):
|
||||
with context.eager_mode():
|
||||
large_tensor = array_ops.ones(shape=(256,))
|
||||
self.assertGreater(256, func_graph._EAGER_CONST_THRESHOLD)
|
||||
|
||||
small_tensor = array_ops.ones(shape=(4,))
|
||||
self.assertLessEqual(4, func_graph._EAGER_CONST_THRESHOLD)
|
||||
|
||||
v = resource_variable_ops.ResourceVariable(0.0)
|
||||
|
||||
for captured, op_type in [(large_tensor, 'Placeholder'),
|
||||
(small_tensor, 'Const'), (v, 'Placeholder')]:
|
||||
@function_decorator
|
||||
def test_fn():
|
||||
return captured + 1 # pylint: disable=cell-var-from-loop
|
||||
|
||||
g = test_fn.get_concrete_function().graph
|
||||
internal_captures = g.internal_captures
|
||||
self.assertLen(internal_captures, 1)
|
||||
self.assertEqual(internal_captures[0].op.type, op_type)
|
||||
|
||||
def testRegisterFunctionWithInputSignature(self):
|
||||
def matmul(x, y):
|
||||
return math_ops.matmul(x, y)
|
||||
|
@ -566,6 +566,10 @@ for pdt in [
|
||||
_NP_TO_TF[pdt] = next(
|
||||
_NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype)
|
||||
|
||||
|
||||
TF_VALUE_DTYPES = set(_NP_TO_TF.values())
|
||||
|
||||
|
||||
_TF_TO_NP = {
|
||||
types_pb2.DT_HALF:
|
||||
np.float16,
|
||||
|
@ -22,12 +22,15 @@ import collections as py_collections
|
||||
import itertools
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import execute
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
@ -64,6 +67,9 @@ WHITELIST_COLLECTIONS = [
|
||||
]
|
||||
|
||||
|
||||
_EAGER_CONST_THRESHOLD = 128
|
||||
|
||||
|
||||
class UnknownArgument(object):
|
||||
"""Signifies an argument which is not currently handled."""
|
||||
pass
|
||||
@ -569,6 +575,13 @@ class FuncGraph(ops.Graph):
|
||||
if isinstance(tensor, ops.EagerTensor):
|
||||
if name is None:
|
||||
name = str(ops.uid())
|
||||
|
||||
# Small EagerTensors are captured with Const ops
|
||||
if (tensor.dtype in dtypes.TF_VALUE_DTYPES and
|
||||
np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD):
|
||||
return self.capture_eager_tensor(tensor, name)
|
||||
|
||||
# Large EagerTensors and resources are captured with Placeholder ops
|
||||
return self._capture_helper(tensor, name)
|
||||
if tensor.graph is not self:
|
||||
if name is None:
|
||||
@ -643,6 +656,22 @@ class FuncGraph(ops.Graph):
|
||||
tape.record_operation("captured_value", [placeholder], [variable],
|
||||
lambda x: [x])
|
||||
|
||||
def capture_eager_tensor(self, tensor, name):
|
||||
capture = self._captures.get(ops.tensor_id(tensor))
|
||||
if capture is None:
|
||||
# We clear all control dependencies and place the Const op on the same
|
||||
# device as the source tensor. The device placement may be relaxed at
|
||||
# a later date.
|
||||
with ops.control_dependencies(None), self.device(tensor.device):
|
||||
graph_const = constant_op.constant(tensor.numpy(), dtype=tensor.dtype,
|
||||
shape=tensor.shape, name=name)
|
||||
self.add_capture(tensor, graph_const)
|
||||
else:
|
||||
graph_const = capture[1]
|
||||
tape.record_operation("captured_value", [graph_const], [tensor],
|
||||
lambda x: [x])
|
||||
return graph_const
|
||||
|
||||
@property
|
||||
def external_captures(self):
|
||||
"""External tensors captured by this function."""
|
||||
|
Loading…
Reference in New Issue
Block a user