Fix bug where dispatch broke for ops that define an argument named 'op'.
PiperOrigin-RevId: 312288165 Change-Id: I714848226466815cb34e8497ebc7df471880783a
This commit is contained in:
parent
e0b19f6ef2
commit
9bdd084064
|
@ -959,7 +959,10 @@ void GenEagerPythonOp::AddDispatch(const string& prefix) {
|
||||||
|
|
||||||
strings::StrAppend(&result_, prefix, "except (TypeError, ValueError):\n");
|
strings::StrAppend(&result_, prefix, "except (TypeError, ValueError):\n");
|
||||||
strings::StrAppend(&result_, prefix, " result = _dispatch.dispatch(\n");
|
strings::StrAppend(&result_, prefix, " result = _dispatch.dispatch(\n");
|
||||||
AddBodyNoReturn(strings::StrCat(prefix, " ", function_name_, ", "));
|
AddBodyNoReturn(strings::StrCat(prefix, " ", function_name_,
|
||||||
|
", "
|
||||||
|
"(), dict("));
|
||||||
|
strings::StrAppend(&result_, prefix, " )\n");
|
||||||
strings::StrAppend(&result_, prefix,
|
strings::StrAppend(&result_, prefix,
|
||||||
" if result is not "
|
" if result is not "
|
||||||
"_dispatch.OpDispatcher.NOT_SUPPORTED:\n");
|
"_dispatch.OpDispatcher.NOT_SUPPORTED:\n");
|
||||||
|
|
|
@ -99,7 +99,7 @@ class GlobalOpDispatcher(object):
|
||||||
_GLOBAL_DISPATCHERS.append(self)
|
_GLOBAL_DISPATCHERS.append(self)
|
||||||
|
|
||||||
|
|
||||||
def dispatch(op, *args, **kwargs):
|
def dispatch(op, args, kwargs):
|
||||||
"""Returns the result from the first successful dispatcher for a given op.
|
"""Returns the result from the first successful dispatcher for a given op.
|
||||||
|
|
||||||
Calls the `handle` method of each `OpDispatcher` that has been registered
|
Calls the `handle` method of each `OpDispatcher` that has been registered
|
||||||
|
@ -107,8 +107,8 @@ def dispatch(op, *args, **kwargs):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
op: Python function: the operation to dispatch for.
|
op: Python function: the operation to dispatch for.
|
||||||
*args: The arguments to the operation.
|
args: The arguments to the operation.
|
||||||
**kwargs: They keyword arguments to the operation.
|
kwargs: They keyword arguments to the operation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The result of the operation, or `NOT_SUPPORTED` if no registered
|
The result of the operation, or `NOT_SUPPORTED` if no registered
|
||||||
|
@ -202,7 +202,7 @@ def add_dispatch_support(target):
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
# Note: convert_to_eager_tensor currently raises a ValueError, not a
|
# Note: convert_to_eager_tensor currently raises a ValueError, not a
|
||||||
# TypeError, when given unexpected types. So we need to catch both.
|
# TypeError, when given unexpected types. So we need to catch both.
|
||||||
result = dispatch(wrapper, *args, **kwargs)
|
result = dispatch(wrapper, args, kwargs)
|
||||||
if result is not OpDispatcher.NOT_SUPPORTED:
|
if result is not OpDispatcher.NOT_SUPPORTED:
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue