Turn on shape relaxation retracing in @tf.function.
PiperOrigin-RevId: 235054774
This commit is contained in:
parent
1d0ec3ec5d
commit
d8033ab10d
@ -578,9 +578,11 @@ class Function(object):
|
||||
concrete_functions = []
|
||||
# pylint: disable=protected-access
|
||||
if self._stateful_fn:
|
||||
concrete_functions.extend(self._stateful_fn._function_cache.values())
|
||||
concrete_functions.extend(
|
||||
self._stateful_fn._function_cache.all_values())
|
||||
if self._stateless_fn:
|
||||
concrete_functions.extend(self._stateless_fn._function_cache.values())
|
||||
concrete_functions.extend(
|
||||
self._stateless_fn._function_cache.all_values())
|
||||
# pylint: enable=protected-access
|
||||
deduplicated_concrete_functions = list()
|
||||
seen_signatures = list()
|
||||
|
@ -43,6 +43,7 @@ from tensorflow.python.framework import error_interpolation
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import func_graph as func_graph_module
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import functional_ops
|
||||
@ -60,10 +61,94 @@ from tensorflow.python.util import tf_inspect
|
||||
FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
|
||||
BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
|
||||
|
||||
CacheKey = collections.namedtuple("CacheKey", [
|
||||
"input_signature", "parent_graph", "device_functions", "colocation_stack",
|
||||
"uses_xla"
|
||||
])
|
||||
class CacheKey(
|
||||
collections.namedtuple("CacheKey", [
|
||||
"input_signature", "parent_graph", "device_functions",
|
||||
"colocation_stack", "uses_xla"])):
|
||||
|
||||
def replace(self, *args, **kwargs):
|
||||
return self._replace(*args, **kwargs)
|
||||
|
||||
|
||||
def _flat_shape_list(*params):
|
||||
"""Return a flat list of TensorShapes, one for each tensor[spec] in `*params`.
|
||||
|
||||
Args:
|
||||
*params: Set of nested entries containing Tensors, TensorSpec, and
|
||||
non-tensors.
|
||||
|
||||
Returns:
|
||||
A list of entries containing either `None` or `TensorShape`.
|
||||
"""
|
||||
return [tensor_shape.TensorShape(x.shape)
|
||||
if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) else None
|
||||
for x in nest.flatten(params)]
|
||||
|
||||
|
||||
def _compatible_shapes(flat_x, flat_y):
|
||||
"""Check if lists of TensorShapes contain compatible shapes.
|
||||
|
||||
Args:
|
||||
flat_x: List of TensorShape or None.
|
||||
flat_y: List of TensorShape or None.
|
||||
|
||||
Returns:
|
||||
A python bool.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if `len(flat_x) != len(flat_y)`.
|
||||
RuntimeError: if `flat_x[i] is None != flat_y[i] is None` for any `i`.
|
||||
"""
|
||||
if len(flat_x) != len(flat_y):
|
||||
raise RuntimeError("Expected shape lists of identical lengths, but saw: "
|
||||
"%s and %s" % (flat_x, flat_y))
|
||||
def is_compatible(x, y):
|
||||
"""Internal help function.
|
||||
|
||||
Args:
|
||||
x: TensorShape or None.
|
||||
y: TensorShape or None.
|
||||
|
||||
Returns:
|
||||
Python bool.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `x is None != y is None`.
|
||||
"""
|
||||
# If both x and y are None, there is no shape to compare. Otherwise check
|
||||
# if they are compatible with each other. Either way, both input signatures
|
||||
# must have have Tensors in the same entries. If not, raise an assertion
|
||||
# error.
|
||||
if x is None != y is None:
|
||||
raise RuntimeError(
|
||||
"Expected signature type matches between flattened input shapes "
|
||||
"%s and %s; but saw that (%s is None) != (%s is None)"
|
||||
% (flat_x, flat_y, x, y))
|
||||
return x is None or x.is_compatible_with(y)
|
||||
return all(is_compatible(x, y) for x, y in zip(flat_x, flat_y))
|
||||
|
||||
|
||||
def _common_shape(x, y):
|
||||
"""Find a `TensorShape` that is compatible with both `x` and `y`."""
|
||||
if x is None != y is None:
|
||||
raise RuntimeError(
|
||||
"Cannot find a common shape when LHS shape is None but RHS shape "
|
||||
"is not (or vice versa): %s vs. %s" % (x, y))
|
||||
if x is None:
|
||||
return None # The associated input was not a Tensor, no shape generated.
|
||||
if not isinstance(x, tensor_shape.TensorShape):
|
||||
raise TypeError("Expected x to be a TensorShape but saw %s" % (x,))
|
||||
if not isinstance(y, tensor_shape.TensorShape):
|
||||
raise TypeError("Expected y to be a TensorShape but saw %s" % (y,))
|
||||
if x.rank != y.rank or x.rank is None:
|
||||
return tensor_shape.TensorShape(None)
|
||||
dims = []
|
||||
for dim_x, dim_y in zip(x.dims, y.dims):
|
||||
if dim_x != dim_y or tensor_shape.dimension_value(dim_x) is None:
|
||||
dims.append(None)
|
||||
else:
|
||||
dims.append(tensor_shape.dimension_value(dim_x))
|
||||
return tensor_shape.TensorShape(dims)
|
||||
|
||||
|
||||
def is_same_structure(structure1,
|
||||
@ -1073,6 +1158,34 @@ def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature):
|
||||
return inputs
|
||||
|
||||
|
||||
class FunctionCache(object):
|
||||
"""A lightweight container for cached functions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# The set of functions that have been missed; entries are CacheKey with
|
||||
# input_signature `None` (e.g. a "call context key")
|
||||
self.missed = set()
|
||||
# The primary cache, mapping a fully shaped CacheKey to a function.
|
||||
self.primary = collections.OrderedDict()
|
||||
# A cache key lookup, mapping a CacheKey generated without shape info to a
|
||||
# flat list of relaxed shapes (one for each argument). Arguments that are
|
||||
# not Tensors contain a `None` for the corresponding relaxed shape.
|
||||
self.arg_relaxed_shapes = collections.OrderedDict()
|
||||
# The secondary cache, mapping a CacheKey generated without shape info to a
|
||||
# function.
|
||||
self.arg_relaxed = collections.OrderedDict()
|
||||
# All OrderedDicts require manual garbage collection.
|
||||
self._garbage_collectors = [
|
||||
_FunctionGarbageCollector(self.primary),
|
||||
_FunctionGarbageCollector(self.arg_relaxed),
|
||||
_FunctionGarbageCollector(self.arg_relaxed_shapes)]
|
||||
|
||||
def all_values(self):
|
||||
"""A set of all `ConcreteFunction` instances held by this cache."""
|
||||
return set(self.primary.values()) | set(self.arg_relaxed.values())
|
||||
|
||||
|
||||
class Function(object):
|
||||
"""Wrapper class for the graph functions defined for a Python function.
|
||||
|
||||
@ -1126,8 +1239,7 @@ class Function(object):
|
||||
self._name = name
|
||||
self._autograph = autograph
|
||||
self._autograph_options = autograph_options
|
||||
self._function_cache = collections.OrderedDict()
|
||||
self._garbage_collector = _FunctionGarbageCollector(self._function_cache)
|
||||
self._function_cache = FunctionCache()
|
||||
self._function_attributes = attributes or {}
|
||||
self._capture_by_value = capture_by_value
|
||||
|
||||
@ -1284,13 +1396,15 @@ class Function(object):
|
||||
# Return the cached `Function` for the instance
|
||||
return self._descriptor_cache[instance]
|
||||
|
||||
def _cache_key(self, args, kwargs):
|
||||
def _cache_key(self, args, kwargs, include_tensor_ranks_only=False):
|
||||
"""Computes the cache key given inputs and execution context."""
|
||||
if self._input_signature is None:
|
||||
inputs = (args, kwargs) if kwargs else args
|
||||
input_signature = pywrap_tensorflow.TFE_Py_EncodeArg(inputs)
|
||||
input_signature = pywrap_tensorflow.TFE_Py_EncodeArg(
|
||||
inputs, include_tensor_ranks_only)
|
||||
else:
|
||||
del args, kwargs
|
||||
assert not include_tensor_ranks_only
|
||||
input_signature = self._flat_input_signature
|
||||
|
||||
ctx = context.context()
|
||||
@ -1336,6 +1450,46 @@ class Function(object):
|
||||
return CacheKey(input_signature, parent_graph, device_functions,
|
||||
colocation_stack, uses_xla)
|
||||
|
||||
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
|
||||
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
|
||||
if self._input_signature is None:
|
||||
arglen = len(args)
|
||||
else:
|
||||
arglen = len(self._input_signature)
|
||||
base_arg_names = self._function_spec.arg_names[:arglen]
|
||||
num_missing_args = arglen - len(self._function_spec.arg_names)
|
||||
missing_arg_names = [self._function_spec.vararg_name] * num_missing_args
|
||||
# Produce a list of missing args of the form ["arg_0", "arg_1", ...],
|
||||
# where arg is based on the self._function_spec.vararg_name.
|
||||
missing_arg_names = [
|
||||
"%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names)
|
||||
]
|
||||
arg_names = base_arg_names + missing_arg_names
|
||||
graph_function = ConcreteFunction(
|
||||
func_graph_module.func_graph_from_py_func(
|
||||
self._name,
|
||||
self._python_function,
|
||||
args,
|
||||
kwargs,
|
||||
self._input_signature,
|
||||
autograph=self._autograph,
|
||||
autograph_options=self._autograph_options,
|
||||
arg_names=arg_names,
|
||||
override_flat_arg_shapes=override_flat_arg_shapes,
|
||||
capture_by_value=self._capture_by_value),
|
||||
self._function_attributes)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
# Tell the ConcreteFunction to clean up its graph once it goes out of
|
||||
# scope. ConcreteFunction does not do this in its constructor since it
|
||||
# gets used in some places (like Keras) where the FuncGraph lives
|
||||
# longer than the ConcreteFunction.
|
||||
graph_function._garbage_collector = ConcreteFunctionGarbageCollector(
|
||||
graph_function.graph)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return graph_function
|
||||
|
||||
def _maybe_define_function(self, args, kwargs):
|
||||
"""Gets a function for these inputs, defining it if necessary.
|
||||
|
||||
@ -1353,57 +1507,76 @@ class Function(object):
|
||||
Raises:
|
||||
ValueError: If inputs are incompatible with the input signature.
|
||||
TypeError: If the function inputs include non-hashable objects
|
||||
RuntimeError: If there's an internal bug (inconsistency) in handling
|
||||
shape relaxation retracing.
|
||||
"""
|
||||
if self._input_signature is None or args is not None or kwargs is not None:
|
||||
args, kwargs = self._function_spec.canonicalize_function_inputs(
|
||||
*args, **kwargs)
|
||||
cache_key = self._cache_key(args, kwargs)
|
||||
with self._lock:
|
||||
try:
|
||||
graph_function = self._function_cache.get(cache_key, None)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
"Arguments supplied to `defun`-generated functions must be"
|
||||
" hashable. Original error: %s" % e)
|
||||
|
||||
if graph_function is None:
|
||||
logging.vlog(1,
|
||||
"Creating new FuncGraph for Python function %r (key: %r)",
|
||||
self._python_function, cache_key)
|
||||
if self._input_signature is None:
|
||||
arglen = len(args)
|
||||
else:
|
||||
arglen = len(self._input_signature)
|
||||
base_arg_names = self._function_spec.arg_names[:arglen]
|
||||
num_missing_args = arglen - len(self._function_spec.arg_names)
|
||||
missing_arg_names = [self._function_spec.vararg_name] * num_missing_args
|
||||
# Produce a list of missing args of the form ["arg_0", "arg_1", ...],
|
||||
# where arg is based on the self._function_spec.vararg_name.
|
||||
missing_arg_names = [
|
||||
"%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names)
|
||||
]
|
||||
arg_names = base_arg_names + missing_arg_names
|
||||
graph_function = ConcreteFunction(
|
||||
func_graph_module.func_graph_from_py_func(
|
||||
self._name,
|
||||
self._python_function,
|
||||
args,
|
||||
kwargs,
|
||||
self._input_signature,
|
||||
autograph=self._autograph,
|
||||
autograph_options=self._autograph_options,
|
||||
arg_names=arg_names,
|
||||
capture_by_value=self._capture_by_value),
|
||||
self._function_attributes)
|
||||
# pylint: disable=protected-access
|
||||
# Tell the ConcreteFunction to clean up its graph once it goes out of
|
||||
# scope. ConcreteFunction does not do this in its constructor since it
|
||||
# gets used in some places (like Keras) where the FuncGraph lives
|
||||
# longer than the ConcreteFunction.
|
||||
graph_function._garbage_collector = ConcreteFunctionGarbageCollector(
|
||||
graph_function.graph)
|
||||
# pylint: enable=protected-access
|
||||
self._function_cache[cache_key] = graph_function
|
||||
try:
|
||||
hash(cache_key)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
"Arguments supplied to `defun`-generated functions must be"
|
||||
" hashable. Original error: %s" % e)
|
||||
|
||||
with self._lock:
|
||||
graph_function = self._function_cache.primary.get(cache_key, None)
|
||||
if graph_function is not None:
|
||||
return graph_function, args, kwargs
|
||||
|
||||
logging.vlog(1,
|
||||
"Creating new FuncGraph for Python function %r (key: %r)",
|
||||
self._python_function, cache_key)
|
||||
logging.vlog(2,
|
||||
"Python function signature [args: %s] [kwargs: %s]",
|
||||
str(args),
|
||||
str(kwargs))
|
||||
|
||||
call_context_key = cache_key.replace(input_signature=None)
|
||||
|
||||
# If there's a provided input signature, or XLA is being used, or
|
||||
# there's no cache miss for this calling context so far, go ahead and
|
||||
# build the function and bypass shape relaxation retracing.
|
||||
if (self._input_signature is not None
|
||||
or cache_key.uses_xla
|
||||
or call_context_key not in self._function_cache.missed):
|
||||
self._function_cache.missed.add(call_context_key)
|
||||
graph_function = self._create_graph_function(args, kwargs)
|
||||
self._function_cache.primary[cache_key] = graph_function
|
||||
return graph_function, args, kwargs
|
||||
|
||||
rank_only_cache_key = self._cache_key(
|
||||
args, kwargs, include_tensor_ranks_only=True)
|
||||
|
||||
arg_shapes = _flat_shape_list(args, kwargs)
|
||||
relaxed_arg_shapes = self._function_cache.arg_relaxed_shapes.get(
|
||||
rank_only_cache_key, None)
|
||||
relaxed_arg_function = self._function_cache.arg_relaxed.get(
|
||||
rank_only_cache_key, None)
|
||||
|
||||
if (relaxed_arg_function is not None
|
||||
and _compatible_shapes(relaxed_arg_shapes, arg_shapes)):
|
||||
return relaxed_arg_function, args, kwargs
|
||||
|
||||
if relaxed_arg_shapes is None:
|
||||
relaxed_arg_shapes = arg_shapes
|
||||
else:
|
||||
if len(arg_shapes) != len(relaxed_arg_shapes):
|
||||
raise RuntimeError("Expected arg_shapes len to match "
|
||||
"relaxed_arg_shapes len: %d vs. %d"
|
||||
% (len(arg_shapes), len(relaxed_arg_shapes)))
|
||||
relaxed_arg_shapes = [
|
||||
_common_shape(x, y) for (x, y) in zip(
|
||||
arg_shapes, relaxed_arg_shapes)]
|
||||
self._function_cache.arg_relaxed_shapes[rank_only_cache_key] = (
|
||||
relaxed_arg_shapes)
|
||||
graph_function = self._create_graph_function(
|
||||
args, kwargs, override_flat_arg_shapes=relaxed_arg_shapes)
|
||||
self._function_cache.arg_relaxed[rank_only_cache_key] = graph_function
|
||||
|
||||
return graph_function, args, kwargs
|
||||
|
||||
|
||||
|
@ -66,6 +66,13 @@ from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def total_function_cache(defined):
|
||||
# pylint: disable=protected-access
|
||||
return (set(defined._function_cache.primary)
|
||||
| set(defined._function_cache.arg_relaxed))
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
class MiniModel(keras_training.Model):
|
||||
"""Minimal model for mnist.
|
||||
|
||||
@ -99,6 +106,94 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
|
||||
self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108])
|
||||
|
||||
def testVariable(self):
|
||||
v1 = variables.Variable(1.0)
|
||||
add = def_function.function(lambda x, v: x + v1 + v)
|
||||
v2 = variables.Variable(1.0)
|
||||
x = constant_op.constant(1.0)
|
||||
r = add(x, v2)
|
||||
self.assertEqual(3.0, self.evaluate(r))
|
||||
|
||||
def testInputShapeFunctionRelaxation(self):
|
||||
unknown_dim = [False]
|
||||
|
||||
@function.defun
|
||||
def func(a):
|
||||
if a._shape_tuple()[0] is None:
|
||||
unknown_dim[0] = True
|
||||
return a + 1
|
||||
|
||||
func(constant_op.constant([]))
|
||||
self.assertFalse(unknown_dim[0])
|
||||
self.assertLen(total_function_cache(func), 1)
|
||||
|
||||
func(constant_op.constant([1.0]))
|
||||
self.assertFalse(unknown_dim[0])
|
||||
self.assertLen(total_function_cache(func), 2)
|
||||
|
||||
func(constant_op.constant([1.0, 2.0]))
|
||||
self.assertTrue(unknown_dim[0])
|
||||
self.assertLen(total_function_cache(func), 2)
|
||||
|
||||
def testNestedInputShapeFunctionRelaxation(self):
|
||||
unknown_dim = [False]
|
||||
|
||||
@function.defun
|
||||
def func(a_, b_=None):
|
||||
del a_ # Only used to check which cache is used.
|
||||
self.assertEqual(b_[0]._shape_tuple(), ())
|
||||
if b_[1]._shape_tuple()[0] is None:
|
||||
unknown_dim[0] = True
|
||||
return b_[0] + 1
|
||||
|
||||
a = 'hi'
|
||||
b0 = constant_op.constant(1.0)
|
||||
func(a, b_=[b0, constant_op.constant([])])
|
||||
self.assertFalse(unknown_dim[0])
|
||||
self.assertLen(total_function_cache(func), 1)
|
||||
|
||||
func(a, b_=[b0, constant_op.constant([1.0])])
|
||||
self.assertFalse(unknown_dim[0])
|
||||
self.assertLen(total_function_cache(func), 2)
|
||||
|
||||
func(a, b_=[b0, constant_op.constant([1.0, 1.0])])
|
||||
self.assertTrue(unknown_dim[0])
|
||||
self.assertLen(total_function_cache(func), 2)
|
||||
|
||||
unknown_dim[0] = False
|
||||
|
||||
# Now do the same except with a new a which is not a tensor; this should
|
||||
# change the cache key.
|
||||
a = 'bye'
|
||||
func(a, b_=[b0, constant_op.constant([])])
|
||||
self.assertFalse(unknown_dim[0])
|
||||
self.assertLen(total_function_cache(func), 3)
|
||||
|
||||
# Since we already marked a cache miss for a function with the same
|
||||
# non-input signatures, here we will immediately start relaxing shapes.
|
||||
func(a, b_=[b0, constant_op.constant([1.0])])
|
||||
self.assertTrue(unknown_dim[0])
|
||||
self.assertLen(total_function_cache(func), 3)
|
||||
|
||||
def testFunctionRelaxationLosesInnerDimWithKerasLayer(self):
|
||||
layer = keras.layers.Dense(1)
|
||||
fn = def_function.function()(layer)
|
||||
|
||||
with self.captureWritesToStream(sys.stderr) as printed:
|
||||
fn(array_ops.ones((3, 2)))
|
||||
self.assertNotIn('ValueError', printed.contents())
|
||||
with self.captureWritesToStream(sys.stderr) as printed:
|
||||
# Use batch size 2 to trigger a second cache miss on the shape.
|
||||
fn(array_ops.ones((2, 2)))
|
||||
self.assertNotIn('ValueError', printed.contents())
|
||||
|
||||
# Shape relaxation passes TensorShape([None, None]), which causes layer
|
||||
# matmul to fail, due to incompatible dims. What would have been a graph
|
||||
# build time error (layer would complain about the inner dim being 4).
|
||||
with self.captureWritesToStream(sys.stderr) as printed:
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, r'MatMul'):
|
||||
fn(array_ops.ones((3, 4)))
|
||||
|
||||
def testWastedAdd(self):
|
||||
|
||||
@def_function.function()
|
||||
@ -382,13 +477,13 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
x = random_ops.random_uniform([2, 2]).numpy()
|
||||
defined = function.defun(f)
|
||||
defined(x)
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
x = random_ops.random_uniform([2, 2]).numpy()
|
||||
defined(x)
|
||||
# A NumPy array with different values but the same shape and dtype
|
||||
# shouldn't trigger another function definition.
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
# Test that the numpy array is properly an argument to the graph function.
|
||||
self.assertEqual(1., defined(numpy.ones([])).numpy())
|
||||
@ -1110,7 +1205,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
defined = function.defun(multi_device_fn)
|
||||
outputs = self.evaluate(defined())
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
|
||||
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
|
||||
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
|
||||
@ -1118,7 +1213,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
with ops.device('/cpu:3'):
|
||||
outputs = self.evaluate(defined())
|
||||
# All function definitions are agnostic to call site devices.
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
|
||||
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
|
||||
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
|
||||
@ -1126,7 +1221,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
with ops.device('/cpu:0'):
|
||||
outputs = self.evaluate(defined())
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
|
||||
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
|
||||
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
|
||||
@ -1211,10 +1306,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
defined = function.defun(func)
|
||||
defined(Foo())
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
defined(Foo())
|
||||
self.assertLen(defined._function_cache, 2)
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
def testCacheTensorDtypeCollision(self):
|
||||
|
||||
@ -1224,11 +1319,11 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
defined = function.defun(func)
|
||||
t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
|
||||
defined(t)
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
t = constant_op.constant([[1.0]], dtype=dtypes.complex128)
|
||||
defined(t)
|
||||
self.assertLen(defined._function_cache, 2)
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
def testCacheTensorShapeCollision(self):
|
||||
|
||||
@ -1238,11 +1333,11 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
defined = function.defun(func)
|
||||
t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
|
||||
defined(t)
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
t = constant_op.constant([1.0], dtype=dtypes.complex64)
|
||||
defined(t)
|
||||
self.assertLen(defined._function_cache, 2)
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
def testCacheTensorShapeDtypeCollision(self):
|
||||
|
||||
@ -1252,11 +1347,11 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
defined = function.defun(func)
|
||||
t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
|
||||
defined(t)
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
t = constant_op.constant([1.0], dtype=dtypes.complex128)
|
||||
defined(t)
|
||||
self.assertLen(defined._function_cache, 2)
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
def testCacheTensorUnknownShapesCollision(self):
|
||||
|
||||
@ -1266,21 +1361,34 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
with context.graph_mode(), self.cached_session():
|
||||
defined = function.defun(func)
|
||||
|
||||
p = array_ops.placeholder(dtype=dtypes.float32, shape=None)
|
||||
p = array_ops.placeholder(dtype=dtypes.float32, shape=[])
|
||||
defined(p)
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
p = array_ops.placeholder(dtype=dtypes.float32, shape=[None])
|
||||
p = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
|
||||
defined(p)
|
||||
self.assertLen(defined._function_cache, 2)
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
p = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
|
||||
p = array_ops.placeholder(dtype=dtypes.float32, shape=[2])
|
||||
defined(p)
|
||||
self.assertLen(defined._function_cache, 3)
|
||||
# Gradual shape relaxation is performed; and the common shape between
|
||||
# [1] and [2] is one containing unknown dimensions.
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
t = constant_op.constant(1.0, dtype=dtypes.float32)
|
||||
# pylint: disable=protected-access
|
||||
self.assertLen(defined._function_cache.arg_relaxed_shapes, 1)
|
||||
relaxed_shapes = (
|
||||
list(defined._function_cache.arg_relaxed_shapes.values())[0])
|
||||
self.assertEqual(len(relaxed_shapes), 1)
|
||||
relaxed_shape = relaxed_shapes[0]
|
||||
# pylint: enable=protected-access
|
||||
self.assertEqual(relaxed_shape.rank, 1)
|
||||
self.assertEqual(tensor_shape.dimension_value(relaxed_shape[0]), None)
|
||||
|
||||
t = constant_op.constant([1.0, 1.0, 1.0], dtype=dtypes.float32)
|
||||
defined(t)
|
||||
self.assertLen(defined._function_cache, 4)
|
||||
# Shape (3,) matches the relaxed shape TensorShape([None])
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
def testPythonFunctionWithDefaultArgs(self):
|
||||
|
||||
@ -1295,7 +1403,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def cache_keys():
|
||||
"""Sanitizes cache keys of non-input metadata."""
|
||||
return tuple(key[0] for key in defined._function_cache)
|
||||
return tuple(key[0] for key in total_function_cache(defined))
|
||||
|
||||
# `True` corresponds to the fact that we're executing eagerly
|
||||
self.assertIn(('URRRu', (0, 1, 20)), cache_keys())
|
||||
@ -1305,19 +1413,19 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
# This matches the previous call.
|
||||
defined(foo=1)
|
||||
self.assertLen(defined._function_cache, 2)
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
defined(1, 2, 3)
|
||||
self.assertLen(defined._function_cache, 3)
|
||||
self.assertLen(total_function_cache(defined), 3)
|
||||
self.assertIn(('URRRu', (1, 2, 3)), cache_keys())
|
||||
|
||||
# This matches the previous call.
|
||||
defined(1, bar=2, baz=3)
|
||||
self.assertLen(defined._function_cache, 3)
|
||||
self.assertLen(total_function_cache(defined), 3)
|
||||
|
||||
# This matches the previous call.
|
||||
defined(1, baz=3, bar=2)
|
||||
self.assertLen(defined._function_cache, 3)
|
||||
self.assertLen(total_function_cache(defined), 3)
|
||||
|
||||
def testFunctoolsPartialUnwrappedCorrectly(self):
|
||||
|
||||
@ -1343,12 +1451,12 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
defined = function.defun(foo, input_signature=signature)
|
||||
a = array_ops.ones([2])
|
||||
self.assertAllEqual(a, defined(a))
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
self.assertAllEqual(a, defined.get_concrete_function()(a))
|
||||
self.assertAllEqual(a, defined.get_concrete_function(a)(a))
|
||||
self.assertAllEqual(a, defined.get_concrete_function(
|
||||
tensor_spec.TensorSpec((2,), dtype=dtypes.float32))(a))
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
def bar(a):
|
||||
self.assertEqual(a._shape_tuple(), (2, None))
|
||||
@ -1358,13 +1466,13 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
defined = function.defun(bar, input_signature=signature)
|
||||
a = array_ops.ones([2, 1])
|
||||
out = defined(a)
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
self.assertAllEqual(out, a)
|
||||
|
||||
# Changing the second dimension shouldn't create a new function.
|
||||
b = array_ops.ones([2, 3])
|
||||
out = defined(b)
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
self.assertAllEqual(out, b)
|
||||
|
||||
def testInputSignatureWithCompatibleInputs(self):
|
||||
@ -1405,7 +1513,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
b = array_ops.ones([1])
|
||||
expected = expected_foo([a, a], b)
|
||||
out = foo([a, a], b)
|
||||
self.assertLen(foo._function_cache, 1)
|
||||
self.assertLen(total_function_cache(foo), 1)
|
||||
nest.assert_same_structure(out, expected)
|
||||
self.assertAllEqual(out[0][0], a)
|
||||
self.assertAllEqual(out[0][1], a)
|
||||
@ -1417,7 +1525,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
c = array_ops.ones([1])
|
||||
expected = expected_foo([a, b], c)
|
||||
out = foo([a, b], c)
|
||||
self.assertLen(foo._function_cache, 1)
|
||||
self.assertLen(total_function_cache(foo), 1)
|
||||
nest.assert_same_structure(out, expected)
|
||||
self.assertAllEqual(out[0][0], a)
|
||||
self.assertAllEqual(out[0][1], b)
|
||||
@ -1428,7 +1536,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
b = b.numpy().tolist()
|
||||
c = c.numpy().tolist()
|
||||
out = foo([a, b], c)
|
||||
self.assertLen(foo._function_cache, 1)
|
||||
self.assertLen(total_function_cache(foo), 1)
|
||||
nest.assert_same_structure(out, expected)
|
||||
self.assertAllEqual(out[0][0], a)
|
||||
self.assertAllEqual(out[0][1], b)
|
||||
@ -1597,22 +1705,22 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
integer = constant_op.constant(2, dtypes.int64)
|
||||
|
||||
out1, out2 = foo(flt, integer)
|
||||
self.assertLen(foo._function_cache, 1)
|
||||
self.assertLen(total_function_cache(foo), 1)
|
||||
self.assertEqual(out1.numpy(), 1.0)
|
||||
self.assertEqual(out2.numpy(), 2)
|
||||
|
||||
out1, out2 = foo(flt=flt, integer=integer)
|
||||
self.assertLen(foo._function_cache, 1)
|
||||
self.assertLen(total_function_cache(foo), 1)
|
||||
self.assertEqual(out1.numpy(), 1.0)
|
||||
self.assertEqual(out2.numpy(), 2)
|
||||
|
||||
out1, out2 = foo(integer=integer, flt=flt)
|
||||
self.assertLen(foo._function_cache, 1)
|
||||
self.assertLen(total_function_cache(foo), 1)
|
||||
self.assertEqual(out1.numpy(), 1.0)
|
||||
self.assertEqual(out2.numpy(), 2)
|
||||
|
||||
out1, out2 = foo(flt, integer=integer)
|
||||
self.assertLen(foo._function_cache, 1)
|
||||
self.assertLen(total_function_cache(foo), 1)
|
||||
self.assertEqual(out1.numpy(), 1.0)
|
||||
self.assertEqual(out2.numpy(), 2)
|
||||
|
||||
@ -1642,27 +1750,27 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
a = constant_op.constant(2.0)
|
||||
b = constant_op.constant([1.0, 2.0])
|
||||
one = defined(a, b)
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
two = defined(a=a, b=b)
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
three = defined(b=b, a=a)
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
four = defined(a, b=b)
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
# The next call corresponds to a new input signature, hence
|
||||
# we expect another function to be defined.
|
||||
five = defined(b, a)
|
||||
self.assertLen(defined._function_cache, 2)
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
six = defined(a=b, b=a)
|
||||
self.assertLen(defined._function_cache, 2)
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
seven = defined(b=a, a=b)
|
||||
self.assertLen(defined._function_cache, 2)
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
self.assertAllEqual(one, [1.0, 2.0])
|
||||
self.assertAllEqual(two, [1.0, 2.0])
|
||||
@ -2049,18 +2157,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
with ops.Graph().as_default():
|
||||
x = constant_op.constant(11)
|
||||
maybe_add(x, True)
|
||||
self.assertLen(maybe_add._function_cache, 1)
|
||||
self.assertLen(add._function_cache, 1)
|
||||
self.assertLen(total_function_cache(maybe_add), 1)
|
||||
self.assertLen(total_function_cache(add), 1)
|
||||
|
||||
maybe_add(x, False)
|
||||
self.assertLen(maybe_add._function_cache, 2)
|
||||
self.assertLen(add._function_cache, 1)
|
||||
self.assertLen(total_function_cache(maybe_add), 2)
|
||||
self.assertLen(total_function_cache(add), 1)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
x = constant_op.constant(11)
|
||||
maybe_add(x, True)
|
||||
self.assertLen(maybe_add._function_cache, 3)
|
||||
self.assertLen(add._function_cache, 2)
|
||||
self.assertLen(total_function_cache(maybe_add), 3)
|
||||
self.assertLen(total_function_cache(add), 2)
|
||||
|
||||
def testCacheKeyOverlappingShapes(self):
|
||||
@function.defun
|
||||
@ -2068,10 +2176,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
return t
|
||||
|
||||
defined(array_ops.zeros([12, 1]))
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
defined(array_ops.zeros([1, 21]))
|
||||
self.assertLen(defined._function_cache, 2)
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
def testCacheKeyNestedLists(self):
|
||||
@function.defun
|
||||
@ -2082,10 +2190,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
b = constant_op.constant(2.)
|
||||
c = constant_op.constant(3.)
|
||||
defined([[a], b, c])
|
||||
self.assertLen(defined._function_cache, 1)
|
||||
self.assertLen(total_function_cache(defined), 1)
|
||||
|
||||
defined([[a, b], c])
|
||||
self.assertLen(defined._function_cache, 2)
|
||||
self.assertLen(total_function_cache(defined), 2)
|
||||
|
||||
def testDecoratedMethod(self):
|
||||
m = DefunnedMiniModel()
|
||||
@ -2647,6 +2755,7 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
result = func(g1, g2, c1, g3, c2)
|
||||
self.assertEqual(result.numpy(), 5.0 * 7.0 * 17.0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution(
|
||||
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
|
||||
|
@ -231,8 +231,11 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim);
|
||||
PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor);
|
||||
|
||||
// Encodes the object as a tuple that is meant to be used as part of the key
|
||||
// for the defun function cache.
|
||||
PyObject* TFE_Py_EncodeArg(PyObject*);
|
||||
// for the defun function cache. If `include_tensor_ranks_only` is true,
|
||||
// then the encoding only stores tensor ranks, and the key is
|
||||
// agnostic to dimension sizes. Otherwise, full tensor shape encodings are
|
||||
// returned.
|
||||
PyObject* TFE_Py_EncodeArg(PyObject*, bool include_tensor_ranks_only);
|
||||
|
||||
void TFE_Py_EnableInteractivePythonLogging();
|
||||
|
||||
|
@ -2913,7 +2913,9 @@ struct EncodeResult {
|
||||
}
|
||||
};
|
||||
|
||||
tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
|
||||
tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
|
||||
bool include_tensor_ranks_only,
|
||||
EncodeResult* result) {
|
||||
if (EagerTensor_CheckExact(arg)) {
|
||||
TFE_TensorHandle* t = EagerTensor_Handle(arg);
|
||||
tensorflow::TensorShape tensor_shape;
|
||||
@ -2922,10 +2924,13 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
|
||||
absl::StrAppend(&result->str, kDType, t->handle->dtype);
|
||||
|
||||
absl::StrAppend(&result->str, kShape);
|
||||
for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) {
|
||||
absl::StrAppend(&result->str, dim_size, kShapeDelim);
|
||||
if (include_tensor_ranks_only) {
|
||||
absl::StrAppend(&result->str, tensor_shape.dim_sizes().size());
|
||||
} else {
|
||||
for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) {
|
||||
absl::StrAppend(&result->str, dim_size, kShapeDelim);
|
||||
}
|
||||
}
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
@ -2949,6 +2954,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
|
||||
static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
|
||||
|
||||
absl::StrAppend(&result->str, kDType, dtype);
|
||||
|
||||
static char _shape_tuple[] = "_shape_tuple";
|
||||
tensorflow::Safe_PyObjectPtr shape_tuple(
|
||||
PyObject_CallMethod(arg, _shape_tuple, nullptr));
|
||||
@ -2969,23 +2975,30 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
|
||||
shape_tuple.get(), "shape_tuple didn't return a sequence"));
|
||||
|
||||
int len = PySequence_Fast_GET_SIZE(shape_seq.get());
|
||||
for (int i = 0; i < len; ++i) {
|
||||
PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i);
|
||||
if (item == Py_None) {
|
||||
absl::StrAppend(&result->str, kNone);
|
||||
} else {
|
||||
absl::StrAppend(&result->str, MakeInt(item));
|
||||
|
||||
if (include_tensor_ranks_only) {
|
||||
absl::StrAppend(&result->str, len);
|
||||
} else {
|
||||
for (int i = 0; i < len; ++i) {
|
||||
PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i);
|
||||
if (item == Py_None) {
|
||||
absl::StrAppend(&result->str, kNone);
|
||||
} else {
|
||||
absl::StrAppend(&result->str, MakeInt(item));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result);
|
||||
tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
|
||||
bool include_tensor_ranks_only,
|
||||
EncodeResult* result);
|
||||
|
||||
// This function doesn't set the type of sequence before
|
||||
tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
|
||||
const char* end_type,
|
||||
bool include_tensor_ranks_only,
|
||||
EncodeResult* result) {
|
||||
tensorflow::Safe_PyObjectPtr arg_seq(
|
||||
PySequence_Fast(arg, "unable to create seq from list/tuple"));
|
||||
@ -2997,7 +3010,8 @@ tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
|
||||
if (item == Py_None) {
|
||||
absl::StrAppend(&result->str, kNone);
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(item, result));
|
||||
TF_RETURN_IF_ERROR(
|
||||
TFE_Py_EncodeArgHelper(item, include_tensor_ranks_only, result));
|
||||
}
|
||||
}
|
||||
absl::StrAppend(&result->str, end_type);
|
||||
@ -3005,10 +3019,13 @@ tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) {
|
||||
tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
|
||||
bool include_tensor_ranks_only,
|
||||
EncodeResult* result) {
|
||||
if (tensorflow::swig::IsTensor(arg)) {
|
||||
absl::StrAppend(&result->str, kTensor);
|
||||
TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(arg, result));
|
||||
TF_RETURN_IF_ERROR(
|
||||
TFE_Py_EncodeTensor(arg, include_tensor_ranks_only, result));
|
||||
} else if (tensorflow::swig::IsIndexedSlices(arg)) {
|
||||
absl::StrAppend(&result->str, kIndexedSlices);
|
||||
tensorflow::Safe_PyObjectPtr values(PyObject_GetAttrString(arg, "values"));
|
||||
@ -3017,7 +3034,8 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"IndexedSlices does not have a values attr");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(values.get(), result));
|
||||
TF_RETURN_IF_ERROR(
|
||||
TFE_Py_EncodeTensor(values.get(), include_tensor_ranks_only, result));
|
||||
|
||||
tensorflow::Safe_PyObjectPtr indices(
|
||||
PyObject_GetAttrString(arg, "indices"));
|
||||
@ -3026,7 +3044,8 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"IndexedSlices does not have a indices attr");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(indices.get(), result));
|
||||
TF_RETURN_IF_ERROR(
|
||||
TFE_Py_EncodeTensor(indices.get(), include_tensor_ranks_only, result));
|
||||
|
||||
tensorflow::Safe_PyObjectPtr dense_shape(
|
||||
PyObject_GetAttrString(arg, "dense_shape"));
|
||||
@ -3036,12 +3055,15 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) {
|
||||
"IndexedSlices does not have a dense_shape attr");
|
||||
}
|
||||
if (dense_shape.get() != Py_None) {
|
||||
TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(dense_shape.get(), result));
|
||||
TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(
|
||||
dense_shape.get(), include_tensor_ranks_only, result));
|
||||
}
|
||||
} else if (PyList_Check(arg)) {
|
||||
TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kList, kListEnd, result));
|
||||
TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
|
||||
arg, kList, kListEnd, include_tensor_ranks_only, result));
|
||||
} else if (PyTuple_Check(arg)) {
|
||||
TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kTuple, kTupleEnd, result));
|
||||
TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
|
||||
arg, kTuple, kTupleEnd, include_tensor_ranks_only, result));
|
||||
} else if (PyDict_Check(arg)) {
|
||||
tensorflow::Safe_PyObjectPtr keys(PyDict_Keys(arg));
|
||||
if (PyList_Sort(keys.get()) == -1) {
|
||||
@ -3053,9 +3075,11 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) {
|
||||
|
||||
for (int i = 0; i < len; i++) {
|
||||
PyObject* key = PyList_GetItem(keys.get(), i);
|
||||
TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(key, result));
|
||||
TF_RETURN_IF_ERROR(
|
||||
TFE_Py_EncodeArgHelper(key, include_tensor_ranks_only, result));
|
||||
PyObject* value = PyDict_GetItem(arg, key);
|
||||
TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(value, result));
|
||||
TF_RETURN_IF_ERROR(
|
||||
TFE_Py_EncodeArgHelper(value, include_tensor_ranks_only, result));
|
||||
}
|
||||
} else {
|
||||
PyObject* object = PyWeakref_NewRef(arg, nullptr);
|
||||
@ -3082,10 +3106,15 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) {
|
||||
// on known shapes to produce slimmer graphs, and correctness, as some
|
||||
// high-level APIs require shapes to be fully-known.
|
||||
//
|
||||
// `include_tensor_ranks_only` allows caching on arguments excluding shape info,
|
||||
// so that a slow path using relaxed shape can rely on a cache key that excludes
|
||||
// shapes.
|
||||
//
|
||||
// TODO(nareshmodi): Add support for sparse tensors.
|
||||
PyObject* TFE_Py_EncodeArg(PyObject* arg) {
|
||||
PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only) {
|
||||
EncodeResult result;
|
||||
const auto status = TFE_Py_EncodeArgHelper(arg, &result);
|
||||
const auto status =
|
||||
TFE_Py_EncodeArgHelper(arg, include_tensor_ranks_only, &result);
|
||||
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections as py_collections
|
||||
import itertools
|
||||
import weakref
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
@ -468,7 +469,8 @@ def func_graph_from_py_func(name,
|
||||
arg_names=None,
|
||||
op_return_value=None,
|
||||
collections=None,
|
||||
capture_by_value=None):
|
||||
capture_by_value=None,
|
||||
override_flat_arg_shapes=None):
|
||||
"""Returns a `FuncGraph` generated from `python_func`.
|
||||
|
||||
Args:
|
||||
@ -507,6 +509,12 @@ def func_graph_from_py_func(name,
|
||||
capture_by_value: An optional boolean. If True, the func graph will capture
|
||||
Variables by value instead of reference. By default inherit from outer
|
||||
graphs, and failing that will default to False.
|
||||
override_flat_arg_shapes: An optional list of instances that are either
|
||||
`None` or `TensorShape`. The length must match that of
|
||||
`nest.flatten((args, kwargs))`. The entries containing value `None`
|
||||
must match entries in flattened arguments containing non-tensors, while
|
||||
entries containing a `TensorShape` must match entries in the flattened
|
||||
arguments containing tensors.
|
||||
|
||||
Returns:
|
||||
A FuncGraph.
|
||||
@ -514,6 +522,8 @@ def func_graph_from_py_func(name,
|
||||
Raises:
|
||||
TypeError: If any of `python_func`'s return values is neither `None` nor a
|
||||
`Tensor`.
|
||||
ValueError: If both `signature` and `override_flat_arg_shapes` are
|
||||
passed in.
|
||||
"""
|
||||
if op_return_value is not None:
|
||||
assert isinstance(op_return_value, ops.Tensor), op_return_value
|
||||
@ -530,13 +540,27 @@ def func_graph_from_py_func(name,
|
||||
default_use_recource = current_scope.use_resource
|
||||
current_scope.set_use_resource(True)
|
||||
|
||||
if signature is not None and override_flat_arg_shapes is not None:
|
||||
raise ValueError(
|
||||
"Passed both signature and override_flat_arg_shapes: %s and %s."
|
||||
% (signature, override_flat_arg_shapes))
|
||||
|
||||
if signature is not None:
|
||||
args = signature
|
||||
kwargs = {}
|
||||
|
||||
# Creates and names placeholders for all arguments.
|
||||
func_args = _get_defun_inputs_from_args(args, arg_names)
|
||||
func_kwargs = _get_defun_inputs_from_kwargs(kwargs)
|
||||
if override_flat_arg_shapes is not None:
|
||||
flat_args = nest.flatten(args)
|
||||
arg_shapes = override_flat_arg_shapes[:len(flat_args)]
|
||||
kwarg_shapes = override_flat_arg_shapes[len(flat_args):]
|
||||
else:
|
||||
arg_shapes = None
|
||||
kwarg_shapes = None
|
||||
func_args = _get_defun_inputs_from_args(
|
||||
args, arg_names, flat_shapes=arg_shapes)
|
||||
func_kwargs = _get_defun_inputs_from_kwargs(
|
||||
kwargs, flat_shapes=kwarg_shapes)
|
||||
|
||||
# Convert all Tensors into TensorSpecs before saving the structured inputs.
|
||||
# If storing pure concrete functions that are not called through polymorphic
|
||||
@ -761,43 +785,73 @@ def _create_substitute_placeholder(value, name=None, dtype=None):
|
||||
return placeholder
|
||||
|
||||
|
||||
def _get_defun_inputs_from_args(args, names):
|
||||
def _get_defun_inputs_from_args(args, names, flat_shapes=None):
|
||||
"""Maps Python function positional args to graph-construction inputs."""
|
||||
return _get_defun_inputs(args, names, structure=args)
|
||||
return _get_defun_inputs(
|
||||
args, names, structure=args, flat_shapes=flat_shapes)
|
||||
|
||||
|
||||
def _get_defun_inputs(flat_args, names, structure):
|
||||
def _get_defun_inputs(args, names, structure, flat_shapes=None):
|
||||
"""Maps python function args to graph-construction inputs.
|
||||
|
||||
Args:
|
||||
flat_args: A flat list of user-specified arguments.
|
||||
args: A flat list of user-specified arguments.
|
||||
names: A list of strings with user-specified argument names, same length as
|
||||
`flat_args`. May be `None`, in which case a generic name is used.
|
||||
`args`. May be `None`, in which case a generic name is used.
|
||||
structure: The original argument list or dictionary.
|
||||
flat_shapes: A flat list of values that are either `None` or
|
||||
instances of `TensorShape`. If provided, then length must match
|
||||
that of `nest.flatten(args)`; and locations where `args` are
|
||||
instances of `Tensor` must have a corresponding `TensorShape` in
|
||||
`flat_shapes`. May be `None`, in which case exact shapes are read
|
||||
directly from the args.
|
||||
|
||||
Returns:
|
||||
Placeholders with the same structure as `structure`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if `flat_shapes` is provided, but
|
||||
`len(flat_shapes) != len(nest.flatten(args))`.
|
||||
RuntimeError: if a shape from `flat_shapes` is not None
|
||||
for an argument that is not a `Tensor`, `TensorSpec`,
|
||||
or `ResourceVariable`.
|
||||
"""
|
||||
func_graph = ops.get_default_graph()
|
||||
function_inputs = []
|
||||
if names is None:
|
||||
names = [None] * len(flat_args)
|
||||
for arg_value, name in zip(flat_args, names):
|
||||
names = [None] * len(args)
|
||||
if flat_shapes is None:
|
||||
shapes_iter = itertools.repeat(None)
|
||||
else:
|
||||
len_flat_args = len(nest.flatten(args))
|
||||
if len_flat_args != len(flat_shapes):
|
||||
raise RuntimeError(
|
||||
"Length of fully flat shapes (%d) must match that of "
|
||||
"flatten(args) (%d). args: %s, flat_shapes: %s"
|
||||
% (len(flat_shapes),
|
||||
len_flat_args,
|
||||
args,
|
||||
flat_shapes))
|
||||
shapes_iter = iter(flat_shapes)
|
||||
for arg_value, name in zip(args, names):
|
||||
for arg in nest.flatten(arg_value):
|
||||
# We have a shape entry for each arg, regadless of whether it's a real
|
||||
# Tensor or not. For non-tensor entries it should be None.
|
||||
shape = next(shapes_iter)
|
||||
if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)):
|
||||
if isinstance(arg, tensor_spec.TensorSpec) and arg.name:
|
||||
requested_name = arg.name
|
||||
else:
|
||||
requested_name = name
|
||||
|
||||
placeholder_shape = shape if shape is not None else arg.shape
|
||||
try:
|
||||
placeholder = graph_placeholder(
|
||||
arg.dtype, arg.shape,
|
||||
arg.dtype, placeholder_shape,
|
||||
name=requested_name)
|
||||
except ValueError:
|
||||
# Sometimes parameter names are not valid op names, so fall back to
|
||||
# unnamed placeholders.
|
||||
placeholder = graph_placeholder(arg.dtype, arg.shape)
|
||||
placeholder = graph_placeholder(arg.dtype, placeholder_shape)
|
||||
if name is not None:
|
||||
# Record the requested/user-specified name in case it's different than
|
||||
# the uniquified name, for validation when exporting signatures.
|
||||
@ -816,18 +870,24 @@ def _get_defun_inputs(flat_args, names, structure):
|
||||
attr_value_pb2.AttrValue(s=compat.as_bytes(name)))
|
||||
function_inputs.append(arg)
|
||||
else:
|
||||
if shape is not None:
|
||||
raise RuntimeError(
|
||||
"Expected provided shape override to be None for arg that isn't "
|
||||
"a Tensor, but saw arg: '%s', shape: '%s'. args: %s"
|
||||
% (arg, shape, args))
|
||||
function_inputs.append(arg)
|
||||
return nest.pack_sequence_as(structure, function_inputs)
|
||||
|
||||
|
||||
def _get_defun_inputs_from_kwargs(kwargs):
|
||||
def _get_defun_inputs_from_kwargs(kwargs, flat_shapes):
|
||||
"""Maps Python function keyword args to graph-construction inputs."""
|
||||
if kwargs:
|
||||
names, flat_args = zip(*sorted(kwargs.items()))
|
||||
names, args = zip(*sorted(kwargs.items()))
|
||||
else:
|
||||
names = []
|
||||
flat_args = []
|
||||
return _get_defun_inputs(flat_args, names, structure=kwargs)
|
||||
args = []
|
||||
return _get_defun_inputs(
|
||||
args, names, structure=kwargs, flat_shapes=flat_shapes)
|
||||
|
||||
|
||||
def dismantle_func_graph(func_graph):
|
||||
|
Loading…
x
Reference in New Issue
Block a user