Fix colocation attr on Switch in lower_if_op.cc

PiperOrigin-RevId: 331886863
Change-Id: I9cd889b2c380b31ae025f4050eece83f7851a4aa
This commit is contained in:
Zheng Xu 2020-09-15 16:55:18 -07:00 committed by TensorFlower Gardener
parent b4f7c37ccb
commit a873a3e075
2 changed files with 17 additions and 58 deletions

View File

@ -166,15 +166,13 @@ Status CondBuilder::AddInput(Node* src, int src_output) {
// the same device as the input node (if set) and sets the colocation _class
// attr. It also ignores the existing colocation constraints on the input node
// using colocate_with(ignore_existing=True).
TF_RETURN_IF_ERROR(
NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry(),
&debug_info)
.Input(src, src_output)
.Input(pred_)
.Device(src->requested_device())
.Attr(kColocationAttrName,
{absl::StrCat(kColocationGroupPrefix, src->name())})
.Finalize(graph_, &input));
TF_RETURN_IF_ERROR(NodeBuilder(NewName(src->name()), "Switch",
graph_->op_registry(), &debug_info)
.Input(src, src_output)
.Input(pred_)
.Device(src->requested_device())
.Attr("_class", {src->name()})
.Finalize(graph_, &input));
then_call_builder_.Input(input, kThenBranch);
else_call_builder_.Input(input, kElseBranch);
return Status::OK();

View File

@ -34,8 +34,6 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python import tf2
from tensorflow.python.client import device_lib
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as eager_function
@ -722,9 +720,8 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
# We expect that everything runs on CPU, even if GPU is available.
self.assertEqual(len(run_metadata.partition_graphs), 1)
def _count_matching_switch_nodes_on_device(self, run_metadata, device_str,
dtype):
# Returns the number of Switch nodes with type dtype placed on
def _count_matching_switch_nodes_on_device(self, run_metadata, device_str):
# Returns the number of Switch nodes with type float32 placed on
# `device_str`.
device_graphs = [
g for g in run_metadata.partition_graphs
@ -732,14 +729,14 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
]
self.assertLen(device_graphs, 1)
switch_nodes = [
n for n in device_graphs[0].node
if n.op == "Switch" and n.attr["T"].type == dtype.as_datatype_enum
n for n in device_graphs[0].node if n.op == "Switch" and
n.attr["T"].type == dtypes.float32.as_datatype_enum
]
return len(switch_nodes)
@test_util.run_gpu_only
@test_util.run_deprecated_v1
def testCondSwitchColocatedWithInputWhenInputExplicitlyPlacedOnCPU(self):
def testCondSwitchColocatedWithInputWhenInputOnCPU(self):
x = array_ops.placeholder(dtypes.float32)
# `arg` is used in the cond then branch so a Switch node is created for it.
@ -759,46 +756,12 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
options = config_pb2.RunOptions(output_partition_graphs=True)
sess.run(
r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
self.assertLen(run_metadata.partition_graphs, 2)
self.assertEqual(len(run_metadata.partition_graphs), 2)
# Check that the Switch for `arg` gets placed on CPU.
self.assertEqual(
self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
dtypes.float32), 1)
self._count_matching_switch_nodes_on_device(run_metadata, "CPU"), 1)
self.assertEqual(
self._count_matching_switch_nodes_on_device(run_metadata, "GPU",
dtypes.float32), 0)
@test_util.run_gpu_only
@test_util.run_deprecated_v1
def testCondSwitchColocatedWithInputWhenInputPlacedOnCPU(self):
x = array_ops.placeholder(dtypes.float32)
# `arg` is used in the cond then branch so a Switch node is created for it.
# We test that the Switch node gets placed on the same device as `arg`.
# Since arg is a dataset (and only has a CPU kernel), it gets placed on CPU
# by placer.
arg = dataset_ops.Dataset.range(8)
def true_fn():
return cardinality.cardinality(arg)
r = control_flow_ops.cond(
constant_op.constant(True), true_fn,
lambda: constant_op.constant(0, dtypes.int64))
with session.Session() as sess:
run_metadata = config_pb2.RunMetadata()
options = config_pb2.RunOptions(output_partition_graphs=True)
sess.run(
r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
self.assertLen(run_metadata.partition_graphs, 2)
# Check that the Switch for `arg` gets placed on CPU.
self.assertEqual(
self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
dtypes.variant), 1)
self.assertEqual(
self._count_matching_switch_nodes_on_device(run_metadata, "GPU",
dtypes.variant), 0)
self._count_matching_switch_nodes_on_device(run_metadata, "GPU"), 0)
@test_util.run_gpu_only
@test_util.run_deprecated_v1
@ -824,11 +787,9 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
self.assertEqual(len(run_metadata.partition_graphs), 2)
# Check that the Switch for `arg` gets placed on GPU.
self.assertEqual(
self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
dtypes.float32), 0)
self._count_matching_switch_nodes_on_device(run_metadata, "CPU"), 0)
self.assertEqual(
self._count_matching_switch_nodes_on_device(run_metadata, "GPU",
dtypes.float32), 1)
self._count_matching_switch_nodes_on_device(run_metadata, "GPU"), 1)
def testCondAccessTrueBranchTensorInFalseBranchRaises(self):