Improve performance of tf.subscribe on large graphs by caching control outputs.

Change: 138473617
This commit is contained in:
A. Unique TensorFlower 2016-11-07 20:25:54 -08:00 committed by TensorFlower Gardener
parent 10d6df2f95
commit e1dddfb401
2 changed files with 73 additions and 13 deletions
tensorflow/python/framework

View File

@ -57,25 +57,51 @@ def _recursive_apply(tensors, apply_fn):
(tensors, tensors_type))
def _control_outputs(op):
"""Returns the control_input consumers for the supplied `Operation`.
class _ControlOutputCache(object):
"""Helper class to manage calculating and caching control_outputs in graph."""
Args:
op: The `Operation` to find consumers of.
Yields:
A list of ops that have op as a control dependency.
"""
for o in op.graph.get_operations():
if op in o.control_inputs:
yield o
def __init__(self):
self.cache = {}
def calc_control_outputs(self, graph):
"""Returns the map of control_outputs for a given graph.
Args:
graph: The graph to parse.
Returns:
A map of the control outputs.
"""
control_outputs = {}
for op in graph.get_operations():
for control_input in op.control_inputs:
if control_input not in control_outputs:
control_outputs[control_input] = set()
control_outputs[control_input].add(op)
return control_outputs
def get_control_outputs(self, op):
"""Return the control outputs for a given op.
Args:
op: The op to fetch control outputs for.
Returns:
Iterable of control output ops.
"""
if op.graph not in self.cache:
control_outputs = self.calc_control_outputs(op.graph)
self.cache[op.graph] = control_outputs
else:
control_outputs = self.cache[op.graph]
return control_outputs.get(op, [])
def _subscribe(tensor, side_effects):
def _subscribe(tensor, side_effects, control_cache):
"""Helper method that subscribes a single tensor to a list of side_effects.
Args:
tensor: `tf.Tensor`
side_effects: List of side_effect functions see subscribe for details.
control_cache: `_ControlOutputCache` helper to get control_outputs faster.
Returns:
The modified replacement to the passed in tensor which triggers the side
effects.
@ -84,7 +110,7 @@ def _subscribe(tensor, side_effects):
for consumer_op in list(tensor.consumers()): # explicit copy
update_input.append((consumer_op, list(consumer_op.inputs).index(tensor)))
update_control_input = list(_control_outputs(tensor.op))
update_control_input = control_cache.get_control_outputs(tensor.op)
# Trailing slash on name scope to replace the scope.
name_scope = tensor.op.name + '/subscription/'
@ -141,4 +167,8 @@ def subscribe(tensors, side_effects):
"""
if not hasattr(side_effects, '__iter__'):
side_effects = [side_effects]
return _recursive_apply(tensors, lambda t: _subscribe(t, side_effects))
control_outputs = _ControlOutputCache()
result = _recursive_apply(
tensors, lambda t: _subscribe(t, side_effects, control_outputs))
return result

View File

@ -54,6 +54,36 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertEquals(d_out, [42])
self.assertEquals(shared, [2, 2, 2])
def testCaching(self):
"""Confirm caching of control output is recacluated between calls."""
a = tf.constant(1)
b = tf.constant(2)
with tf.control_dependencies([a]):
c = tf.constant(42)
shared = {}
def sub(t):
shared[t] = shared.get(t, 0) + 1
return t
a = subscribe.subscribe(a, lambda t: tf.py_func(sub, [t], [t.dtype]))
with tf.control_dependencies([b]):
d = tf.constant(11)
# If it was using outdated cached control_outputs then
# evaling would not trigger the new subscription.
b = subscribe.subscribe(b, lambda t: tf.py_func(sub, [t], [t.dtype]))
with self.test_session() as sess:
c_out = sess.run([c])
d_out = sess.run([d])
self.assertEquals(c_out, [42])
self.assertEquals(d_out, [11])
self.assertEquals(shared, {2: 1, 1: 1})
if __name__ == '__main__':
googletest.main()