Auto-uniquify ConcreteFunction kwargs, and allow passing anything as a positional argument

Same format as ConcreteFunction.inputs without the captures, nest.flatten(args, expand_composites=True) + nest.flatten(kwargs, expand_composities=True) filtered for Tensors only.

ConcreteFunction won't complain about anything now. Actually calling the resulting function may be confusing, and we can in the future make a more restricted API which only accepts pure Tensor structures where the ConcreteFunction can be called with those structures. But taking mixed Python/Tensor structures is probably not a good idea.

PiperOrigin-RevId: 249708123
This commit is contained in:
Allen Lavoie 2019-05-23 13:49:16 -07:00 committed by TensorFlower Gardener
parent 81bb13c636
commit 318ed830c7
5 changed files with 45 additions and 47 deletions

View File

@ -448,6 +448,31 @@ class DefFunctionTest(test.TestCase):
re.compile('An op outside of the function.*passed.*Const', re.DOTALL)):
failing_function()
def testNonUniqueNamesGetConcreteFunction(self):
@def_function.function
def non_unique_arg_names(x, **kwargs):
a, b, c = x
d = kwargs['d']
return a + b + c + d
concrete = non_unique_arg_names.get_concrete_function(
(tensor_spec.TensorSpec(None, dtypes.float32),
tensor_spec.TensorSpec(None, dtypes.float32),
tensor_spec.TensorSpec(None, dtypes.float32)),
d=tensor_spec.TensorSpec(None, dtypes.float32))
self.assertAllClose(
10.,
concrete(x=constant_op.constant(1.),
x_1=constant_op.constant(2.),
x_2=constant_op.constant(3.),
d=constant_op.constant(4.)))
self.assertAllClose(
10.,
concrete(constant_op.constant(1.),
constant_op.constant(2.),
constant_op.constant(3.),
constant_op.constant(4.)))
def testVariableCreatorScope(self):
created_variables = []
captured_variables = []

View File

@ -1406,40 +1406,25 @@ class Function(object):
kwargs = {}
seen_names = set()
captured = frozenset(graph_function.graph.internal_captures)
allowed_positional = 0
if args:
for outer_arg in args:
# TODO(allenl): Consider allowing arguments with defaults in the Python
# function's signature to be passed as positional arguments to the
# concrete function.
if not isinstance(
outer_arg,
(ops.Tensor, resource_variable_ops.ResourceVariable,
tensor_spec.TensorSpec)):
break
allowed_positional += 1
# pylint: disable=protected-access
graph_function._num_positional_args = allowed_positional
graph_function._arg_keywords = []
prefix_counts = {}
# pylint: enable=protected-access
num_positional = 0
for arg in graph_function.graph.inputs:
if arg in captured:
break
user_arg_name = arg.op.get_attr("_user_specified_name")
if user_arg_name in seen_names:
raise ValueError(
("Unable to construct a concrete function for {} since some "
"arguments do not have unique names. Got two arguments named "
"'{}'. When constructing a concrete TensorFlow function from a "
"Python function which takes nested structures or variadic "
"positional arguments, pass unique names to tf.TensorSpec objects "
"used to identify these Tensor inputs. These names may then be "
"used as keyword arguments to the concrete function.")
.format(
self._python_function,
compat.as_str(arg.op.get_attr("_user_specified_name"))))
seen_names.add(user_arg_name)
graph_function._arg_keywords.append(user_arg_name) # pylint: disable=protected-access
num_positional += 1
user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name"))
proposal = user_arg_name
while proposal in seen_names:
index = prefix_counts.get(user_arg_name, 1)
proposal = "{}_{}".format(user_arg_name, index)
prefix_counts[user_arg_name] = index + 1
seen_names.add(proposal)
graph_function._arg_keywords.append(proposal) # pylint: disable=protected-access
# Anything can be a positional argument, in the same order as .inputs
graph_function._num_positional_args = num_positional # pylint: disable=protected-access
return graph_function
def __get__(self, instance, owner):

View File

@ -100,13 +100,6 @@ class ArgumentNamingTests(test.TestCase, parameterized.TestCase):
self.assertEqual({'alpha', 'beta'},
set(fn_op.graph.structured_outputs.keys()))
with self.assertRaisesRegexp(ValueError, "two arguments named 'z'"):
fn.get_concrete_function(
z=(tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.float32),
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)),
y=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32,
name='custom'),
x=4.)
fn_op2 = fn.get_concrete_function(
z=(tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.float32,
name='z_first'),

View File

@ -385,10 +385,11 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
((a, b),) = mats
return matmul(a, b)
with self.assertRaisesRegexp(ValueError, "two arguments named 'mats'"):
sq.get_concrete_function(
[(tensor_spec.TensorSpec((None, None), dtypes.float32),
tensor_spec.TensorSpec((None, None), dtypes.float32))])
sq_op_autonamed = sq.get_concrete_function(
[(tensor_spec.TensorSpec((None, None), dtypes.float32),
tensor_spec.TensorSpec((None, None), dtypes.float32))])
self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list())
sq_op = sq.get_concrete_function(
[(tensor_spec.TensorSpec((None, None), dtypes.float32,
name='first_mat'),
@ -398,11 +399,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]])
with self.assertRaisesRegexp(
TypeError, 'bound to Tensors within nested structures'):
sq_op(t1, t2)
out = sq_op(first_mat=t1, second_mat=t2)
self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy())
self.assertAllEqual(sq_op_autonamed(t1, t2),
math_ops.matmul(t1, t2).numpy())
def testExecutingStatelessDefunConcurrently(self):

View File

@ -171,11 +171,6 @@ class SaveTest(test.TestCase):
input_signature=([tensor_spec.TensorSpec(None, dtypes.float32),
tensor_spec.TensorSpec(None, dtypes.float32)],))
root.f([constant_op.constant(1.), constant_op.constant(1.)])
# Concrete functions must always have uniquely named Tensor inputs. Save
# relies on this.
with self.assertRaisesRegexp(
ValueError, "two arguments named 'x'"):
root.f.get_concrete_function()
def test_nested_outputs(self):
root = tracking.AutoTrackable()