Inherit device scopes when building control flow operations

tf.function allows calling a single function body placed on different devices, but there's no need for that with control flow.

PiperOrigin-RevId: 338367553
Change-Id: I9659769737d40e6c9f5ffdad551dd46b6592c5fd
This commit is contained in:
Allen Lavoie 2020-10-21 16:51:17 -07:00 committed by TensorFlower Gardener
parent 6cc0cf5e30
commit d3698cdfd9
3 changed files with 104 additions and 45 deletions

View File

@ -1441,6 +1441,15 @@ class CondV2ContainerTest(test.TestCase):
class CondV2ColocationGroupAndDeviceTest(test.TestCase):
def setUp(self):
super(CondV2ColocationGroupAndDeviceTest, self).setUp()
cpus = context.context().list_physical_devices("CPU")
context.context().set_logical_device_configuration(
cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration()
])
def testColocateWithBeforeCond(self):
with ops.Graph().as_default() as g:
with self.session(graph=g):
@ -1516,31 +1525,64 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
self.assertTrue(len(run_metadata.partition_graphs) >= 2)
def testDeviceBeforeCond(self):
with ops.Graph().as_default() as g:
with self.session(graph=g):
def fn():
self.assertEqual("", constant_op.constant(3.0).op.device)
return test_ops.device_placement_op()
with context.eager_mode():
def fn():
cpu_zero_op = test_ops.device_placement_op()
self.assertEqual("/device:CPU:0", cpu_zero_op.device)
with ops.device("CPU:1"):
cpu_one_op = test_ops.device_placement_op()
self.assertEqual("/device:CPU:1", cpu_one_op.device)
return cpu_zero_op, cpu_one_op
@def_function.function
def _cond_wrapper():
with ops.device("/device:CPU:0"):
self.assertIn(
compat.as_bytes("CPU:0"),
self.evaluate(cond_v2.cond_v2(constant_op.constant(True),
fn, fn)))
return cond_v2.cond_v2(constant_op.constant(True), fn, fn)
def fn2():
self.assertEqual("", constant_op.constant(3.0).op.device)
return test_ops.device_placement_op()
zero_expected, one_expected = self.evaluate(_cond_wrapper())
self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
self.assertIn(compat.as_bytes("CPU:1"), one_expected)
if test_util.is_gpu_available():
with ops.device("/device:GPU:0"):
self.assertIn(
compat.as_bytes("GPU:0"),
self.evaluate(cond_v2.cond_v2(constant_op.constant(True),
fn2, fn2)))
else:
self.skipTest("Test requires a GPU to check GPU device placement.")
def fn2():
self.assertEqual("/device:GPU:0", constant_op.constant(3.0).op.device)
return test_ops.device_placement_op()
@def_function.function
def _cond_wrapper2():
with ops.device("/device:GPU:0"):
return cond_v2.cond_v2(constant_op.constant(True), fn2, fn2)
if test_util.is_gpu_available():
self.assertIn(compat.as_bytes("GPU:0"),
self.evaluate(_cond_wrapper2()))
else:
self.skipTest("Test requires a GPU to check GPU device placement.")
def testColocationBeforeCond(self):
with context.eager_mode():
def _fn():
result = test_ops.device_placement_op()
self.assertIn("colocation_test_op",
result.op.colocation_groups()[0].decode())
return result
@def_function.function(autograph=False)
def _cond_wrapper():
with ops.device("/device:CPU:0"):
op_on_cpu_0 = test_ops.device_placement_op(name="colocation_test_op")
with ops.device("/device:CPU:1"):
op_on_cpu_1 = test_ops.device_placement_op(
name="colocation_test_op_1")
condition = constant_op.constant(True)
with ops.colocate_with(op_on_cpu_0.op):
zero_expected = cond_v2.cond_v2(condition, _fn, _fn)
with ops.colocate_with(op_on_cpu_1.op):
one_expected = cond_v2.cond_v2(condition, _fn, _fn)
return zero_expected, one_expected
zero_expected, one_expected = self.evaluate(_cond_wrapper())
self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
self.assertIn(compat.as_bytes("CPU:1"), one_expected)
def testDeviceInAndOutOfCond(self):
with ops.Graph().as_default() as g:

View File

@ -223,6 +223,27 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
self.checkIteratedGradients(_Func)
def testDeviceLabelsInherited(self):
def _LoopBody(i, y):
result = math_ops.cos(y)
self.assertIn("CPU:10", result.device)
with ops.device("CPU:11"):
result = array_ops.identity(result)
self.assertIn("CPU:11", result.device)
return i + 1, result
@def_function.function
def _FunctionWithWhileLoop():
x = constant_op.constant(1.)
with ops.device("CPU:10"):
_, z = while_loop_v2(
lambda i, _: i < 2,
_LoopBody,
[0, x])
return z
# The test assertion runs at trace time.
_FunctionWithWhileLoop.get_concrete_function()
def testExternalControlDependencies(self):
with ops.Graph().as_default(), self.test_session():
v = variables.Variable(1.)

View File

@ -22,43 +22,39 @@ from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
class CondBranchFuncGraph(func_graph.FuncGraph):
class ControlFlowFuncGraph(func_graph.FuncGraph):
"""Contains control flow-specific FuncGraph logic."""
def __init__(self, *args, **kwargs):
super(ControlFlowFuncGraph, self).__init__(*args, **kwargs)
outer_graph = self.outer_graph
# Unlike tf.function, control flow FuncGraphs are generally created one per
# op. This means hard-coding any outer device scopes in the body (rather
# than inspecting the call-time placement of the control flow op) makes
# sense.
self._device_function_stack = outer_graph._device_function_stack.copy() # pylint: disable=protected-access
self.is_control_flow_graph = True
if ops.executing_eagerly_outside_functions():
func_graph.override_func_graph_name_scope(
self, self.outer_graph.get_name_scope())
class CondBranchFuncGraph(ControlFlowFuncGraph):
"""FuncGraph for branches of tf.cond().
This is used to distinguish cond branches from other functions.
"""
def __init__(self, *args, **kwargs):
super(CondBranchFuncGraph, self).__init__(*args, **kwargs)
self.is_control_flow_graph = True
if ops.executing_eagerly_outside_functions():
func_graph.override_func_graph_name_scope(
self, self.outer_graph.get_name_scope())
class WhileCondFuncGraph(func_graph.FuncGraph):
class WhileCondFuncGraph(ControlFlowFuncGraph):
"""FuncGraph for the condition of tf.while_loop().
This is used to distinguish while conditions from other functions.
"""
def __init__(self, *args, **kwargs):
super(WhileCondFuncGraph, self).__init__(*args, **kwargs)
self.is_control_flow_graph = True
if ops.executing_eagerly_outside_functions():
func_graph.override_func_graph_name_scope(
self, self.outer_graph.get_name_scope())
class WhileBodyFuncGraph(func_graph.FuncGraph):
class WhileBodyFuncGraph(ControlFlowFuncGraph):
"""FuncGraph for the body of tf.while_loop().
This is used to distinguish while bodies from other functions.
"""
def __init__(self, *args, **kwargs):
super(WhileBodyFuncGraph, self).__init__(*args, **kwargs)
self.is_control_flow_graph = True
if ops.executing_eagerly_outside_functions():
func_graph.override_func_graph_name_scope(
self, self.outer_graph.get_name_scope())