Fix fetching concrete functions from def_function using only an input_signature
Used when exporting functions to SavedModel PiperOrigin-RevId: 221100918
This commit is contained in:
parent
e7d988404a
commit
794788971b
@ -285,7 +285,12 @@ class PolymorphicFunction(object):
|
|||||||
|
|
||||||
self._stateless_fn = self._defun_with_scope(invalid_creator_scope)
|
self._stateless_fn = self._defun_with_scope(invalid_creator_scope)
|
||||||
self._stateless_fn._name = self._name # pylint: disable=protected-access
|
self._stateless_fn._name = self._name # pylint: disable=protected-access
|
||||||
return self._stateful_fn._canonicalize_function_inputs(*args, **kwds) # pylint: disable=protected-access
|
if self._input_signature is None or args or kwds:
|
||||||
|
return self._stateful_fn._canonicalize_function_inputs(*args, **kwds) # pylint: disable=protected-access
|
||||||
|
# If an input signature is defined, we may need to fetch a concrete function
|
||||||
|
# without any inputs specified. In this case args and kwds should be ignored
|
||||||
|
# but running _canonicalize_function_inputs would raise an exception.
|
||||||
|
return (), {}
|
||||||
|
|
||||||
def __call__(self, *args, **kwds):
|
def __call__(self, *args, **kwds):
|
||||||
"""Calls the graph function."""
|
"""Calls the graph function."""
|
||||||
|
@ -213,6 +213,19 @@ class DefFunctionTest(test.TestCase):
|
|||||||
model = _ModelWithOptimizer()
|
model = _ModelWithOptimizer()
|
||||||
model(x, y)
|
model(x, y)
|
||||||
|
|
||||||
|
def test_concrete_function_from_signature(self):
|
||||||
|
|
||||||
|
@def_function.function(
|
||||||
|
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||||
|
def compute(x):
|
||||||
|
return 2. * x
|
||||||
|
|
||||||
|
concrete = compute.get_concrete_function()
|
||||||
|
self.assertAllClose(1., concrete(constant_op.constant(0.5)))
|
||||||
|
concrete = compute.get_concrete_function(
|
||||||
|
tensor_spec.TensorSpec(None, dtypes.float32))
|
||||||
|
self.assertAllClose(4., concrete(constant_op.constant(2.)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
Loading…
Reference in New Issue
Block a user