Rolling change forward again: "Add support for global operation dispatchers. (This is intended for use by TF-internal classes only.)"

PiperOrigin-RevId: 311373578
Change-Id: Ib40cee66bbb1395c8997db3c1eb3f5914425a280
This commit is contained in:
Edward Loper 2020-05-13 11:49:10 -07:00 committed by TensorFlower Gardener
parent b69595c6c7
commit 90f3a1eb38
2 changed files with 77 additions and 2 deletions

View File

@ -39,6 +39,10 @@ from tensorflow.python.util import tf_inspect
DISPATCH_ATTR = "_tf_dispatchers"
# OpDispatchers which should be used for all operations.
_GLOBAL_DISPATCHERS = []
class OpDispatcher(object):
"""Abstract base class for TensorFlow operator dispatchers.
@ -82,6 +86,19 @@ class OpDispatcher(object):
getattr(op, DISPATCH_ATTR).append(self)
class GlobalOpDispatcher(object):
"""Abstract base class for TensorFlow global operator dispatchers."""
NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED
def handle(self, op, args, kwargs):
"""Handle the specified operation with the specified arguments."""
def register(self):
"""Register this dispatcher as a handler for all ops."""
_GLOBAL_DISPATCHERS.append(self)
def dispatch(op, *args, **kwargs):
"""Returns the result from the first successful dispatcher for a given op.
@ -101,6 +118,10 @@ def dispatch(op, *args, **kwargs):
result = dispatcher.handle(args, kwargs)
if result is not OpDispatcher.NOT_SUPPORTED:
return result
for dispatcher in _GLOBAL_DISPATCHERS:
result = dispatcher.handle(op, args, kwargs)
if result is not OpDispatcher.NOT_SUPPORTED:
return result
return OpDispatcher.NOT_SUPPORTED

View File

@ -45,6 +45,47 @@ def test_op(x, y, z):
return x + (2 * y) + (3 * z)
class TensorTracer(object):
"""An object used to trace TensorFlow graphs.
This is an example class that is used to test global op dispatchers. The
global op dispatcher for TensorTracers is defined below.
"""
def __init__(self, name, args=None, kwargs=None):
self.name = name
self.args = args
self.kwargs = kwargs
def __repr__(self):
if self.args is None and self.kwargs is None:
return self.name
else:
args = [str(x) for x in self.args]
args += sorted(
["{}={}".format(name, x) for (name, x) in self.kwargs.items()])
return "{}({})".format(self.name, ", ".join(args))
class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher):
"""Global op dispatcher for TensorTracer."""
def handle(self, op, args, kwargs):
# Dispatcher only applies if at least one arg is a TensorTracer.
if not (any(self.is_tensor_tracer_arg(x) for x in args) or
any(self.is_tensor_tracer_arg(x) for x in kwargs.values())):
return self.NOT_SUPPORTED
return TensorTracer(op.__name__, args, kwargs)
def is_tensor_tracer_arg(self, value):
if isinstance(value, TensorTracer):
return True
if isinstance(value, (list, tuple)):
if any(isinstance(x, TensorTracer) for x in value):
return True
@test_util.run_all_in_graph_and_eager_modes
class DispatchTest(test_util.TensorFlowTestCase):
@ -131,8 +172,21 @@ class DispatchTest(test_util.TensorFlowTestCase):
r".*some_op \(from __main__\) is deprecated and will be "
"removed in a future version.*")
def testGlobalDispatcher(self):
original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS
try:
TensorTracerOpDispatcher().register()
x = TensorTracer("x")
y = TensorTracer("y")
trace = math_ops.reduce_sum(math_ops.add(math_ops.abs(x), y), axis=3)
self.assertEqual(
str(trace), "reduce_sum(add(name=None, x=abs(x), y=y), axis=3)")
finally:
# Clean up.
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
if __name__ == "__main__":
googletest.main()