From 90f3a1eb381e644ac5d0f3fd126af25f856820a9 Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Wed, 13 May 2020 11:49:10 -0700 Subject: [PATCH] Rolling change forward again: "Add support for global operation dispatchers. (This is intended for use by TF-internal classes only.)" PiperOrigin-RevId: 311373578 Change-Id: Ib40cee66bbb1395c8997db3c1eb3f5914425a280 --- tensorflow/python/util/dispatch.py | 21 +++++++++ tensorflow/python/util/dispatch_test.py | 58 ++++++++++++++++++++++++- 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/util/dispatch.py b/tensorflow/python/util/dispatch.py index e94e3345348..3868da14b44 100644 --- a/tensorflow/python/util/dispatch.py +++ b/tensorflow/python/util/dispatch.py @@ -39,6 +39,10 @@ from tensorflow.python.util import tf_inspect DISPATCH_ATTR = "_tf_dispatchers" +# OpDispatchers which should be used for all operations. +_GLOBAL_DISPATCHERS = [] + + class OpDispatcher(object): """Abstract base class for TensorFlow operator dispatchers. @@ -82,6 +86,19 @@ class OpDispatcher(object): getattr(op, DISPATCH_ATTR).append(self) +class GlobalOpDispatcher(object): + """Abstract base class for TensorFlow global operator dispatchers.""" + + NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED + + def handle(self, op, args, kwargs): + """Handle the specified operation with the specified arguments.""" + + def register(self): + """Register this dispatcher as a handler for all ops.""" + _GLOBAL_DISPATCHERS.append(self) + + def dispatch(op, *args, **kwargs): """Returns the result from the first successful dispatcher for a given op. @@ -101,6 +118,10 @@ def dispatch(op, *args, **kwargs): result = dispatcher.handle(args, kwargs) if result is not OpDispatcher.NOT_SUPPORTED: return result + for dispatcher in _GLOBAL_DISPATCHERS: + result = dispatcher.handle(op, args, kwargs) + if result is not OpDispatcher.NOT_SUPPORTED: + return result return OpDispatcher.NOT_SUPPORTED diff --git a/tensorflow/python/util/dispatch_test.py b/tensorflow/python/util/dispatch_test.py index 89999fcf843..bd35c391924 100644 --- a/tensorflow/python/util/dispatch_test.py +++ b/tensorflow/python/util/dispatch_test.py @@ -45,6 +45,47 @@ def test_op(x, y, z): return x + (2 * y) + (3 * z) +class TensorTracer(object): + """An object used to trace TensorFlow graphs. + + This is an example class that is used to test global op dispatchers. The + global op dispatcher for TensorTracers is defined below. + """ + + def __init__(self, name, args=None, kwargs=None): + self.name = name + self.args = args + self.kwargs = kwargs + + def __repr__(self): + if self.args is None and self.kwargs is None: + return self.name + else: + args = [str(x) for x in self.args] + args += sorted( + ["{}={}".format(name, x) for (name, x) in self.kwargs.items()]) + return "{}({})".format(self.name, ", ".join(args)) + + +class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher): + """Global op dispatcher for TensorTracer.""" + + def handle(self, op, args, kwargs): + # Dispatcher only applies if at least one arg is a TensorTracer. + if not (any(self.is_tensor_tracer_arg(x) for x in args) or + any(self.is_tensor_tracer_arg(x) for x in kwargs.values())): + return self.NOT_SUPPORTED + + return TensorTracer(op.__name__, args, kwargs) + + def is_tensor_tracer_arg(self, value): + if isinstance(value, TensorTracer): + return True + if isinstance(value, (list, tuple)): + if any(isinstance(x, TensorTracer) for x in value): + return True + + @test_util.run_all_in_graph_and_eager_modes class DispatchTest(test_util.TensorFlowTestCase): @@ -131,8 +172,21 @@ class DispatchTest(test_util.TensorFlowTestCase): r".*some_op \(from __main__\) is deprecated and will be " "removed in a future version.*") + def testGlobalDispatcher(self): + original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS + try: + TensorTracerOpDispatcher().register() + + x = TensorTracer("x") + y = TensorTracer("y") + trace = math_ops.reduce_sum(math_ops.add(math_ops.abs(x), y), axis=3) + self.assertEqual( + str(trace), "reduce_sum(add(name=None, x=abs(x), y=y), axis=3)") + + finally: + # Clean up. + dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers + if __name__ == "__main__": googletest.main() - -