Turn on shape relaxation retracing in @tf.function.

PiperOrigin-RevId: 235054774
This commit is contained in:
Eugene Brevdo 2019-02-21 13:14:20 -08:00 committed by TensorFlower Gardener
parent 1d0ec3ec5d
commit d8033ab10d
6 changed files with 528 additions and 152 deletions

View File

@ -578,9 +578,11 @@ class Function(object):
concrete_functions = []
# pylint: disable=protected-access
if self._stateful_fn:
concrete_functions.extend(self._stateful_fn._function_cache.values())
concrete_functions.extend(
self._stateful_fn._function_cache.all_values())
if self._stateless_fn:
concrete_functions.extend(self._stateless_fn._function_cache.values())
concrete_functions.extend(
self._stateless_fn._function_cache.all_values())
# pylint: enable=protected-access
deduplicated_concrete_functions = list()
seen_signatures = list()

View File

@ -43,6 +43,7 @@ from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import errors
from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import functional_ops
@ -60,10 +61,94 @@ from tensorflow.python.util import tf_inspect
FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
CacheKey = collections.namedtuple("CacheKey", [
"input_signature", "parent_graph", "device_functions", "colocation_stack",
"uses_xla"
])
class CacheKey(
collections.namedtuple("CacheKey", [
"input_signature", "parent_graph", "device_functions",
"colocation_stack", "uses_xla"])):
def replace(self, *args, **kwargs):
return self._replace(*args, **kwargs)
def _flat_shape_list(*params):
"""Return a flat list of TensorShapes, one for each tensor[spec] in `*params`.
Args:
*params: Set of nested entries containing Tensors, TensorSpec, and
non-tensors.
Returns:
A list of entries containing either `None` or `TensorShape`.
"""
return [tensor_shape.TensorShape(x.shape)
if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) else None
for x in nest.flatten(params)]
def _compatible_shapes(flat_x, flat_y):
"""Check if lists of TensorShapes contain compatible shapes.
Args:
flat_x: List of TensorShape or None.
flat_y: List of TensorShape or None.
Returns:
A python bool.
Raises:
RuntimeError: if `len(flat_x) != len(flat_y)`.
RuntimeError: if `flat_x[i] is None != flat_y[i] is None` for any `i`.
"""
if len(flat_x) != len(flat_y):
raise RuntimeError("Expected shape lists of identical lengths, but saw: "
"%s and %s" % (flat_x, flat_y))
def is_compatible(x, y):
"""Internal help function.
Args:
x: TensorShape or None.
y: TensorShape or None.
Returns:
Python bool.
Raises:
RuntimeError: If `x is None != y is None`.
"""
# If both x and y are None, there is no shape to compare. Otherwise check
# if they are compatible with each other. Either way, both input signatures
# must have have Tensors in the same entries. If not, raise an assertion
# error.
if x is None != y is None:
raise RuntimeError(
"Expected signature type matches between flattened input shapes "
"%s and %s; but saw that (%s is None) != (%s is None)"
% (flat_x, flat_y, x, y))
return x is None or x.is_compatible_with(y)
return all(is_compatible(x, y) for x, y in zip(flat_x, flat_y))
def _common_shape(x, y):
"""Find a `TensorShape` that is compatible with both `x` and `y`."""
if x is None != y is None:
raise RuntimeError(
"Cannot find a common shape when LHS shape is None but RHS shape "
"is not (or vice versa): %s vs. %s" % (x, y))
if x is None:
return None # The associated input was not a Tensor, no shape generated.
if not isinstance(x, tensor_shape.TensorShape):
raise TypeError("Expected x to be a TensorShape but saw %s" % (x,))
if not isinstance(y, tensor_shape.TensorShape):
raise TypeError("Expected y to be a TensorShape but saw %s" % (y,))
if x.rank != y.rank or x.rank is None:
return tensor_shape.TensorShape(None)
dims = []
for dim_x, dim_y in zip(x.dims, y.dims):
if dim_x != dim_y or tensor_shape.dimension_value(dim_x) is None:
dims.append(None)
else:
dims.append(tensor_shape.dimension_value(dim_x))
return tensor_shape.TensorShape(dims)
def is_same_structure(structure1,
@ -1073,6 +1158,34 @@ def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature):
return inputs
class FunctionCache(object):
"""A lightweight container for cached functions.
"""
def __init__(self):
# The set of functions that have been missed; entries are CacheKey with
# input_signature `None` (e.g. a "call context key")
self.missed = set()
# The primary cache, mapping a fully shaped CacheKey to a function.
self.primary = collections.OrderedDict()
# A cache key lookup, mapping a CacheKey generated without shape info to a
# flat list of relaxed shapes (one for each argument). Arguments that are
# not Tensors contain a `None` for the corresponding relaxed shape.
self.arg_relaxed_shapes = collections.OrderedDict()
# The secondary cache, mapping a CacheKey generated without shape info to a
# function.
self.arg_relaxed = collections.OrderedDict()
# All OrderedDicts require manual garbage collection.
self._garbage_collectors = [
_FunctionGarbageCollector(self.primary),
_FunctionGarbageCollector(self.arg_relaxed),
_FunctionGarbageCollector(self.arg_relaxed_shapes)]
def all_values(self):
"""A set of all `ConcreteFunction` instances held by this cache."""
return set(self.primary.values()) | set(self.arg_relaxed.values())
class Function(object):
"""Wrapper class for the graph functions defined for a Python function.
@ -1126,8 +1239,7 @@ class Function(object):
self._name = name
self._autograph = autograph
self._autograph_options = autograph_options
self._function_cache = collections.OrderedDict()
self._garbage_collector = _FunctionGarbageCollector(self._function_cache)
self._function_cache = FunctionCache()
self._function_attributes = attributes or {}
self._capture_by_value = capture_by_value
@ -1284,13 +1396,15 @@ class Function(object):
# Return the cached `Function` for the instance
return self._descriptor_cache[instance]
def _cache_key(self, args, kwargs):
def _cache_key(self, args, kwargs, include_tensor_ranks_only=False):
"""Computes the cache key given inputs and execution context."""
if self._input_signature is None:
inputs = (args, kwargs) if kwargs else args
input_signature = pywrap_tensorflow.TFE_Py_EncodeArg(inputs)
input_signature = pywrap_tensorflow.TFE_Py_EncodeArg(
inputs, include_tensor_ranks_only)
else:
del args, kwargs
assert not include_tensor_ranks_only
input_signature = self._flat_input_signature
ctx = context.context()
@ -1336,6 +1450,46 @@ class Function(object):
return CacheKey(input_signature, parent_graph, device_functions,
colocation_stack, uses_xla)
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
if self._input_signature is None:
arglen = len(args)
else:
arglen = len(self._input_signature)
base_arg_names = self._function_spec.arg_names[:arglen]
num_missing_args = arglen - len(self._function_spec.arg_names)
missing_arg_names = [self._function_spec.vararg_name] * num_missing_args
# Produce a list of missing args of the form ["arg_0", "arg_1", ...],
# where arg is based on the self._function_spec.vararg_name.
missing_arg_names = [
"%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names)
]
arg_names = base_arg_names + missing_arg_names
graph_function = ConcreteFunction(
func_graph_module.func_graph_from_py_func(
self._name,
self._python_function,
args,
kwargs,
self._input_signature,
autograph=self._autograph,
autograph_options=self._autograph_options,
arg_names=arg_names,
override_flat_arg_shapes=override_flat_arg_shapes,
capture_by_value=self._capture_by_value),
self._function_attributes)
# pylint: disable=protected-access
# Tell the ConcreteFunction to clean up its graph once it goes out of
# scope. ConcreteFunction does not do this in its constructor since it
# gets used in some places (like Keras) where the FuncGraph lives
# longer than the ConcreteFunction.
graph_function._garbage_collector = ConcreteFunctionGarbageCollector(
graph_function.graph)
# pylint: enable=protected-access
return graph_function
def _maybe_define_function(self, args, kwargs):
"""Gets a function for these inputs, defining it if necessary.
@ -1353,57 +1507,76 @@ class Function(object):
Raises:
ValueError: If inputs are incompatible with the input signature.
TypeError: If the function inputs include non-hashable objects
RuntimeError: If there's an internal bug (inconsistency) in handling
shape relaxation retracing.
"""
if self._input_signature is None or args is not None or kwargs is not None:
args, kwargs = self._function_spec.canonicalize_function_inputs(
*args, **kwargs)
cache_key = self._cache_key(args, kwargs)
with self._lock:
try:
graph_function = self._function_cache.get(cache_key, None)
except TypeError as e:
raise TypeError(
"Arguments supplied to `defun`-generated functions must be"
" hashable. Original error: %s" % e)
if graph_function is None:
logging.vlog(1,
"Creating new FuncGraph for Python function %r (key: %r)",
self._python_function, cache_key)
if self._input_signature is None:
arglen = len(args)
else:
arglen = len(self._input_signature)
base_arg_names = self._function_spec.arg_names[:arglen]
num_missing_args = arglen - len(self._function_spec.arg_names)
missing_arg_names = [self._function_spec.vararg_name] * num_missing_args
# Produce a list of missing args of the form ["arg_0", "arg_1", ...],
# where arg is based on the self._function_spec.vararg_name.
missing_arg_names = [
"%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names)
]
arg_names = base_arg_names + missing_arg_names
graph_function = ConcreteFunction(
func_graph_module.func_graph_from_py_func(
self._name,
self._python_function,
args,
kwargs,
self._input_signature,
autograph=self._autograph,
autograph_options=self._autograph_options,
arg_names=arg_names,
capture_by_value=self._capture_by_value),
self._function_attributes)
# pylint: disable=protected-access
# Tell the ConcreteFunction to clean up its graph once it goes out of
# scope. ConcreteFunction does not do this in its constructor since it
# gets used in some places (like Keras) where the FuncGraph lives
# longer than the ConcreteFunction.
graph_function._garbage_collector = ConcreteFunctionGarbageCollector(
graph_function.graph)
# pylint: enable=protected-access
self._function_cache[cache_key] = graph_function
try:
hash(cache_key)
except TypeError as e:
raise TypeError(
"Arguments supplied to `defun`-generated functions must be"
" hashable. Original error: %s" % e)
with self._lock:
graph_function = self._function_cache.primary.get(cache_key, None)
if graph_function is not None:
return graph_function, args, kwargs
logging.vlog(1,
"Creating new FuncGraph for Python function %r (key: %r)",
self._python_function, cache_key)
logging.vlog(2,
"Python function signature [args: %s] [kwargs: %s]",
str(args),
str(kwargs))
call_context_key = cache_key.replace(input_signature=None)
# If there's a provided input signature, or XLA is being used, or
# there's no cache miss for this calling context so far, go ahead and
# build the function and bypass shape relaxation retracing.
if (self._input_signature is not None
or cache_key.uses_xla
or call_context_key not in self._function_cache.missed):
self._function_cache.missed.add(call_context_key)
graph_function = self._create_graph_function(args, kwargs)
self._function_cache.primary[cache_key] = graph_function
return graph_function, args, kwargs
rank_only_cache_key = self._cache_key(
args, kwargs, include_tensor_ranks_only=True)
arg_shapes = _flat_shape_list(args, kwargs)
relaxed_arg_shapes = self._function_cache.arg_relaxed_shapes.get(
rank_only_cache_key, None)
relaxed_arg_function = self._function_cache.arg_relaxed.get(
rank_only_cache_key, None)
if (relaxed_arg_function is not None
and _compatible_shapes(relaxed_arg_shapes, arg_shapes)):
return relaxed_arg_function, args, kwargs
if relaxed_arg_shapes is None:
relaxed_arg_shapes = arg_shapes
else:
if len(arg_shapes) != len(relaxed_arg_shapes):
raise RuntimeError("Expected arg_shapes len to match "
"relaxed_arg_shapes len: %d vs. %d"
% (len(arg_shapes), len(relaxed_arg_shapes)))
relaxed_arg_shapes = [
_common_shape(x, y) for (x, y) in zip(
arg_shapes, relaxed_arg_shapes)]
self._function_cache.arg_relaxed_shapes[rank_only_cache_key] = (
relaxed_arg_shapes)
graph_function = self._create_graph_function(
args, kwargs, override_flat_arg_shapes=relaxed_arg_shapes)
self._function_cache.arg_relaxed[rank_only_cache_key] = graph_function
return graph_function, args, kwargs

View File

@ -66,6 +66,13 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
def total_function_cache(defined):
# pylint: disable=protected-access
return (set(defined._function_cache.primary)
| set(defined._function_cache.arg_relaxed))
# pylint: enable=protected-access
class MiniModel(keras_training.Model):
"""Minimal model for mnist.
@ -99,6 +106,94 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108])
def testVariable(self):
v1 = variables.Variable(1.0)
add = def_function.function(lambda x, v: x + v1 + v)
v2 = variables.Variable(1.0)
x = constant_op.constant(1.0)
r = add(x, v2)
self.assertEqual(3.0, self.evaluate(r))
def testInputShapeFunctionRelaxation(self):
unknown_dim = [False]
@function.defun
def func(a):
if a._shape_tuple()[0] is None:
unknown_dim[0] = True
return a + 1
func(constant_op.constant([]))
self.assertFalse(unknown_dim[0])
self.assertLen(total_function_cache(func), 1)
func(constant_op.constant([1.0]))
self.assertFalse(unknown_dim[0])
self.assertLen(total_function_cache(func), 2)
func(constant_op.constant([1.0, 2.0]))
self.assertTrue(unknown_dim[0])
self.assertLen(total_function_cache(func), 2)
def testNestedInputShapeFunctionRelaxation(self):
unknown_dim = [False]
@function.defun
def func(a_, b_=None):
del a_ # Only used to check which cache is used.
self.assertEqual(b_[0]._shape_tuple(), ())
if b_[1]._shape_tuple()[0] is None:
unknown_dim[0] = True
return b_[0] + 1
a = 'hi'
b0 = constant_op.constant(1.0)
func(a, b_=[b0, constant_op.constant([])])
self.assertFalse(unknown_dim[0])
self.assertLen(total_function_cache(func), 1)
func(a, b_=[b0, constant_op.constant([1.0])])
self.assertFalse(unknown_dim[0])
self.assertLen(total_function_cache(func), 2)
func(a, b_=[b0, constant_op.constant([1.0, 1.0])])
self.assertTrue(unknown_dim[0])
self.assertLen(total_function_cache(func), 2)
unknown_dim[0] = False
# Now do the same except with a new a which is not a tensor; this should
# change the cache key.
a = 'bye'
func(a, b_=[b0, constant_op.constant([])])
self.assertFalse(unknown_dim[0])
self.assertLen(total_function_cache(func), 3)
# Since we already marked a cache miss for a function with the same
# non-input signatures, here we will immediately start relaxing shapes.
func(a, b_=[b0, constant_op.constant([1.0])])
self.assertTrue(unknown_dim[0])
self.assertLen(total_function_cache(func), 3)
def testFunctionRelaxationLosesInnerDimWithKerasLayer(self):
layer = keras.layers.Dense(1)
fn = def_function.function()(layer)
with self.captureWritesToStream(sys.stderr) as printed:
fn(array_ops.ones((3, 2)))
self.assertNotIn('ValueError', printed.contents())
with self.captureWritesToStream(sys.stderr) as printed:
# Use batch size 2 to trigger a second cache miss on the shape.
fn(array_ops.ones((2, 2)))
self.assertNotIn('ValueError', printed.contents())
# Shape relaxation passes TensorShape([None, None]), which causes layer
# matmul to fail, due to incompatible dims. What would have been a graph
# build time error (layer would complain about the inner dim being 4).
with self.captureWritesToStream(sys.stderr) as printed:
with self.assertRaisesRegexp(errors.InvalidArgumentError, r'MatMul'):
fn(array_ops.ones((3, 4)))
def testWastedAdd(self):
@def_function.function()
@ -382,13 +477,13 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
x = random_ops.random_uniform([2, 2]).numpy()
defined = function.defun(f)
defined(x)
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
x = random_ops.random_uniform([2, 2]).numpy()
defined(x)
# A NumPy array with different values but the same shape and dtype
# shouldn't trigger another function definition.
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
# Test that the numpy array is properly an argument to the graph function.
self.assertEqual(1., defined(numpy.ones([])).numpy())
@ -1110,7 +1205,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
defined = function.defun(multi_device_fn)
outputs = self.evaluate(defined())
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
@ -1118,7 +1213,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
with ops.device('/cpu:3'):
outputs = self.evaluate(defined())
# All function definitions are agnostic to call site devices.
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
@ -1126,7 +1221,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
with ops.device('/cpu:0'):
outputs = self.evaluate(defined())
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
@ -1211,10 +1306,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
defined = function.defun(func)
defined(Foo())
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
defined(Foo())
self.assertLen(defined._function_cache, 2)
self.assertLen(total_function_cache(defined), 2)
def testCacheTensorDtypeCollision(self):
@ -1224,11 +1319,11 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
defined = function.defun(func)
t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
defined(t)
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
t = constant_op.constant([[1.0]], dtype=dtypes.complex128)
defined(t)
self.assertLen(defined._function_cache, 2)
self.assertLen(total_function_cache(defined), 2)
def testCacheTensorShapeCollision(self):
@ -1238,11 +1333,11 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
defined = function.defun(func)
t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
defined(t)
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
t = constant_op.constant([1.0], dtype=dtypes.complex64)
defined(t)
self.assertLen(defined._function_cache, 2)
self.assertLen(total_function_cache(defined), 2)
def testCacheTensorShapeDtypeCollision(self):
@ -1252,11 +1347,11 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
defined = function.defun(func)
t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
defined(t)
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
t = constant_op.constant([1.0], dtype=dtypes.complex128)
defined(t)
self.assertLen(defined._function_cache, 2)
self.assertLen(total_function_cache(defined), 2)
def testCacheTensorUnknownShapesCollision(self):
@ -1266,21 +1361,34 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
with context.graph_mode(), self.cached_session():
defined = function.defun(func)
p = array_ops.placeholder(dtype=dtypes.float32, shape=None)
p = array_ops.placeholder(dtype=dtypes.float32, shape=[])
defined(p)
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
p = array_ops.placeholder(dtype=dtypes.float32, shape=[None])
p = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
defined(p)
self.assertLen(defined._function_cache, 2)
self.assertLen(total_function_cache(defined), 2)
p = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
p = array_ops.placeholder(dtype=dtypes.float32, shape=[2])
defined(p)
self.assertLen(defined._function_cache, 3)
# Gradual shape relaxation is performed; and the common shape between
# [1] and [2] is one containing unknown dimensions.
self.assertLen(total_function_cache(defined), 2)
t = constant_op.constant(1.0, dtype=dtypes.float32)
# pylint: disable=protected-access
self.assertLen(defined._function_cache.arg_relaxed_shapes, 1)
relaxed_shapes = (
list(defined._function_cache.arg_relaxed_shapes.values())[0])
self.assertEqual(len(relaxed_shapes), 1)
relaxed_shape = relaxed_shapes[0]
# pylint: enable=protected-access
self.assertEqual(relaxed_shape.rank, 1)
self.assertEqual(tensor_shape.dimension_value(relaxed_shape[0]), None)
t = constant_op.constant([1.0, 1.0, 1.0], dtype=dtypes.float32)
defined(t)
self.assertLen(defined._function_cache, 4)
# Shape (3,) matches the relaxed shape TensorShape([None])
self.assertLen(total_function_cache(defined), 2)
def testPythonFunctionWithDefaultArgs(self):
@ -1295,7 +1403,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
def cache_keys():
"""Sanitizes cache keys of non-input metadata."""
return tuple(key[0] for key in defined._function_cache)
return tuple(key[0] for key in total_function_cache(defined))
# `True` corresponds to the fact that we're executing eagerly
self.assertIn(('URRRu', (0, 1, 20)), cache_keys())
@ -1305,19 +1413,19 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
# This matches the previous call.
defined(foo=1)
self.assertLen(defined._function_cache, 2)
self.assertLen(total_function_cache(defined), 2)
defined(1, 2, 3)
self.assertLen(defined._function_cache, 3)
self.assertLen(total_function_cache(defined), 3)
self.assertIn(('URRRu', (1, 2, 3)), cache_keys())
# This matches the previous call.
defined(1, bar=2, baz=3)
self.assertLen(defined._function_cache, 3)
self.assertLen(total_function_cache(defined), 3)
# This matches the previous call.
defined(1, baz=3, bar=2)
self.assertLen(defined._function_cache, 3)
self.assertLen(total_function_cache(defined), 3)
def testFunctoolsPartialUnwrappedCorrectly(self):
@ -1343,12 +1451,12 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
defined = function.defun(foo, input_signature=signature)
a = array_ops.ones([2])
self.assertAllEqual(a, defined(a))
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
self.assertAllEqual(a, defined.get_concrete_function()(a))
self.assertAllEqual(a, defined.get_concrete_function(a)(a))
self.assertAllEqual(a, defined.get_concrete_function(
tensor_spec.TensorSpec((2,), dtype=dtypes.float32))(a))
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
def bar(a):
self.assertEqual(a._shape_tuple(), (2, None))
@ -1358,13 +1466,13 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
defined = function.defun(bar, input_signature=signature)
a = array_ops.ones([2, 1])
out = defined(a)
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
self.assertAllEqual(out, a)
# Changing the second dimension shouldn't create a new function.
b = array_ops.ones([2, 3])
out = defined(b)
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
self.assertAllEqual(out, b)
def testInputSignatureWithCompatibleInputs(self):
@ -1405,7 +1513,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
b = array_ops.ones([1])
expected = expected_foo([a, a], b)
out = foo([a, a], b)
self.assertLen(foo._function_cache, 1)
self.assertLen(total_function_cache(foo), 1)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], a)
@ -1417,7 +1525,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
c = array_ops.ones([1])
expected = expected_foo([a, b], c)
out = foo([a, b], c)
self.assertLen(foo._function_cache, 1)
self.assertLen(total_function_cache(foo), 1)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], b)
@ -1428,7 +1536,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
b = b.numpy().tolist()
c = c.numpy().tolist()
out = foo([a, b], c)
self.assertLen(foo._function_cache, 1)
self.assertLen(total_function_cache(foo), 1)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], b)
@ -1597,22 +1705,22 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
integer = constant_op.constant(2, dtypes.int64)
out1, out2 = foo(flt, integer)
self.assertLen(foo._function_cache, 1)
self.assertLen(total_function_cache(foo), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(flt=flt, integer=integer)
self.assertLen(foo._function_cache, 1)
self.assertLen(total_function_cache(foo), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(integer=integer, flt=flt)
self.assertLen(foo._function_cache, 1)
self.assertLen(total_function_cache(foo), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(flt, integer=integer)
self.assertLen(foo._function_cache, 1)
self.assertLen(total_function_cache(foo), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
@ -1642,27 +1750,27 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
a = constant_op.constant(2.0)
b = constant_op.constant([1.0, 2.0])
one = defined(a, b)
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
two = defined(a=a, b=b)
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
three = defined(b=b, a=a)
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
four = defined(a, b=b)
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
# The next call corresponds to a new input signature, hence
# we expect another function to be defined.
five = defined(b, a)
self.assertLen(defined._function_cache, 2)
self.assertLen(total_function_cache(defined), 2)
six = defined(a=b, b=a)
self.assertLen(defined._function_cache, 2)
self.assertLen(total_function_cache(defined), 2)
seven = defined(b=a, a=b)
self.assertLen(defined._function_cache, 2)
self.assertLen(total_function_cache(defined), 2)
self.assertAllEqual(one, [1.0, 2.0])
self.assertAllEqual(two, [1.0, 2.0])
@ -2049,18 +2157,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
with ops.Graph().as_default():
x = constant_op.constant(11)
maybe_add(x, True)
self.assertLen(maybe_add._function_cache, 1)
self.assertLen(add._function_cache, 1)
self.assertLen(total_function_cache(maybe_add), 1)
self.assertLen(total_function_cache(add), 1)
maybe_add(x, False)
self.assertLen(maybe_add._function_cache, 2)
self.assertLen(add._function_cache, 1)
self.assertLen(total_function_cache(maybe_add), 2)
self.assertLen(total_function_cache(add), 1)
with ops.Graph().as_default():
x = constant_op.constant(11)
maybe_add(x, True)
self.assertLen(maybe_add._function_cache, 3)
self.assertLen(add._function_cache, 2)
self.assertLen(total_function_cache(maybe_add), 3)
self.assertLen(total_function_cache(add), 2)
def testCacheKeyOverlappingShapes(self):
@function.defun
@ -2068,10 +2176,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
return t
defined(array_ops.zeros([12, 1]))
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
defined(array_ops.zeros([1, 21]))
self.assertLen(defined._function_cache, 2)
self.assertLen(total_function_cache(defined), 2)
def testCacheKeyNestedLists(self):
@function.defun
@ -2082,10 +2190,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
b = constant_op.constant(2.)
c = constant_op.constant(3.)
defined([[a], b, c])
self.assertLen(defined._function_cache, 1)
self.assertLen(total_function_cache(defined), 1)
defined([[a, b], c])
self.assertLen(defined._function_cache, 2)
self.assertLen(total_function_cache(defined), 2)
def testDecoratedMethod(self):
m = DefunnedMiniModel()
@ -2647,6 +2755,7 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
result = func(g1, g2, c1, g3, c2)
self.assertEqual(result.numpy(), 5.0 * 7.0 * 17.0)
if __name__ == '__main__':
ops.enable_eager_execution(
config=config_pb2.ConfigProto(device_count={'CPU': 4}))

View File

@ -231,8 +231,11 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim);
PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor);
// Encodes the object as a tuple that is meant to be used as part of the key
// for the defun function cache.
PyObject* TFE_Py_EncodeArg(PyObject*);
// for the defun function cache. If `include_tensor_ranks_only` is true,
// then the encoding only stores tensor ranks, and the key is
// agnostic to dimension sizes. Otherwise, full tensor shape encodings are
// returned.
PyObject* TFE_Py_EncodeArg(PyObject*, bool include_tensor_ranks_only);
void TFE_Py_EnableInteractivePythonLogging();

View File

@ -2913,7 +2913,9 @@ struct EncodeResult {
}
};
tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
bool include_tensor_ranks_only,
EncodeResult* result) {
if (EagerTensor_CheckExact(arg)) {
TFE_TensorHandle* t = EagerTensor_Handle(arg);
tensorflow::TensorShape tensor_shape;
@ -2922,10 +2924,13 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
absl::StrAppend(&result->str, kDType, t->handle->dtype);
absl::StrAppend(&result->str, kShape);
for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) {
absl::StrAppend(&result->str, dim_size, kShapeDelim);
if (include_tensor_ranks_only) {
absl::StrAppend(&result->str, tensor_shape.dim_sizes().size());
} else {
for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) {
absl::StrAppend(&result->str, dim_size, kShapeDelim);
}
}
return tensorflow::Status::OK();
}
@ -2949,6 +2954,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
absl::StrAppend(&result->str, kDType, dtype);
static char _shape_tuple[] = "_shape_tuple";
tensorflow::Safe_PyObjectPtr shape_tuple(
PyObject_CallMethod(arg, _shape_tuple, nullptr));
@ -2969,23 +2975,30 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
shape_tuple.get(), "shape_tuple didn't return a sequence"));
int len = PySequence_Fast_GET_SIZE(shape_seq.get());
for (int i = 0; i < len; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i);
if (item == Py_None) {
absl::StrAppend(&result->str, kNone);
} else {
absl::StrAppend(&result->str, MakeInt(item));
if (include_tensor_ranks_only) {
absl::StrAppend(&result->str, len);
} else {
for (int i = 0; i < len; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i);
if (item == Py_None) {
absl::StrAppend(&result->str, kNone);
} else {
absl::StrAppend(&result->str, MakeInt(item));
}
}
}
return tensorflow::Status::OK();
}
tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result);
tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
bool include_tensor_ranks_only,
EncodeResult* result);
// This function doesn't set the type of sequence before
tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
const char* end_type,
bool include_tensor_ranks_only,
EncodeResult* result) {
tensorflow::Safe_PyObjectPtr arg_seq(
PySequence_Fast(arg, "unable to create seq from list/tuple"));
@ -2997,7 +3010,8 @@ tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
if (item == Py_None) {
absl::StrAppend(&result->str, kNone);
} else {
TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(item, result));
TF_RETURN_IF_ERROR(
TFE_Py_EncodeArgHelper(item, include_tensor_ranks_only, result));
}
}
absl::StrAppend(&result->str, end_type);
@ -3005,10 +3019,13 @@ tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
return tensorflow::Status::OK();
}
tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) {
tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
bool include_tensor_ranks_only,
EncodeResult* result) {
if (tensorflow::swig::IsTensor(arg)) {
absl::StrAppend(&result->str, kTensor);
TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(arg, result));
TF_RETURN_IF_ERROR(
TFE_Py_EncodeTensor(arg, include_tensor_ranks_only, result));
} else if (tensorflow::swig::IsIndexedSlices(arg)) {
absl::StrAppend(&result->str, kIndexedSlices);
tensorflow::Safe_PyObjectPtr values(PyObject_GetAttrString(arg, "values"));
@ -3017,7 +3034,8 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) {
return tensorflow::errors::InvalidArgument(
"IndexedSlices does not have a values attr");
}
TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(values.get(), result));
TF_RETURN_IF_ERROR(
TFE_Py_EncodeTensor(values.get(), include_tensor_ranks_only, result));
tensorflow::Safe_PyObjectPtr indices(
PyObject_GetAttrString(arg, "indices"));
@ -3026,7 +3044,8 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) {
return tensorflow::errors::InvalidArgument(
"IndexedSlices does not have a indices attr");
}
TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(indices.get(), result));
TF_RETURN_IF_ERROR(
TFE_Py_EncodeTensor(indices.get(), include_tensor_ranks_only, result));
tensorflow::Safe_PyObjectPtr dense_shape(
PyObject_GetAttrString(arg, "dense_shape"));
@ -3036,12 +3055,15 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) {
"IndexedSlices does not have a dense_shape attr");
}
if (dense_shape.get() != Py_None) {
TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(dense_shape.get(), result));
TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(
dense_shape.get(), include_tensor_ranks_only, result));
}
} else if (PyList_Check(arg)) {
TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kList, kListEnd, result));
TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
arg, kList, kListEnd, include_tensor_ranks_only, result));
} else if (PyTuple_Check(arg)) {
TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kTuple, kTupleEnd, result));
TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
arg, kTuple, kTupleEnd, include_tensor_ranks_only, result));
} else if (PyDict_Check(arg)) {
tensorflow::Safe_PyObjectPtr keys(PyDict_Keys(arg));
if (PyList_Sort(keys.get()) == -1) {
@ -3053,9 +3075,11 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) {
for (int i = 0; i < len; i++) {
PyObject* key = PyList_GetItem(keys.get(), i);
TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(key, result));
TF_RETURN_IF_ERROR(
TFE_Py_EncodeArgHelper(key, include_tensor_ranks_only, result));
PyObject* value = PyDict_GetItem(arg, key);
TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(value, result));
TF_RETURN_IF_ERROR(
TFE_Py_EncodeArgHelper(value, include_tensor_ranks_only, result));
}
} else {
PyObject* object = PyWeakref_NewRef(arg, nullptr);
@ -3082,10 +3106,15 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) {
// on known shapes to produce slimmer graphs, and correctness, as some
// high-level APIs require shapes to be fully-known.
//
// `include_tensor_ranks_only` allows caching on arguments excluding shape info,
// so that a slow path using relaxed shape can rely on a cache key that excludes
// shapes.
//
// TODO(nareshmodi): Add support for sparse tensors.
PyObject* TFE_Py_EncodeArg(PyObject* arg) {
PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only) {
EncodeResult result;
const auto status = TFE_Py_EncodeArgHelper(arg, &result);
const auto status =
TFE_Py_EncodeArgHelper(arg, include_tensor_ranks_only, &result);
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
return nullptr;
}

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import collections as py_collections
import itertools
import weakref
from tensorflow.core.framework import attr_value_pb2
@ -468,7 +469,8 @@ def func_graph_from_py_func(name,
arg_names=None,
op_return_value=None,
collections=None,
capture_by_value=None):
capture_by_value=None,
override_flat_arg_shapes=None):
"""Returns a `FuncGraph` generated from `python_func`.
Args:
@ -507,6 +509,12 @@ def func_graph_from_py_func(name,
capture_by_value: An optional boolean. If True, the func graph will capture
Variables by value instead of reference. By default inherit from outer
graphs, and failing that will default to False.
override_flat_arg_shapes: An optional list of instances that are either
`None` or `TensorShape`. The length must match that of
`nest.flatten((args, kwargs))`. The entries containing value `None`
must match entries in flattened arguments containing non-tensors, while
entries containing a `TensorShape` must match entries in the flattened
arguments containing tensors.
Returns:
A FuncGraph.
@ -514,6 +522,8 @@ def func_graph_from_py_func(name,
Raises:
TypeError: If any of `python_func`'s return values is neither `None` nor a
`Tensor`.
ValueError: If both `signature` and `override_flat_arg_shapes` are
passed in.
"""
if op_return_value is not None:
assert isinstance(op_return_value, ops.Tensor), op_return_value
@ -530,13 +540,27 @@ def func_graph_from_py_func(name,
default_use_recource = current_scope.use_resource
current_scope.set_use_resource(True)
if signature is not None and override_flat_arg_shapes is not None:
raise ValueError(
"Passed both signature and override_flat_arg_shapes: %s and %s."
% (signature, override_flat_arg_shapes))
if signature is not None:
args = signature
kwargs = {}
# Creates and names placeholders for all arguments.
func_args = _get_defun_inputs_from_args(args, arg_names)
func_kwargs = _get_defun_inputs_from_kwargs(kwargs)
if override_flat_arg_shapes is not None:
flat_args = nest.flatten(args)
arg_shapes = override_flat_arg_shapes[:len(flat_args)]
kwarg_shapes = override_flat_arg_shapes[len(flat_args):]
else:
arg_shapes = None
kwarg_shapes = None
func_args = _get_defun_inputs_from_args(
args, arg_names, flat_shapes=arg_shapes)
func_kwargs = _get_defun_inputs_from_kwargs(
kwargs, flat_shapes=kwarg_shapes)
# Convert all Tensors into TensorSpecs before saving the structured inputs.
# If storing pure concrete functions that are not called through polymorphic
@ -761,43 +785,73 @@ def _create_substitute_placeholder(value, name=None, dtype=None):
return placeholder
def _get_defun_inputs_from_args(args, names):
def _get_defun_inputs_from_args(args, names, flat_shapes=None):
"""Maps Python function positional args to graph-construction inputs."""
return _get_defun_inputs(args, names, structure=args)
return _get_defun_inputs(
args, names, structure=args, flat_shapes=flat_shapes)
def _get_defun_inputs(flat_args, names, structure):
def _get_defun_inputs(args, names, structure, flat_shapes=None):
"""Maps python function args to graph-construction inputs.
Args:
flat_args: A flat list of user-specified arguments.
args: A flat list of user-specified arguments.
names: A list of strings with user-specified argument names, same length as
`flat_args`. May be `None`, in which case a generic name is used.
`args`. May be `None`, in which case a generic name is used.
structure: The original argument list or dictionary.
flat_shapes: A flat list of values that are either `None` or
instances of `TensorShape`. If provided, then length must match
that of `nest.flatten(args)`; and locations where `args` are
instances of `Tensor` must have a corresponding `TensorShape` in
`flat_shapes`. May be `None`, in which case exact shapes are read
directly from the args.
Returns:
Placeholders with the same structure as `structure`.
Raises:
RuntimeError: if `flat_shapes` is provided, but
`len(flat_shapes) != len(nest.flatten(args))`.
RuntimeError: if a shape from `flat_shapes` is not None
for an argument that is not a `Tensor`, `TensorSpec`,
or `ResourceVariable`.
"""
func_graph = ops.get_default_graph()
function_inputs = []
if names is None:
names = [None] * len(flat_args)
for arg_value, name in zip(flat_args, names):
names = [None] * len(args)
if flat_shapes is None:
shapes_iter = itertools.repeat(None)
else:
len_flat_args = len(nest.flatten(args))
if len_flat_args != len(flat_shapes):
raise RuntimeError(
"Length of fully flat shapes (%d) must match that of "
"flatten(args) (%d). args: %s, flat_shapes: %s"
% (len(flat_shapes),
len_flat_args,
args,
flat_shapes))
shapes_iter = iter(flat_shapes)
for arg_value, name in zip(args, names):
for arg in nest.flatten(arg_value):
# We have a shape entry for each arg, regadless of whether it's a real
# Tensor or not. For non-tensor entries it should be None.
shape = next(shapes_iter)
if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)):
if isinstance(arg, tensor_spec.TensorSpec) and arg.name:
requested_name = arg.name
else:
requested_name = name
placeholder_shape = shape if shape is not None else arg.shape
try:
placeholder = graph_placeholder(
arg.dtype, arg.shape,
arg.dtype, placeholder_shape,
name=requested_name)
except ValueError:
# Sometimes parameter names are not valid op names, so fall back to
# unnamed placeholders.
placeholder = graph_placeholder(arg.dtype, arg.shape)
placeholder = graph_placeholder(arg.dtype, placeholder_shape)
if name is not None:
# Record the requested/user-specified name in case it's different than
# the uniquified name, for validation when exporting signatures.
@ -816,18 +870,24 @@ def _get_defun_inputs(flat_args, names, structure):
attr_value_pb2.AttrValue(s=compat.as_bytes(name)))
function_inputs.append(arg)
else:
if shape is not None:
raise RuntimeError(
"Expected provided shape override to be None for arg that isn't "
"a Tensor, but saw arg: '%s', shape: '%s'. args: %s"
% (arg, shape, args))
function_inputs.append(arg)
return nest.pack_sequence_as(structure, function_inputs)
def _get_defun_inputs_from_kwargs(kwargs):
def _get_defun_inputs_from_kwargs(kwargs, flat_shapes):
"""Maps Python function keyword args to graph-construction inputs."""
if kwargs:
names, flat_args = zip(*sorted(kwargs.items()))
names, args = zip(*sorted(kwargs.items()))
else:
names = []
flat_args = []
return _get_defun_inputs(flat_args, names, structure=kwargs)
args = []
return _get_defun_inputs(
args, names, structure=kwargs, flat_shapes=flat_shapes)
def dismantle_func_graph(func_graph):