diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index fb248459527..2fcfde86bad 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -448,6 +448,31 @@ class DefFunctionTest(test.TestCase): re.compile('An op outside of the function.*passed.*Const', re.DOTALL)): failing_function() + def testNonUniqueNamesGetConcreteFunction(self): + @def_function.function + def non_unique_arg_names(x, **kwargs): + a, b, c = x + d = kwargs['d'] + return a + b + c + d + + concrete = non_unique_arg_names.get_concrete_function( + (tensor_spec.TensorSpec(None, dtypes.float32), + tensor_spec.TensorSpec(None, dtypes.float32), + tensor_spec.TensorSpec(None, dtypes.float32)), + d=tensor_spec.TensorSpec(None, dtypes.float32)) + self.assertAllClose( + 10., + concrete(x=constant_op.constant(1.), + x_1=constant_op.constant(2.), + x_2=constant_op.constant(3.), + d=constant_op.constant(4.))) + self.assertAllClose( + 10., + concrete(constant_op.constant(1.), + constant_op.constant(2.), + constant_op.constant(3.), + constant_op.constant(4.))) + def testVariableCreatorScope(self): created_variables = [] captured_variables = [] diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 52f6ad908ab..24377614031 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -1406,40 +1406,25 @@ class Function(object): kwargs = {} seen_names = set() captured = frozenset(graph_function.graph.internal_captures) - allowed_positional = 0 - if args: - for outer_arg in args: - # TODO(allenl): Consider allowing arguments with defaults in the Python - # function's signature to be passed as positional arguments to the - # concrete function. - if not isinstance( - outer_arg, - (ops.Tensor, resource_variable_ops.ResourceVariable, - tensor_spec.TensorSpec)): - break - allowed_positional += 1 # pylint: disable=protected-access - graph_function._num_positional_args = allowed_positional graph_function._arg_keywords = [] + prefix_counts = {} # pylint: enable=protected-access + num_positional = 0 for arg in graph_function.graph.inputs: if arg in captured: break - user_arg_name = arg.op.get_attr("_user_specified_name") - if user_arg_name in seen_names: - raise ValueError( - ("Unable to construct a concrete function for {} since some " - "arguments do not have unique names. Got two arguments named " - "'{}'. When constructing a concrete TensorFlow function from a " - "Python function which takes nested structures or variadic " - "positional arguments, pass unique names to tf.TensorSpec objects " - "used to identify these Tensor inputs. These names may then be " - "used as keyword arguments to the concrete function.") - .format( - self._python_function, - compat.as_str(arg.op.get_attr("_user_specified_name")))) - seen_names.add(user_arg_name) - graph_function._arg_keywords.append(user_arg_name) # pylint: disable=protected-access + num_positional += 1 + user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name")) + proposal = user_arg_name + while proposal in seen_names: + index = prefix_counts.get(user_arg_name, 1) + proposal = "{}_{}".format(user_arg_name, index) + prefix_counts[user_arg_name] = index + 1 + seen_names.add(proposal) + graph_function._arg_keywords.append(proposal) # pylint: disable=protected-access + # Anything can be a positional argument, in the same order as .inputs + graph_function._num_positional_args = num_positional # pylint: disable=protected-access return graph_function def __get__(self, instance, owner): diff --git a/tensorflow/python/eager/function_argument_naming_test.py b/tensorflow/python/eager/function_argument_naming_test.py index 08a50a8f513..4e6a60e0d27 100644 --- a/tensorflow/python/eager/function_argument_naming_test.py +++ b/tensorflow/python/eager/function_argument_naming_test.py @@ -100,13 +100,6 @@ class ArgumentNamingTests(test.TestCase, parameterized.TestCase): self.assertEqual({'alpha', 'beta'}, set(fn_op.graph.structured_outputs.keys())) - with self.assertRaisesRegexp(ValueError, "two arguments named 'z'"): - fn.get_concrete_function( - z=(tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)), - y=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32, - name='custom'), - x=4.) fn_op2 = fn.get_concrete_function( z=(tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.float32, name='z_first'), diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 9ef84b86d75..84c6f529069 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -385,10 +385,11 @@ class FunctionTest(test.TestCase, parameterized.TestCase): ((a, b),) = mats return matmul(a, b) - with self.assertRaisesRegexp(ValueError, "two arguments named 'mats'"): - sq.get_concrete_function( - [(tensor_spec.TensorSpec((None, None), dtypes.float32), - tensor_spec.TensorSpec((None, None), dtypes.float32))]) + sq_op_autonamed = sq.get_concrete_function( + [(tensor_spec.TensorSpec((None, None), dtypes.float32), + tensor_spec.TensorSpec((None, None), dtypes.float32))]) + self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list()) + sq_op = sq.get_concrete_function( [(tensor_spec.TensorSpec((None, None), dtypes.float32, name='first_mat'), @@ -398,11 +399,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase): t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]]) - with self.assertRaisesRegexp( - TypeError, 'bound to Tensors within nested structures'): - sq_op(t1, t2) out = sq_op(first_mat=t1, second_mat=t2) self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy()) + self.assertAllEqual(sq_op_autonamed(t1, t2), + math_ops.matmul(t1, t2).numpy()) def testExecutingStatelessDefunConcurrently(self): diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index 14cb5abc075..a2ee86f6a9a 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -171,11 +171,6 @@ class SaveTest(test.TestCase): input_signature=([tensor_spec.TensorSpec(None, dtypes.float32), tensor_spec.TensorSpec(None, dtypes.float32)],)) root.f([constant_op.constant(1.), constant_op.constant(1.)]) - # Concrete functions must always have uniquely named Tensor inputs. Save - # relies on this. - with self.assertRaisesRegexp( - ValueError, "two arguments named 'x'"): - root.f.get_concrete_function() def test_nested_outputs(self): root = tracking.AutoTrackable()