213 lines
7.3 KiB
Python
213 lines
7.3 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Type-based dispatch for TensorFlow ops.
|
|
|
|
"Operation dispatchers" can be used to override the behavior for TensorFlow ops
|
|
when they are called with otherwise unsupported argument types. In particular,
|
|
when an operation is called with arguments that would cause it to raise a
|
|
TypeError, it falls back on its registered operation dispatchers. If any
|
|
registered dispatchers can handle the arguments, then its result is returned.
|
|
Otherwise, the original TypeError is raised.
|
|
|
|
By default, dispatch support is added to the generated op wrappers for any
|
|
visible ops by default. Ops that are implemented in Python can opt in to
|
|
dispatch support using the `add_dispatch_support` decorator.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import itertools
|
|
|
|
from tensorflow.python.util import tf_decorator
|
|
from tensorflow.python.util import tf_inspect
|
|
|
|
# Private function attribute used to store a list of dispatchers.
|
|
DISPATCH_ATTR = "_tf_dispatchers"
|
|
|
|
|
|
# OpDispatchers which should be used for all operations.
|
|
_GLOBAL_DISPATCHERS = []
|
|
|
|
|
|
class OpDispatcher(object):
|
|
"""Abstract base class for TensorFlow operator dispatchers.
|
|
|
|
Each operation dispatcher acts as an override handler for a single
|
|
TensorFlow operation, and its results are used when the handler indicates
|
|
that it can handle the operation's arguments (by returning any value other
|
|
than `OpDispatcher.NOT_SUPPORTED`).
|
|
"""
|
|
|
|
# Sentinel value that can be returned to indicate that an operation
|
|
# dispatcher does not support a given set of arguments.
|
|
NOT_SUPPORTED = object()
|
|
|
|
def handle(self, args, kwargs): # pylint: disable=unused-argument
|
|
"""Handle this dispatcher's operation with the specified arguments.
|
|
|
|
If this operation dispatcher can handle the given arguments, then
|
|
return an appropriate value (or raise an appropriate exception).
|
|
|
|
Args:
|
|
args: The arguments to the operation.
|
|
kwargs: They keyword arguments to the operation.
|
|
|
|
Returns:
|
|
The result of the operation, or `OpDispatcher.NOT_SUPPORTED` if this
|
|
dispatcher can not handle the given arguments.
|
|
"""
|
|
return self.NOT_SUPPORTED
|
|
|
|
def register(self, op):
|
|
"""Register this dispatcher as a handler for `op`.
|
|
|
|
Args:
|
|
op: Python function: the TensorFlow operation that should be handled. Must
|
|
have a dispatch list (which is added automatically for generated ops,
|
|
and can be added to Python ops using the `add_dispatch_support`
|
|
decorator).
|
|
"""
|
|
if not hasattr(op, DISPATCH_ATTR):
|
|
raise AssertionError("Dispatching not enabled for %s" % op)
|
|
getattr(op, DISPATCH_ATTR).append(self)
|
|
|
|
|
|
class GlobalOpDispatcher(object):
|
|
"""Abstract base class for TensorFlow global operator dispatchers."""
|
|
|
|
NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED
|
|
|
|
def handle(self, op, args, kwargs):
|
|
"""Handle the specified operation with the specified arguments."""
|
|
|
|
def register(self):
|
|
"""Register this dispatcher as a handler for all ops."""
|
|
_GLOBAL_DISPATCHERS.append(self)
|
|
|
|
|
|
def dispatch(op, args, kwargs):
|
|
"""Returns the result from the first successful dispatcher for a given op.
|
|
|
|
Calls the `handle` method of each `OpDispatcher` that has been registered
|
|
to handle `op`, and returns the value from the first successful handler.
|
|
|
|
Args:
|
|
op: Python function: the operation to dispatch for.
|
|
args: The arguments to the operation.
|
|
kwargs: They keyword arguments to the operation.
|
|
|
|
Returns:
|
|
The result of the operation, or `NOT_SUPPORTED` if no registered
|
|
dispatcher can handle the given arguments.
|
|
"""
|
|
for dispatcher in getattr(op, DISPATCH_ATTR):
|
|
result = dispatcher.handle(args, kwargs)
|
|
if result is not OpDispatcher.NOT_SUPPORTED:
|
|
return result
|
|
for dispatcher in _GLOBAL_DISPATCHERS:
|
|
result = dispatcher.handle(op, args, kwargs)
|
|
if result is not OpDispatcher.NOT_SUPPORTED:
|
|
return result
|
|
return OpDispatcher.NOT_SUPPORTED
|
|
|
|
|
|
class _TypeBasedDispatcher(OpDispatcher):
|
|
"""Dispatcher that handles op if any arguments have a specified type.
|
|
|
|
Checks the types of the arguments and keyword arguments (including elements
|
|
of lists or tuples), and if any argument values have the indicated type(s),
|
|
then delegates to an override function.
|
|
"""
|
|
|
|
def __init__(self, override_func, types):
|
|
self._types = types
|
|
self._override_func = override_func
|
|
|
|
def _handles(self, args, kwargs):
|
|
for arg in itertools.chain(args, kwargs.values()):
|
|
if (isinstance(arg, self._types) or
|
|
(isinstance(arg, (list, tuple)) and
|
|
any(isinstance(elt, self._types) for elt in arg))):
|
|
return True
|
|
return False
|
|
|
|
def handle(self, args, kwargs):
|
|
if self._handles(args, kwargs):
|
|
return self._override_func(*args, **kwargs)
|
|
else:
|
|
return self.NOT_SUPPORTED
|
|
|
|
|
|
# pylint: disable=g-doc-return-or-yield
|
|
def dispatch_for_types(op, *types):
|
|
"""Decorator to declare that a Python function overrides an op for a type.
|
|
|
|
The decorated function is used to override `op` if any of the arguments or
|
|
keyword arguments (including elements of lists or tuples) have one of the
|
|
specified types.
|
|
|
|
Example:
|
|
|
|
```python
|
|
@dispatch_for_types(math_ops.add, RaggedTensor, RaggedTensorValue)
|
|
def ragged_add(x, y, name=None): ...
|
|
```
|
|
|
|
Args:
|
|
op: Python function: the operation that should be overridden.
|
|
*types: The argument types for which this function should be used.
|
|
"""
|
|
|
|
def decorator(func):
|
|
if tf_inspect.getargspec(func) != tf_inspect.getargspec(op):
|
|
raise AssertionError("The decorated function's signature must exactly "
|
|
"match the signature of the overridden op.")
|
|
_TypeBasedDispatcher(func, types).register(op)
|
|
return func
|
|
|
|
return decorator
|
|
|
|
|
|
# pylint: enable=g-doc-return-or-yield
|
|
|
|
|
|
def add_dispatch_list(target):
|
|
"""Decorator that adds a dispatch_list attribute to an op."""
|
|
if hasattr(target, DISPATCH_ATTR):
|
|
raise AssertionError("%s already has a dispatch list" % target)
|
|
setattr(target, DISPATCH_ATTR, [])
|
|
return target
|
|
|
|
|
|
def add_dispatch_support(target):
|
|
"""Decorator that adds a dispatch handling wrapper to an op."""
|
|
def wrapper(*args, **kwargs):
|
|
"""Call target, and fall back on dispatchers if there is a TypeError."""
|
|
try:
|
|
return target(*args, **kwargs)
|
|
except (TypeError, ValueError):
|
|
# Note: convert_to_eager_tensor currently raises a ValueError, not a
|
|
# TypeError, when given unexpected types. So we need to catch both.
|
|
result = dispatch(wrapper, args, kwargs)
|
|
if result is not OpDispatcher.NOT_SUPPORTED:
|
|
return result
|
|
else:
|
|
raise
|
|
|
|
add_dispatch_list(wrapper)
|
|
return tf_decorator.make_decorator(target, wrapper)
|