diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 76a3c276e2d..484bdeee3bd 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -2511,10 +2511,12 @@ tf_cc_test( "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/core/common_runtime/lower_while_op.cc b/tensorflow/core/common_runtime/lower_while_op.cc index a28959703e5..90fdc886c50 100644 --- a/tensorflow/core/common_runtime/lower_while_op.cc +++ b/tensorflow/core/common_runtime/lower_while_op.cc @@ -444,15 +444,14 @@ Status LowerWhileHelper::CreateNextIterationNodes() { if (IsResource(i)) { continue; } - TF_RETURN_IF_ERROR( - NodeBuilder(NewName("next_iteration"), "NextIteration", - graph_->op_registry(), &debug_info_) - .Input(NodeOut(body_call_node_, i)) - .ControlInput(body_call_node_) - .Device(while_op_->requested_device()) - .AssignedDevice(merge_nodes_[op_input_output_to_lowered_node_[i]] - ->assigned_device_name()) - .Finalize(graph_, &next_iteration)); + Node* merge_node = merge_nodes_[op_input_output_to_lowered_node_[i]]; + TF_RETURN_IF_ERROR(NodeBuilder(NewName("next_iteration"), "NextIteration", + graph_->op_registry(), &debug_info_) + .Input(NodeOut(body_call_node_, i)) + .ControlInput(body_call_node_) + .Device(merge_node->requested_device()) + .AssignedDevice(merge_node->assigned_device_name()) + .Finalize(graph_, &next_iteration)); next_iterations_nodes_.emplace_back(next_iteration); } return Status::OK(); diff --git a/tensorflow/core/common_runtime/lower_while_op_test.cc b/tensorflow/core/common_runtime/lower_while_op_test.cc index 65b9b523444..9d7870f891d 100644 --- a/tensorflow/core/common_runtime/lower_while_op_test.cc +++ b/tensorflow/core/common_runtime/lower_while_op_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/strings/match.h" #include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" @@ -25,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -262,6 +264,145 @@ TEST(LowerWhileOpTest, ForwardAssignedInputDevice) { ASSERT_EQ(exit_consumers, 1); } +TEST(LowerWhileOpTest, ForwardRequestedInputDevice) { + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + + // Add test functions for cond and body. + FunctionDefLibrary f_lib_proto; + *f_lib_proto.add_function() = test::function::XTimesTwo(); + *f_lib_proto.add_function() = test::function::LessThanOrEqualToN(8); + + TF_ASSERT_OK(graph->AddFunctionLibrary(f_lib_proto)); + auto type = DT_FLOAT; + // We will place the loop var on the gpu:0. + const string gpu_0_device = "/job:localhost/replica:0/task:0/gpu:0"; + // We will place loop's control input on the gpu:1. + const string gpu_1_device = "/job:localhost/replica:0/task:0/gpu:1"; + // We will place While op on gpu:2. + const string gpu_2_device = "/job:localhost/replica:0/task:0/gpu:2"; + Node* gpu_0_ph; + TF_CHECK_OK(NodeBuilder("placed_node", "Placeholder") + .Attr("dtype", type) + .Device(gpu_0_device) + .Finalize(graph.get(), &gpu_0_ph)); + Node* control_in; + // Add a control input to the While op to trigger the creation of a + // LoopExecuted node. + TF_CHECK_OK(NodeBuilder("control_in", "Placeholder") + .Attr("dtype", type) + .Device(gpu_1_device) + .Finalize(graph.get(), &control_in)); + Node* while_node; + std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(gpu_0_ph)}); + AttrValue cond_func; + cond_func.mutable_func()->set_name("LessThanOrEqualToN"); + AttrValue body_func; + body_func.mutable_func()->set_name("XTimesTwo"); + TF_ASSERT_OK( + NodeBuilder("while", "While", &graph->flib_def()) + .Input(inputs) + .ControlInput(control_in) + .Device(gpu_2_device) + .Attr("T", {type}) + .Attr("cond", cond_func) + .Attr("body", body_func) + .Attr("parallel_iterations", 100) + .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true) + .Finalize(graph.get(), &while_node)); + + // Create an empty Const node with control dep from the While op. + // This triggers the creation of a LoopExecuted node. + Node* control_out; + TensorProto proto; + proto.set_dtype(DT_FLOAT); + TensorShape empty_shape({0}); + empty_shape.AsProto(proto.mutable_tensor_shape()); + TF_ASSERT_OK(NodeBuilder("control_out", "Const") + .ControlInput(while_node) + .Attr("dtype", DT_FLOAT) + .Attr("value", proto) + .Finalize(graph.get(), &control_out)); + + TF_ASSERT_OK(Rewrite(&graph)); + + const Node* placeholder_node = nullptr; + for (const auto* op : graph->op_nodes()) { + if (op->name() == "placed_node") { + placeholder_node = op; + } + } + ASSERT_NE(placeholder_node, nullptr); + // Verify the requested device of the Enter node. + int enter_consumers = 0; + const Node* enter_node = nullptr; + for (const Node* consumer : placeholder_node->out_nodes()) { + if (consumer->type_string() == "Enter") { + enter_consumers += 1; + enter_node = consumer; + ASSERT_EQ(consumer->requested_device(), gpu_0_device); + } + } + ASSERT_EQ(enter_consumers, 1); + // Verify the requested device of the Merge node. + int merge_consumers = 0; + const Node* merge_node = nullptr; + for (const Node* consumer : enter_node->out_nodes()) { + if (consumer->type_string() == "Merge") { + merge_consumers += 1; + merge_node = consumer; + ASSERT_EQ(consumer->requested_device(), gpu_0_device); + } + } + ASSERT_EQ(merge_consumers, 1); + // Verify the requested device of the NextIteration node. + int next_iteration_consumers = 0; + for (const Node* consumer : merge_node->in_nodes()) { + if (consumer->type_string() == "NextIteration") { + next_iteration_consumers += 1; + ASSERT_EQ(consumer->requested_device(), gpu_0_device); + } + } + ASSERT_EQ(next_iteration_consumers, 1); + // Verify the requested device of the Switch node. + int switch_consumers = 0; + const Node* switch_node = nullptr; + for (const Node* consumer : merge_node->out_nodes()) { + if (consumer->type_string() == "Switch") { + switch_consumers += 1; + switch_node = consumer; + ASSERT_EQ(consumer->requested_device(), gpu_0_device); + } + } + ASSERT_EQ(switch_consumers, 1); + // Verify the requested device of the Exit node. + int exit_consumers = 0; + for (const Node* consumer : switch_node->out_nodes()) { + if (consumer->type_string() == "Exit") { + exit_consumers += 1; + ASSERT_EQ(consumer->requested_device(), gpu_0_device); + } + } + ASSERT_EQ(exit_consumers, 1); + // Verify the requested device of LoopControlInputs. + const Node* loop_control_inputs_node = nullptr; + for (const auto* op : graph->op_nodes()) { + if (absl::StrContains(op->name(), "LoopControlInputs")) { + loop_control_inputs_node = op; + } + } + ASSERT_NE(loop_control_inputs_node, nullptr); + ASSERT_EQ(loop_control_inputs_node->requested_device(), gpu_2_device); + // Verify the requested device of LoopExecuted. + const Node* loop_executed_node = nullptr; + for (const auto* op : graph->op_nodes()) { + if (absl::StrContains(op->name(), "LoopExecuted")) { + loop_executed_node = op; + } + } + ASSERT_NE(loop_executed_node, nullptr); + ASSERT_EQ(loop_executed_node->requested_device(), gpu_2_device); +} + TEST(LowerWhileOpTest, MultipleInputs) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));