From 9bdd08406461f6988cffb48100ab79994b50ee64 Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Tue, 19 May 2020 08:50:47 -0700 Subject: [PATCH] Fix bug where dispatch broke for ops that define an argument named 'op'. PiperOrigin-RevId: 312288165 Change-Id: I714848226466815cb34e8497ebc7df471880783a --- tensorflow/python/framework/python_op_gen.cc | 5 ++++- tensorflow/python/util/dispatch.py | 8 ++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 857cc7b6638..ca0c5d9ef1a 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -959,7 +959,10 @@ void GenEagerPythonOp::AddDispatch(const string& prefix) { strings::StrAppend(&result_, prefix, "except (TypeError, ValueError):\n"); strings::StrAppend(&result_, prefix, " result = _dispatch.dispatch(\n"); - AddBodyNoReturn(strings::StrCat(prefix, " ", function_name_, ", ")); + AddBodyNoReturn(strings::StrCat(prefix, " ", function_name_, + ", " + "(), dict(")); + strings::StrAppend(&result_, prefix, " )\n"); strings::StrAppend(&result_, prefix, " if result is not " "_dispatch.OpDispatcher.NOT_SUPPORTED:\n"); diff --git a/tensorflow/python/util/dispatch.py b/tensorflow/python/util/dispatch.py index 3868da14b44..51dfe3793ae 100644 --- a/tensorflow/python/util/dispatch.py +++ b/tensorflow/python/util/dispatch.py @@ -99,7 +99,7 @@ class GlobalOpDispatcher(object): _GLOBAL_DISPATCHERS.append(self) -def dispatch(op, *args, **kwargs): +def dispatch(op, args, kwargs): """Returns the result from the first successful dispatcher for a given op. Calls the `handle` method of each `OpDispatcher` that has been registered @@ -107,8 +107,8 @@ def dispatch(op, *args, **kwargs): Args: op: Python function: the operation to dispatch for. - *args: The arguments to the operation. - **kwargs: They keyword arguments to the operation. + args: The arguments to the operation. + kwargs: They keyword arguments to the operation. Returns: The result of the operation, or `NOT_SUPPORTED` if no registered @@ -202,7 +202,7 @@ def add_dispatch_support(target): except (TypeError, ValueError): # Note: convert_to_eager_tensor currently raises a ValueError, not a # TypeError, when given unexpected types. So we need to catch both. - result = dispatch(wrapper, *args, **kwargs) + result = dispatch(wrapper, args, kwargs) if result is not OpDispatcher.NOT_SUPPORTED: return result else: