tf functions capture small EagerTensors as constant ops rather than placeholders which are fed from the outer context. This enables more comprehensive shape inference during function construction.
PiperOrigin-RevId: 261162099
This commit is contained in:
parent
d8563bdb14
commit
0ccb3675d6
@ -42,6 +42,7 @@ from tensorflow.python.framework import config
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import func_graph
|
||||||
from tensorflow.python.framework import function as tf_function
|
from tensorflow.python.framework import function as tf_function
|
||||||
from tensorflow.python.framework import indexed_slices
|
from tensorflow.python.framework import indexed_slices
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -2125,6 +2126,32 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
# is added.
|
# is added.
|
||||||
self.assertLen(graph._functions, 6)
|
self.assertLen(graph._functions, 6)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(testcase_name='Defun',
|
||||||
|
function_decorator=function.defun),
|
||||||
|
dict(testcase_name='DefFunction',
|
||||||
|
function_decorator=def_function.function))
|
||||||
|
def testEagerCaptures(self, function_decorator):
|
||||||
|
with context.eager_mode():
|
||||||
|
large_tensor = array_ops.ones(shape=(256,))
|
||||||
|
self.assertGreater(256, func_graph._EAGER_CONST_THRESHOLD)
|
||||||
|
|
||||||
|
small_tensor = array_ops.ones(shape=(4,))
|
||||||
|
self.assertLessEqual(4, func_graph._EAGER_CONST_THRESHOLD)
|
||||||
|
|
||||||
|
v = resource_variable_ops.ResourceVariable(0.0)
|
||||||
|
|
||||||
|
for captured, op_type in [(large_tensor, 'Placeholder'),
|
||||||
|
(small_tensor, 'Const'), (v, 'Placeholder')]:
|
||||||
|
@function_decorator
|
||||||
|
def test_fn():
|
||||||
|
return captured + 1 # pylint: disable=cell-var-from-loop
|
||||||
|
|
||||||
|
g = test_fn.get_concrete_function().graph
|
||||||
|
internal_captures = g.internal_captures
|
||||||
|
self.assertLen(internal_captures, 1)
|
||||||
|
self.assertEqual(internal_captures[0].op.type, op_type)
|
||||||
|
|
||||||
def testRegisterFunctionWithInputSignature(self):
|
def testRegisterFunctionWithInputSignature(self):
|
||||||
def matmul(x, y):
|
def matmul(x, y):
|
||||||
return math_ops.matmul(x, y)
|
return math_ops.matmul(x, y)
|
||||||
|
@ -566,6 +566,10 @@ for pdt in [
|
|||||||
_NP_TO_TF[pdt] = next(
|
_NP_TO_TF[pdt] = next(
|
||||||
_NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype)
|
_NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype)
|
||||||
|
|
||||||
|
|
||||||
|
TF_VALUE_DTYPES = set(_NP_TO_TF.values())
|
||||||
|
|
||||||
|
|
||||||
_TF_TO_NP = {
|
_TF_TO_NP = {
|
||||||
types_pb2.DT_HALF:
|
types_pb2.DT_HALF:
|
||||||
np.float16,
|
np.float16,
|
||||||
|
@ -22,12 +22,15 @@ import collections as py_collections
|
|||||||
import itertools
|
import itertools
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import execute
|
from tensorflow.python.eager import execute
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
||||||
from tensorflow.python.framework import composite_tensor
|
from tensorflow.python.framework import composite_tensor
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -64,6 +67,9 @@ WHITELIST_COLLECTIONS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
_EAGER_CONST_THRESHOLD = 128
|
||||||
|
|
||||||
|
|
||||||
class UnknownArgument(object):
|
class UnknownArgument(object):
|
||||||
"""Signifies an argument which is not currently handled."""
|
"""Signifies an argument which is not currently handled."""
|
||||||
pass
|
pass
|
||||||
@ -569,6 +575,13 @@ class FuncGraph(ops.Graph):
|
|||||||
if isinstance(tensor, ops.EagerTensor):
|
if isinstance(tensor, ops.EagerTensor):
|
||||||
if name is None:
|
if name is None:
|
||||||
name = str(ops.uid())
|
name = str(ops.uid())
|
||||||
|
|
||||||
|
# Small EagerTensors are captured with Const ops
|
||||||
|
if (tensor.dtype in dtypes.TF_VALUE_DTYPES and
|
||||||
|
np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD):
|
||||||
|
return self.capture_eager_tensor(tensor, name)
|
||||||
|
|
||||||
|
# Large EagerTensors and resources are captured with Placeholder ops
|
||||||
return self._capture_helper(tensor, name)
|
return self._capture_helper(tensor, name)
|
||||||
if tensor.graph is not self:
|
if tensor.graph is not self:
|
||||||
if name is None:
|
if name is None:
|
||||||
@ -643,6 +656,22 @@ class FuncGraph(ops.Graph):
|
|||||||
tape.record_operation("captured_value", [placeholder], [variable],
|
tape.record_operation("captured_value", [placeholder], [variable],
|
||||||
lambda x: [x])
|
lambda x: [x])
|
||||||
|
|
||||||
|
def capture_eager_tensor(self, tensor, name):
|
||||||
|
capture = self._captures.get(ops.tensor_id(tensor))
|
||||||
|
if capture is None:
|
||||||
|
# We clear all control dependencies and place the Const op on the same
|
||||||
|
# device as the source tensor. The device placement may be relaxed at
|
||||||
|
# a later date.
|
||||||
|
with ops.control_dependencies(None), self.device(tensor.device):
|
||||||
|
graph_const = constant_op.constant(tensor.numpy(), dtype=tensor.dtype,
|
||||||
|
shape=tensor.shape, name=name)
|
||||||
|
self.add_capture(tensor, graph_const)
|
||||||
|
else:
|
||||||
|
graph_const = capture[1]
|
||||||
|
tape.record_operation("captured_value", [graph_const], [tensor],
|
||||||
|
lambda x: [x])
|
||||||
|
return graph_const
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def external_captures(self):
|
def external_captures(self):
|
||||||
"""External tensors captured by this function."""
|
"""External tensors captured by this function."""
|
||||||
|
Loading…
Reference in New Issue
Block a user