Inherit device scopes when building control flow operations
tf.function allows calling a single function body placed on different devices, but there's no need for that with control flow. PiperOrigin-RevId: 338367553 Change-Id: I9659769737d40e6c9f5ffdad551dd46b6592c5fd
This commit is contained in:
parent
6cc0cf5e30
commit
d3698cdfd9
tensorflow/python
@ -1441,6 +1441,15 @@ class CondV2ContainerTest(test.TestCase):
|
||||
|
||||
class CondV2ColocationGroupAndDeviceTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(CondV2ColocationGroupAndDeviceTest, self).setUp()
|
||||
cpus = context.context().list_physical_devices("CPU")
|
||||
context.context().set_logical_device_configuration(
|
||||
cpus[0], [
|
||||
context.LogicalDeviceConfiguration(),
|
||||
context.LogicalDeviceConfiguration()
|
||||
])
|
||||
|
||||
def testColocateWithBeforeCond(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g):
|
||||
@ -1516,31 +1525,64 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
|
||||
self.assertTrue(len(run_metadata.partition_graphs) >= 2)
|
||||
|
||||
def testDeviceBeforeCond(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g):
|
||||
|
||||
def fn():
|
||||
self.assertEqual("", constant_op.constant(3.0).op.device)
|
||||
return test_ops.device_placement_op()
|
||||
with context.eager_mode():
|
||||
def fn():
|
||||
cpu_zero_op = test_ops.device_placement_op()
|
||||
self.assertEqual("/device:CPU:0", cpu_zero_op.device)
|
||||
with ops.device("CPU:1"):
|
||||
cpu_one_op = test_ops.device_placement_op()
|
||||
self.assertEqual("/device:CPU:1", cpu_one_op.device)
|
||||
return cpu_zero_op, cpu_one_op
|
||||
|
||||
@def_function.function
|
||||
def _cond_wrapper():
|
||||
with ops.device("/device:CPU:0"):
|
||||
self.assertIn(
|
||||
compat.as_bytes("CPU:0"),
|
||||
self.evaluate(cond_v2.cond_v2(constant_op.constant(True),
|
||||
fn, fn)))
|
||||
return cond_v2.cond_v2(constant_op.constant(True), fn, fn)
|
||||
|
||||
def fn2():
|
||||
self.assertEqual("", constant_op.constant(3.0).op.device)
|
||||
return test_ops.device_placement_op()
|
||||
zero_expected, one_expected = self.evaluate(_cond_wrapper())
|
||||
self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
|
||||
self.assertIn(compat.as_bytes("CPU:1"), one_expected)
|
||||
|
||||
if test_util.is_gpu_available():
|
||||
with ops.device("/device:GPU:0"):
|
||||
self.assertIn(
|
||||
compat.as_bytes("GPU:0"),
|
||||
self.evaluate(cond_v2.cond_v2(constant_op.constant(True),
|
||||
fn2, fn2)))
|
||||
else:
|
||||
self.skipTest("Test requires a GPU to check GPU device placement.")
|
||||
def fn2():
|
||||
self.assertEqual("/device:GPU:0", constant_op.constant(3.0).op.device)
|
||||
return test_ops.device_placement_op()
|
||||
|
||||
@def_function.function
|
||||
def _cond_wrapper2():
|
||||
with ops.device("/device:GPU:0"):
|
||||
return cond_v2.cond_v2(constant_op.constant(True), fn2, fn2)
|
||||
|
||||
if test_util.is_gpu_available():
|
||||
self.assertIn(compat.as_bytes("GPU:0"),
|
||||
self.evaluate(_cond_wrapper2()))
|
||||
else:
|
||||
self.skipTest("Test requires a GPU to check GPU device placement.")
|
||||
|
||||
def testColocationBeforeCond(self):
|
||||
with context.eager_mode():
|
||||
|
||||
def _fn():
|
||||
result = test_ops.device_placement_op()
|
||||
self.assertIn("colocation_test_op",
|
||||
result.op.colocation_groups()[0].decode())
|
||||
return result
|
||||
|
||||
@def_function.function(autograph=False)
|
||||
def _cond_wrapper():
|
||||
with ops.device("/device:CPU:0"):
|
||||
op_on_cpu_0 = test_ops.device_placement_op(name="colocation_test_op")
|
||||
with ops.device("/device:CPU:1"):
|
||||
op_on_cpu_1 = test_ops.device_placement_op(
|
||||
name="colocation_test_op_1")
|
||||
condition = constant_op.constant(True)
|
||||
with ops.colocate_with(op_on_cpu_0.op):
|
||||
zero_expected = cond_v2.cond_v2(condition, _fn, _fn)
|
||||
with ops.colocate_with(op_on_cpu_1.op):
|
||||
one_expected = cond_v2.cond_v2(condition, _fn, _fn)
|
||||
return zero_expected, one_expected
|
||||
zero_expected, one_expected = self.evaluate(_cond_wrapper())
|
||||
self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
|
||||
self.assertIn(compat.as_bytes("CPU:1"), one_expected)
|
||||
|
||||
def testDeviceInAndOutOfCond(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
|
@ -223,6 +223,27 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.checkIteratedGradients(_Func)
|
||||
|
||||
def testDeviceLabelsInherited(self):
|
||||
def _LoopBody(i, y):
|
||||
result = math_ops.cos(y)
|
||||
self.assertIn("CPU:10", result.device)
|
||||
with ops.device("CPU:11"):
|
||||
result = array_ops.identity(result)
|
||||
self.assertIn("CPU:11", result.device)
|
||||
return i + 1, result
|
||||
|
||||
@def_function.function
|
||||
def _FunctionWithWhileLoop():
|
||||
x = constant_op.constant(1.)
|
||||
with ops.device("CPU:10"):
|
||||
_, z = while_loop_v2(
|
||||
lambda i, _: i < 2,
|
||||
_LoopBody,
|
||||
[0, x])
|
||||
return z
|
||||
# The test assertion runs at trace time.
|
||||
_FunctionWithWhileLoop.get_concrete_function()
|
||||
|
||||
def testExternalControlDependencies(self):
|
||||
with ops.Graph().as_default(), self.test_session():
|
||||
v = variables.Variable(1.)
|
||||
|
@ -22,43 +22,39 @@ from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
class CondBranchFuncGraph(func_graph.FuncGraph):
|
||||
class ControlFlowFuncGraph(func_graph.FuncGraph):
|
||||
"""Contains control flow-specific FuncGraph logic."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ControlFlowFuncGraph, self).__init__(*args, **kwargs)
|
||||
outer_graph = self.outer_graph
|
||||
# Unlike tf.function, control flow FuncGraphs are generally created one per
|
||||
# op. This means hard-coding any outer device scopes in the body (rather
|
||||
# than inspecting the call-time placement of the control flow op) makes
|
||||
# sense.
|
||||
self._device_function_stack = outer_graph._device_function_stack.copy() # pylint: disable=protected-access
|
||||
self.is_control_flow_graph = True
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
func_graph.override_func_graph_name_scope(
|
||||
self, self.outer_graph.get_name_scope())
|
||||
|
||||
|
||||
class CondBranchFuncGraph(ControlFlowFuncGraph):
|
||||
"""FuncGraph for branches of tf.cond().
|
||||
|
||||
This is used to distinguish cond branches from other functions.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CondBranchFuncGraph, self).__init__(*args, **kwargs)
|
||||
self.is_control_flow_graph = True
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
func_graph.override_func_graph_name_scope(
|
||||
self, self.outer_graph.get_name_scope())
|
||||
|
||||
|
||||
class WhileCondFuncGraph(func_graph.FuncGraph):
|
||||
class WhileCondFuncGraph(ControlFlowFuncGraph):
|
||||
"""FuncGraph for the condition of tf.while_loop().
|
||||
|
||||
This is used to distinguish while conditions from other functions.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(WhileCondFuncGraph, self).__init__(*args, **kwargs)
|
||||
self.is_control_flow_graph = True
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
func_graph.override_func_graph_name_scope(
|
||||
self, self.outer_graph.get_name_scope())
|
||||
|
||||
|
||||
class WhileBodyFuncGraph(func_graph.FuncGraph):
|
||||
class WhileBodyFuncGraph(ControlFlowFuncGraph):
|
||||
"""FuncGraph for the body of tf.while_loop().
|
||||
|
||||
This is used to distinguish while bodies from other functions.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(WhileBodyFuncGraph, self).__init__(*args, **kwargs)
|
||||
self.is_control_flow_graph = True
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
func_graph.override_func_graph_name_scope(
|
||||
self, self.outer_graph.get_name_scope())
|
||||
|
Loading…
Reference in New Issue
Block a user