Fix fetching concrete functions from def_function using only an input_signature
Used when exporting functions to SavedModel PiperOrigin-RevId: 221100918
This commit is contained in:
parent
e7d988404a
commit
794788971b
@ -285,7 +285,12 @@ class PolymorphicFunction(object):
|
||||
|
||||
self._stateless_fn = self._defun_with_scope(invalid_creator_scope)
|
||||
self._stateless_fn._name = self._name # pylint: disable=protected-access
|
||||
return self._stateful_fn._canonicalize_function_inputs(*args, **kwds) # pylint: disable=protected-access
|
||||
if self._input_signature is None or args or kwds:
|
||||
return self._stateful_fn._canonicalize_function_inputs(*args, **kwds) # pylint: disable=protected-access
|
||||
# If an input signature is defined, we may need to fetch a concrete function
|
||||
# without any inputs specified. In this case args and kwds should be ignored
|
||||
# but running _canonicalize_function_inputs would raise an exception.
|
||||
return (), {}
|
||||
|
||||
def __call__(self, *args, **kwds):
|
||||
"""Calls the graph function."""
|
||||
|
@ -213,6 +213,19 @@ class DefFunctionTest(test.TestCase):
|
||||
model = _ModelWithOptimizer()
|
||||
model(x, y)
|
||||
|
||||
def test_concrete_function_from_signature(self):
|
||||
|
||||
@def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||
def compute(x):
|
||||
return 2. * x
|
||||
|
||||
concrete = compute.get_concrete_function()
|
||||
self.assertAllClose(1., concrete(constant_op.constant(0.5)))
|
||||
concrete = compute.get_concrete_function(
|
||||
tensor_spec.TensorSpec(None, dtypes.float32))
|
||||
self.assertAllClose(4., concrete(constant_op.constant(2.)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
Reference in New Issue
Block a user