From 80aaa339819bda86471784c2636a4d82a36907ea Mon Sep 17 00:00:00 2001 From: Vojtech Bardiovsky Date: Tue, 5 Mar 2019 16:03:10 -0800 Subject: [PATCH] Remove special condition on canonicalization. The code path where this was necessary no longer exists. PiperOrigin-RevId: 236939023 --- tensorflow/python/eager/def_function.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index d2f8fb2eb54..a39fb82b34b 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -326,15 +326,6 @@ class Function(object): autograph=self._autograph, experimental_autograph_options=self._experimental_autograph_options) - def _canonicalize_function_inputs(self, args, kwds): - """Canonicalize the inputs to the Python function.""" - if self.input_signature is None or args or kwds: - return self._function_spec.canonicalize_function_inputs(*args, **kwds) # pylint: disable=protected-access - # If an input signature is defined, we may need to fetch a concrete function - # without any inputs specified. In this case args and kwds should be ignored - # but running _canonicalize_function_inputs would raise an exception. - return (), {} - def _initialize(self, args, kwds, add_initializers_to=None): """Initializes, on the first call. @@ -439,7 +430,8 @@ class Function(object): # stateless function. return self._stateless_fn(*args, **kwds) else: - canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds) + canon_args, canon_kwds = self._function_spec.canonicalize_function_inputs( + *args, **kwds) # If we did not create any variables the trace we have is good enough. return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds) # pylint: disable=protected-access @@ -496,7 +488,8 @@ class Function(object): # We've created variables and are unable to lift the initialization graphs, # so we fall back to initializing with conds while running the function. - canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds) + canon_args, canon_kwds = self._function_spec.canonicalize_function_inputs( + *args, **kwds) return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds) @property