Add support for global operation dispatchers. (This is intended for use by TF-internal classes only.)
PiperOrigin-RevId: 311350209 Change-Id: Ib095f019fc6825409b490d7dec7e86116955b746
This commit is contained in:
parent
902ffede1f
commit
21b04b6fe0
|
@ -39,10 +39,6 @@ from tensorflow.python.util import tf_inspect
|
|||
DISPATCH_ATTR = "_tf_dispatchers"
|
||||
|
||||
|
||||
# OpDispatchers which should be used for all operations.
|
||||
_GLOBAL_DISPATCHERS = []
|
||||
|
||||
|
||||
class OpDispatcher(object):
|
||||
"""Abstract base class for TensorFlow operator dispatchers.
|
||||
|
||||
|
@ -86,19 +82,6 @@ class OpDispatcher(object):
|
|||
getattr(op, DISPATCH_ATTR).append(self)
|
||||
|
||||
|
||||
class GlobalOpDispatcher(object):
|
||||
"""Abstract base class for TensorFlow global operator dispatchers."""
|
||||
|
||||
NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED
|
||||
|
||||
def handle(self, op, args, kwargs):
|
||||
"""Handle the specified operation with the specified arguments."""
|
||||
|
||||
def register(self):
|
||||
"""Register this dispatcher as a handler for all ops."""
|
||||
_GLOBAL_DISPATCHERS.append(self)
|
||||
|
||||
|
||||
def dispatch(op, *args, **kwargs):
|
||||
"""Returns the result from the first successful dispatcher for a given op.
|
||||
|
||||
|
@ -118,10 +101,6 @@ def dispatch(op, *args, **kwargs):
|
|||
result = dispatcher.handle(args, kwargs)
|
||||
if result is not OpDispatcher.NOT_SUPPORTED:
|
||||
return result
|
||||
for dispatcher in _GLOBAL_DISPATCHERS:
|
||||
result = dispatcher.handle(op, args, kwargs)
|
||||
if result is not OpDispatcher.NOT_SUPPORTED:
|
||||
return result
|
||||
return OpDispatcher.NOT_SUPPORTED
|
||||
|
||||
|
||||
|
|
|
@ -45,47 +45,6 @@ def test_op(x, y, z):
|
|||
return x + (2 * y) + (3 * z)
|
||||
|
||||
|
||||
class TensorTracer(object):
|
||||
"""An object used to trace TensorFlow graphs.
|
||||
|
||||
This is an example class that is used to test global op dispatchers. The
|
||||
global op dispatcher for TensorTracers is defined below.
|
||||
"""
|
||||
|
||||
def __init__(self, name, args=None, kwargs=None):
|
||||
self.name = name
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __repr__(self):
|
||||
if self.args is None and self.kwargs is None:
|
||||
return self.name
|
||||
else:
|
||||
args = [str(x) for x in self.args]
|
||||
args += sorted(
|
||||
["{}={}".format(name, x) for (name, x) in self.kwargs.items()])
|
||||
return "{}({})".format(self.name, ", ".join(args))
|
||||
|
||||
|
||||
class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher):
|
||||
"""Global op dispatcher for TensorTracer."""
|
||||
|
||||
def handle(self, op, args, kwargs):
|
||||
# Dispatcher only applies if at least one arg is a TensorTracer.
|
||||
if not (any(self.is_tensor_tracer_arg(x) for x in args) or
|
||||
any(self.is_tensor_tracer_arg(x) for x in kwargs.values())):
|
||||
return self.NOT_SUPPORTED
|
||||
|
||||
return TensorTracer(op.__name__, args, kwargs)
|
||||
|
||||
def is_tensor_tracer_arg(self, value):
|
||||
if isinstance(value, TensorTracer):
|
||||
return True
|
||||
if isinstance(value, (list, tuple)):
|
||||
if any(isinstance(x, TensorTracer) for x in value):
|
||||
return True
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class DispatchTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
@ -172,21 +131,8 @@ class DispatchTest(test_util.TensorFlowTestCase):
|
|||
r".*some_op \(from __main__\) is deprecated and will be "
|
||||
"removed in a future version.*")
|
||||
|
||||
def testGlobalDispatcher(self):
|
||||
original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS
|
||||
try:
|
||||
TensorTracerOpDispatcher().register()
|
||||
|
||||
x = TensorTracer("x")
|
||||
y = TensorTracer("y")
|
||||
trace = math_ops.reduce_sum(math_ops.add(math_ops.abs(x), y), axis=3)
|
||||
self.assertEqual(
|
||||
str(trace), "reduce_sum(add(name=None, x=abs(x), y=y), axis=3)")
|
||||
|
||||
finally:
|
||||
# Clean up.
|
||||
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue