From 794788971b5dd4e792725410cd5cb53c15a5ab41 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Mon, 12 Nov 2018 09:03:08 -0800 Subject: [PATCH] Fix fetching concrete functions from def_function using only an input_signature Used when exporting functions to SavedModel PiperOrigin-RevId: 221100918 --- tensorflow/python/eager/def_function.py | 7 ++++++- tensorflow/python/eager/def_function_test.py | 13 +++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 0b138acc86a..6f475509d3a 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -285,7 +285,12 @@ class PolymorphicFunction(object): self._stateless_fn = self._defun_with_scope(invalid_creator_scope) self._stateless_fn._name = self._name # pylint: disable=protected-access - return self._stateful_fn._canonicalize_function_inputs(*args, **kwds) # pylint: disable=protected-access + if self._input_signature is None or args or kwds: + return self._stateful_fn._canonicalize_function_inputs(*args, **kwds) # pylint: disable=protected-access + # If an input signature is defined, we may need to fetch a concrete function + # without any inputs specified. In this case args and kwds should be ignored + # but running _canonicalize_function_inputs would raise an exception. + return (), {} def __call__(self, *args, **kwds): """Calls the graph function.""" diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index 07e03a07401..f0f71a219e6 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -213,6 +213,19 @@ class DefFunctionTest(test.TestCase): model = _ModelWithOptimizer() model(x, y) + def test_concrete_function_from_signature(self): + + @def_function.function( + input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) + def compute(x): + return 2. * x + + concrete = compute.get_concrete_function() + self.assertAllClose(1., concrete(constant_op.constant(0.5))) + concrete = compute.get_concrete_function( + tensor_spec.TensorSpec(None, dtypes.float32)) + self.assertAllClose(4., concrete(constant_op.constant(2.))) + if __name__ == '__main__': ops.enable_eager_execution()