Fix fetching concrete functions from def_function using only an input_signature

Used when exporting functions to SavedModel

PiperOrigin-RevId: 221100918
This commit is contained in:
Allen Lavoie 2018-11-12 09:03:08 -08:00 committed by TensorFlower Gardener
parent e7d988404a
commit 794788971b
2 changed files with 19 additions and 1 deletions

View File

@ -285,7 +285,12 @@ class PolymorphicFunction(object):
self._stateless_fn = self._defun_with_scope(invalid_creator_scope)
self._stateless_fn._name = self._name # pylint: disable=protected-access
return self._stateful_fn._canonicalize_function_inputs(*args, **kwds) # pylint: disable=protected-access
if self._input_signature is None or args or kwds:
return self._stateful_fn._canonicalize_function_inputs(*args, **kwds) # pylint: disable=protected-access
# If an input signature is defined, we may need to fetch a concrete function
# without any inputs specified. In this case args and kwds should be ignored
# but running _canonicalize_function_inputs would raise an exception.
return (), {}
def __call__(self, *args, **kwds):
"""Calls the graph function."""

View File

@ -213,6 +213,19 @@ class DefFunctionTest(test.TestCase):
model = _ModelWithOptimizer()
model(x, y)
def test_concrete_function_from_signature(self):
@def_function.function(
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
def compute(x):
return 2. * x
concrete = compute.get_concrete_function()
self.assertAllClose(1., concrete(constant_op.constant(0.5)))
concrete = compute.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32))
self.assertAllClose(4., concrete(constant_op.constant(2.)))
if __name__ == '__main__':
ops.enable_eager_execution()