Rolling change forward again: "Add support for global operation dispatchers. (This is intended for use by TF-internal classes only.)"
PiperOrigin-RevId: 311373578 Change-Id: Ib40cee66bbb1395c8997db3c1eb3f5914425a280
This commit is contained in:
parent
b69595c6c7
commit
90f3a1eb38
|
@ -39,6 +39,10 @@ from tensorflow.python.util import tf_inspect
|
||||||
DISPATCH_ATTR = "_tf_dispatchers"
|
DISPATCH_ATTR = "_tf_dispatchers"
|
||||||
|
|
||||||
|
|
||||||
|
# OpDispatchers which should be used for all operations.
|
||||||
|
_GLOBAL_DISPATCHERS = []
|
||||||
|
|
||||||
|
|
||||||
class OpDispatcher(object):
|
class OpDispatcher(object):
|
||||||
"""Abstract base class for TensorFlow operator dispatchers.
|
"""Abstract base class for TensorFlow operator dispatchers.
|
||||||
|
|
||||||
|
@ -82,6 +86,19 @@ class OpDispatcher(object):
|
||||||
getattr(op, DISPATCH_ATTR).append(self)
|
getattr(op, DISPATCH_ATTR).append(self)
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalOpDispatcher(object):
|
||||||
|
"""Abstract base class for TensorFlow global operator dispatchers."""
|
||||||
|
|
||||||
|
NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED
|
||||||
|
|
||||||
|
def handle(self, op, args, kwargs):
|
||||||
|
"""Handle the specified operation with the specified arguments."""
|
||||||
|
|
||||||
|
def register(self):
|
||||||
|
"""Register this dispatcher as a handler for all ops."""
|
||||||
|
_GLOBAL_DISPATCHERS.append(self)
|
||||||
|
|
||||||
|
|
||||||
def dispatch(op, *args, **kwargs):
|
def dispatch(op, *args, **kwargs):
|
||||||
"""Returns the result from the first successful dispatcher for a given op.
|
"""Returns the result from the first successful dispatcher for a given op.
|
||||||
|
|
||||||
|
@ -101,6 +118,10 @@ def dispatch(op, *args, **kwargs):
|
||||||
result = dispatcher.handle(args, kwargs)
|
result = dispatcher.handle(args, kwargs)
|
||||||
if result is not OpDispatcher.NOT_SUPPORTED:
|
if result is not OpDispatcher.NOT_SUPPORTED:
|
||||||
return result
|
return result
|
||||||
|
for dispatcher in _GLOBAL_DISPATCHERS:
|
||||||
|
result = dispatcher.handle(op, args, kwargs)
|
||||||
|
if result is not OpDispatcher.NOT_SUPPORTED:
|
||||||
|
return result
|
||||||
return OpDispatcher.NOT_SUPPORTED
|
return OpDispatcher.NOT_SUPPORTED
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -45,6 +45,47 @@ def test_op(x, y, z):
|
||||||
return x + (2 * y) + (3 * z)
|
return x + (2 * y) + (3 * z)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorTracer(object):
|
||||||
|
"""An object used to trace TensorFlow graphs.
|
||||||
|
|
||||||
|
This is an example class that is used to test global op dispatchers. The
|
||||||
|
global op dispatcher for TensorTracers is defined below.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name, args=None, kwargs=None):
|
||||||
|
self.name = name
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.args is None and self.kwargs is None:
|
||||||
|
return self.name
|
||||||
|
else:
|
||||||
|
args = [str(x) for x in self.args]
|
||||||
|
args += sorted(
|
||||||
|
["{}={}".format(name, x) for (name, x) in self.kwargs.items()])
|
||||||
|
return "{}({})".format(self.name, ", ".join(args))
|
||||||
|
|
||||||
|
|
||||||
|
class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher):
|
||||||
|
"""Global op dispatcher for TensorTracer."""
|
||||||
|
|
||||||
|
def handle(self, op, args, kwargs):
|
||||||
|
# Dispatcher only applies if at least one arg is a TensorTracer.
|
||||||
|
if not (any(self.is_tensor_tracer_arg(x) for x in args) or
|
||||||
|
any(self.is_tensor_tracer_arg(x) for x in kwargs.values())):
|
||||||
|
return self.NOT_SUPPORTED
|
||||||
|
|
||||||
|
return TensorTracer(op.__name__, args, kwargs)
|
||||||
|
|
||||||
|
def is_tensor_tracer_arg(self, value):
|
||||||
|
if isinstance(value, TensorTracer):
|
||||||
|
return True
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
if any(isinstance(x, TensorTracer) for x in value):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class DispatchTest(test_util.TensorFlowTestCase):
|
class DispatchTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@ -131,8 +172,21 @@ class DispatchTest(test_util.TensorFlowTestCase):
|
||||||
r".*some_op \(from __main__\) is deprecated and will be "
|
r".*some_op \(from __main__\) is deprecated and will be "
|
||||||
"removed in a future version.*")
|
"removed in a future version.*")
|
||||||
|
|
||||||
|
def testGlobalDispatcher(self):
|
||||||
|
original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS
|
||||||
|
try:
|
||||||
|
TensorTracerOpDispatcher().register()
|
||||||
|
|
||||||
|
x = TensorTracer("x")
|
||||||
|
y = TensorTracer("y")
|
||||||
|
trace = math_ops.reduce_sum(math_ops.add(math_ops.abs(x), y), axis=3)
|
||||||
|
self.assertEqual(
|
||||||
|
str(trace), "reduce_sum(add(name=None, x=abs(x), y=y), axis=3)")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up.
|
||||||
|
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue