Fix bug where dispatch broke for ops that define an argument named 'op'.

PiperOrigin-RevId: 312288165
Change-Id: I714848226466815cb34e8497ebc7df471880783a
This commit is contained in:
Edward Loper 2020-05-19 08:50:47 -07:00 committed by TensorFlower Gardener
parent e0b19f6ef2
commit 9bdd084064
2 changed files with 8 additions and 5 deletions

View File

@ -959,7 +959,10 @@ void GenEagerPythonOp::AddDispatch(const string& prefix) {
strings::StrAppend(&result_, prefix, "except (TypeError, ValueError):\n"); strings::StrAppend(&result_, prefix, "except (TypeError, ValueError):\n");
strings::StrAppend(&result_, prefix, " result = _dispatch.dispatch(\n"); strings::StrAppend(&result_, prefix, " result = _dispatch.dispatch(\n");
AddBodyNoReturn(strings::StrCat(prefix, " ", function_name_, ", ")); AddBodyNoReturn(strings::StrCat(prefix, " ", function_name_,
", "
"(), dict("));
strings::StrAppend(&result_, prefix, " )\n");
strings::StrAppend(&result_, prefix, strings::StrAppend(&result_, prefix,
" if result is not " " if result is not "
"_dispatch.OpDispatcher.NOT_SUPPORTED:\n"); "_dispatch.OpDispatcher.NOT_SUPPORTED:\n");

View File

@ -99,7 +99,7 @@ class GlobalOpDispatcher(object):
_GLOBAL_DISPATCHERS.append(self) _GLOBAL_DISPATCHERS.append(self)
def dispatch(op, *args, **kwargs): def dispatch(op, args, kwargs):
"""Returns the result from the first successful dispatcher for a given op. """Returns the result from the first successful dispatcher for a given op.
Calls the `handle` method of each `OpDispatcher` that has been registered Calls the `handle` method of each `OpDispatcher` that has been registered
@ -107,8 +107,8 @@ def dispatch(op, *args, **kwargs):
Args: Args:
op: Python function: the operation to dispatch for. op: Python function: the operation to dispatch for.
*args: The arguments to the operation. args: The arguments to the operation.
**kwargs: They keyword arguments to the operation. kwargs: They keyword arguments to the operation.
Returns: Returns:
The result of the operation, or `NOT_SUPPORTED` if no registered The result of the operation, or `NOT_SUPPORTED` if no registered
@ -202,7 +202,7 @@ def add_dispatch_support(target):
except (TypeError, ValueError): except (TypeError, ValueError):
# Note: convert_to_eager_tensor currently raises a ValueError, not a # Note: convert_to_eager_tensor currently raises a ValueError, not a
# TypeError, when given unexpected types. So we need to catch both. # TypeError, when given unexpected types. So we need to catch both.
result = dispatch(wrapper, *args, **kwargs) result = dispatch(wrapper, args, kwargs)
if result is not OpDispatcher.NOT_SUPPORTED: if result is not OpDispatcher.NOT_SUPPORTED:
return result return result
else: else: