diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index aec0e169584..248b30c4c24 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -2513,13 +2513,10 @@ tf_cc_test( ], ) -tf_cuda_cc_test( +tf_cc_test( name = "lower_if_op_test", size = "small", srcs = ["lower_if_op_test.cc"], - tags = tf_cuda_tests_tags() + [ - "no_cuda_asan", # TODO(b/171575050): re-enable once fixed. - ], deps = [ ":core_cpu", ":core_cpu_internal", diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index 2a0e5d35de5..ff010ad8a63 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -148,22 +148,13 @@ Status CondBuilder::SetColocationAndFinalize(NodeBuilder node_builder, Status CondBuilder::CreatePivotNodes() { // Construct the basic cond body (consisting of feeding in the predicate to // create pivot nodes). - - // This is a special pivot switch node for lowering. We mark this with a - // special _PivotSwitch attr on it as later on in the graph partitioner we - // do some special placement for Switch nodes and its necessary to distinguish - // between a "normal" Switch node and one of these pivot switches. We would - // like to place this node on the CPU always as the pred_ will be on the CPU - // as well (either a CPU op output or a GPU op with HostMemory annotation). - // TODO(b/171321391): Fix this for NUMA cases. Node* switch_pred; TF_RETURN_IF_ERROR( SetColocationAndFinalize(NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry(), &debug_info_) .Input(NodeOut(pred_)) .Input(NodeOut(pred_)) - .Attr("_PivotSwitch", true) - .Device("/CPU:0"), + .Device(if_op_->requested_device()), graph_, &switch_pred)); control_predecessor_ = switch_pred; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/common_runtime/lower_if_op_test.cc b/tensorflow/core/common_runtime/lower_if_op_test.cc index b0304cfe293..cf7d35409bb 100644 --- a/tensorflow/core/common_runtime/lower_if_op_test.cc +++ b/tensorflow/core/common_runtime/lower_if_op_test.cc @@ -147,115 +147,6 @@ TEST(LowerIfOpTest, Simple) { } } -TEST(LowerIfOpTest, GPUPlacement) { - std::unique_ptr graph(new Graph(OpRegistry::Global())); - - // Add test functions for then and else branch. - FunctionDefLibrary f_lib_proto; - *(f_lib_proto.add_function()) = test::function::XTimesTwo(); - *(f_lib_proto.add_function()) = test::function::XTimesFour(); - - // Construct simple conditional that switches on `pred` and operates only on - // single input `A`. - Scope root = Scope::NewRootScope().ExitOnError(); - TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto)); - auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32); - auto x = ops::Placeholder(root.WithOpName("X"), DT_INT32); - auto y = ops::Placeholder(root.WithOpName("Y"), DT_INT32); - Node* pred; - TF_ASSERT_OK(NodeBuilder("greater", "Greater", &root.graph()->flib_def()) - .Input(x.node()) - .Input(y.node()) - .Device("/GPU:0") - .Finalize(root.graph(), &pred)); - Node* written_if; - std::vector inputs({NodeBuilder::NodeOut(a.node())}); - TF_ASSERT_OK( - NodeBuilder("if", "If", &root.graph()->flib_def()) - .Input(pred) - .Input(inputs) - .Attr("then_branch", FuncAttr("XTimesTwo")) - .Attr("else_branch", FuncAttr("XTimesFour")) - .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true) - .Attr("Tout", {DT_INT32}) - .Device("/GPU:0") - .Finalize(root.graph(), &written_if)); - TF_ASSERT_OK(root.DoShapeInference(written_if)); - TF_ASSERT_OK(root.ToGraph(graph.get())); - - // The input graph has no switch or merge nodes. - int node_called_if_count = 0; - for (const auto* op : graph->op_nodes()) { - ASSERT_FALSE(op->IsSwitch()); - ASSERT_FALSE(op->IsMerge()); - if (op->name() == "if") { - ++node_called_if_count; - } - } - ASSERT_EQ(node_called_if_count, 1); - - TF_ASSERT_OK(Rewrite(&graph)); - - // Verify the resultant graph has switch and merge nodes, and a node called - // `if` (but not If nodes). - int switch_count = 0; - int merge_count = 0; - node_called_if_count = 0; - for (const auto* op : graph->op_nodes()) { - if (op->IsSwitch()) { - ++switch_count; - } - if (op->IsMerge()) { - ++merge_count; - } - ASSERT_NE(op->type_string(), "If"); - if (op->name() == "if") { - ++node_called_if_count; - } - } - // One switch for predicate and one for input (A). - ASSERT_EQ(switch_count, 2); - // One merge for the single output value of then and else, and one more merge - // to enforce then and else function call execution (`branch_executed` node). - ASSERT_EQ(merge_count, 2); - ASSERT_EQ(node_called_if_count, 1); - - // Verify execution. - ClientSession session(root, SessionOptionsWithInlining()); - { - RunMetadata metadata; - RunOptions options; - options.set_output_partition_graphs(true); - ClientSession::FeedType feeds; - feeds.emplace(Output(x.node()), Input::Initializer(5)); - feeds.emplace(Output(y.node()), Input::Initializer(10)); - feeds.emplace(Output(a.node()), Input::Initializer(10)); - std::vector out_tensors; - TF_ASSERT_OK(session.Run(options, feeds, {Output(written_if)}, {}, - &out_tensors, &metadata)); - GraphDef cpu_graph = metadata.partition_graphs(1); - int num_cpu_switch = 0; - for (const auto& node : cpu_graph.node()) { - if (node.op() == "Switch") { - ++num_cpu_switch; - } - } - EXPECT_EQ(num_cpu_switch, 2); - EXPECT_EQ(out_tensors.size(), 1); - EXPECT_EQ(out_tensors[0].scalar()(), 40); - } - { - ClientSession::FeedType feeds; - feeds.emplace(Output(x.node()), Input::Initializer(10)); - feeds.emplace(Output(y.node()), Input::Initializer(5)); - feeds.emplace(Output(a.node()), Input::Initializer(10)); - std::vector out_tensors; - TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors)); - EXPECT_EQ(out_tensors.size(), 1); - EXPECT_EQ(out_tensors[0].scalar()(), 20); - } -} - TEST(LowerIfOpTest, BranchFunctionsWithoutOutputs) { using ::tensorflow::test::function::GDef; using ::tensorflow::test::function::NDef; diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 7680bcacba5..bf57e263441 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -371,13 +371,6 @@ NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef, void OptimizeControlFlowColocation(Graph* graph) { auto visit = [](Node* node) { if (IsSwitch(node)) { - // Pivot Switch nodes (which are also of type Switch) are already placed - // on the CPU and colocated with its inputs that are also already on the - // CPU (or might be placed on GPU but in host memory). - if (HasNodeAttr(node->def(), "_PivotSwitch")) { - DCHECK(node->requested_device().find("CPU") != string::npos); - return; - } for (const Edge* in_edge : node->in_edges()) { if (in_edge->dst_input() == 0) { // Colocate with the data input. diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 532dac1d85a..54bbd2b2e9e 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -730,8 +730,6 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): g for g in run_metadata.partition_graphs if device_str in g.node[0].device ] - if not device_graphs: - return 0 self.assertLen(device_graphs, 1) switch_nodes = [ n for n in device_graphs[0].node @@ -761,6 +759,7 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): options = config_pb2.RunOptions(output_partition_graphs=True) sess.run( r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata) + self.assertLen(run_metadata.partition_graphs, 2) # Check that the Switch for `arg` gets placed on CPU. self.assertEqual( self._count_matching_switch_nodes_on_device(run_metadata, "CPU",