Auto-uniquify ConcreteFunction kwargs, and allow passing anything as a positional argument
Same format as ConcreteFunction.inputs without the captures, nest.flatten(args, expand_composites=True) + nest.flatten(kwargs, expand_composities=True) filtered for Tensors only. ConcreteFunction won't complain about anything now. Actually calling the resulting function may be confusing, and we can in the future make a more restricted API which only accepts pure Tensor structures where the ConcreteFunction can be called with those structures. But taking mixed Python/Tensor structures is probably not a good idea. PiperOrigin-RevId: 249708123
This commit is contained in:
parent
81bb13c636
commit
318ed830c7
@ -448,6 +448,31 @@ class DefFunctionTest(test.TestCase):
|
||||
re.compile('An op outside of the function.*passed.*Const', re.DOTALL)):
|
||||
failing_function()
|
||||
|
||||
def testNonUniqueNamesGetConcreteFunction(self):
|
||||
@def_function.function
|
||||
def non_unique_arg_names(x, **kwargs):
|
||||
a, b, c = x
|
||||
d = kwargs['d']
|
||||
return a + b + c + d
|
||||
|
||||
concrete = non_unique_arg_names.get_concrete_function(
|
||||
(tensor_spec.TensorSpec(None, dtypes.float32),
|
||||
tensor_spec.TensorSpec(None, dtypes.float32),
|
||||
tensor_spec.TensorSpec(None, dtypes.float32)),
|
||||
d=tensor_spec.TensorSpec(None, dtypes.float32))
|
||||
self.assertAllClose(
|
||||
10.,
|
||||
concrete(x=constant_op.constant(1.),
|
||||
x_1=constant_op.constant(2.),
|
||||
x_2=constant_op.constant(3.),
|
||||
d=constant_op.constant(4.)))
|
||||
self.assertAllClose(
|
||||
10.,
|
||||
concrete(constant_op.constant(1.),
|
||||
constant_op.constant(2.),
|
||||
constant_op.constant(3.),
|
||||
constant_op.constant(4.)))
|
||||
|
||||
def testVariableCreatorScope(self):
|
||||
created_variables = []
|
||||
captured_variables = []
|
||||
|
@ -1406,40 +1406,25 @@ class Function(object):
|
||||
kwargs = {}
|
||||
seen_names = set()
|
||||
captured = frozenset(graph_function.graph.internal_captures)
|
||||
allowed_positional = 0
|
||||
if args:
|
||||
for outer_arg in args:
|
||||
# TODO(allenl): Consider allowing arguments with defaults in the Python
|
||||
# function's signature to be passed as positional arguments to the
|
||||
# concrete function.
|
||||
if not isinstance(
|
||||
outer_arg,
|
||||
(ops.Tensor, resource_variable_ops.ResourceVariable,
|
||||
tensor_spec.TensorSpec)):
|
||||
break
|
||||
allowed_positional += 1
|
||||
# pylint: disable=protected-access
|
||||
graph_function._num_positional_args = allowed_positional
|
||||
graph_function._arg_keywords = []
|
||||
prefix_counts = {}
|
||||
# pylint: enable=protected-access
|
||||
num_positional = 0
|
||||
for arg in graph_function.graph.inputs:
|
||||
if arg in captured:
|
||||
break
|
||||
user_arg_name = arg.op.get_attr("_user_specified_name")
|
||||
if user_arg_name in seen_names:
|
||||
raise ValueError(
|
||||
("Unable to construct a concrete function for {} since some "
|
||||
"arguments do not have unique names. Got two arguments named "
|
||||
"'{}'. When constructing a concrete TensorFlow function from a "
|
||||
"Python function which takes nested structures or variadic "
|
||||
"positional arguments, pass unique names to tf.TensorSpec objects "
|
||||
"used to identify these Tensor inputs. These names may then be "
|
||||
"used as keyword arguments to the concrete function.")
|
||||
.format(
|
||||
self._python_function,
|
||||
compat.as_str(arg.op.get_attr("_user_specified_name"))))
|
||||
seen_names.add(user_arg_name)
|
||||
graph_function._arg_keywords.append(user_arg_name) # pylint: disable=protected-access
|
||||
num_positional += 1
|
||||
user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name"))
|
||||
proposal = user_arg_name
|
||||
while proposal in seen_names:
|
||||
index = prefix_counts.get(user_arg_name, 1)
|
||||
proposal = "{}_{}".format(user_arg_name, index)
|
||||
prefix_counts[user_arg_name] = index + 1
|
||||
seen_names.add(proposal)
|
||||
graph_function._arg_keywords.append(proposal) # pylint: disable=protected-access
|
||||
# Anything can be a positional argument, in the same order as .inputs
|
||||
graph_function._num_positional_args = num_positional # pylint: disable=protected-access
|
||||
return graph_function
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
|
@ -100,13 +100,6 @@ class ArgumentNamingTests(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual({'alpha', 'beta'},
|
||||
set(fn_op.graph.structured_outputs.keys()))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "two arguments named 'z'"):
|
||||
fn.get_concrete_function(
|
||||
z=(tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.float32),
|
||||
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)),
|
||||
y=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32,
|
||||
name='custom'),
|
||||
x=4.)
|
||||
fn_op2 = fn.get_concrete_function(
|
||||
z=(tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.float32,
|
||||
name='z_first'),
|
||||
|
@ -385,10 +385,11 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
((a, b),) = mats
|
||||
return matmul(a, b)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "two arguments named 'mats'"):
|
||||
sq.get_concrete_function(
|
||||
[(tensor_spec.TensorSpec((None, None), dtypes.float32),
|
||||
tensor_spec.TensorSpec((None, None), dtypes.float32))])
|
||||
sq_op_autonamed = sq.get_concrete_function(
|
||||
[(tensor_spec.TensorSpec((None, None), dtypes.float32),
|
||||
tensor_spec.TensorSpec((None, None), dtypes.float32))])
|
||||
self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list())
|
||||
|
||||
sq_op = sq.get_concrete_function(
|
||||
[(tensor_spec.TensorSpec((None, None), dtypes.float32,
|
||||
name='first_mat'),
|
||||
@ -398,11 +399,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||
t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]])
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, 'bound to Tensors within nested structures'):
|
||||
sq_op(t1, t2)
|
||||
out = sq_op(first_mat=t1, second_mat=t2)
|
||||
self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy())
|
||||
self.assertAllEqual(sq_op_autonamed(t1, t2),
|
||||
math_ops.matmul(t1, t2).numpy())
|
||||
|
||||
def testExecutingStatelessDefunConcurrently(self):
|
||||
|
||||
|
@ -171,11 +171,6 @@ class SaveTest(test.TestCase):
|
||||
input_signature=([tensor_spec.TensorSpec(None, dtypes.float32),
|
||||
tensor_spec.TensorSpec(None, dtypes.float32)],))
|
||||
root.f([constant_op.constant(1.), constant_op.constant(1.)])
|
||||
# Concrete functions must always have uniquely named Tensor inputs. Save
|
||||
# relies on this.
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "two arguments named 'x'"):
|
||||
root.f.get_concrete_function()
|
||||
|
||||
def test_nested_outputs(self):
|
||||
root = tracking.AutoTrackable()
|
||||
|
Loading…
Reference in New Issue
Block a user