diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD
index 76a3c276e2d..484bdeee3bd 100644
--- a/tensorflow/core/common_runtime/BUILD
+++ b/tensorflow/core/common_runtime/BUILD
@@ -2511,10 +2511,12 @@ tf_cc_test(
         "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core:testlib",
         "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/strings",
     ],
 )
 
diff --git a/tensorflow/core/common_runtime/lower_while_op.cc b/tensorflow/core/common_runtime/lower_while_op.cc
index a28959703e5..90fdc886c50 100644
--- a/tensorflow/core/common_runtime/lower_while_op.cc
+++ b/tensorflow/core/common_runtime/lower_while_op.cc
@@ -444,15 +444,14 @@ Status LowerWhileHelper::CreateNextIterationNodes() {
     if (IsResource(i)) {
       continue;
     }
-    TF_RETURN_IF_ERROR(
-        NodeBuilder(NewName("next_iteration"), "NextIteration",
-                    graph_->op_registry(), &debug_info_)
-            .Input(NodeOut(body_call_node_, i))
-            .ControlInput(body_call_node_)
-            .Device(while_op_->requested_device())
-            .AssignedDevice(merge_nodes_[op_input_output_to_lowered_node_[i]]
-                                ->assigned_device_name())
-            .Finalize(graph_, &next_iteration));
+    Node* merge_node = merge_nodes_[op_input_output_to_lowered_node_[i]];
+    TF_RETURN_IF_ERROR(NodeBuilder(NewName("next_iteration"), "NextIteration",
+                                   graph_->op_registry(), &debug_info_)
+                           .Input(NodeOut(body_call_node_, i))
+                           .ControlInput(body_call_node_)
+                           .Device(merge_node->requested_device())
+                           .AssignedDevice(merge_node->assigned_device_name())
+                           .Finalize(graph_, &next_iteration));
     next_iterations_nodes_.emplace_back(next_iteration);
   }
   return Status::OK();
diff --git a/tensorflow/core/common_runtime/lower_while_op_test.cc b/tensorflow/core/common_runtime/lower_while_op_test.cc
index 65b9b523444..9d7870f891d 100644
--- a/tensorflow/core/common_runtime/lower_while_op_test.cc
+++ b/tensorflow/core/common_runtime/lower_while_op_test.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include "absl/strings/match.h"
 #include "tensorflow/cc/client/client_session.h"
 #include "tensorflow/cc/framework/ops.h"
 #include "tensorflow/cc/ops/array_ops.h"
@@ -25,6 +26,7 @@ limitations under the License.
 #include "tensorflow/core/framework/function_testlib.h"
 #include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/graph/graph_def_builder.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/strings/str_util.h"
@@ -262,6 +264,145 @@ TEST(LowerWhileOpTest, ForwardAssignedInputDevice) {
   ASSERT_EQ(exit_consumers, 1);
 }
 
+TEST(LowerWhileOpTest, ForwardRequestedInputDevice) {
+  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+  // Add test functions for cond and body.
+  FunctionDefLibrary f_lib_proto;
+  *f_lib_proto.add_function() = test::function::XTimesTwo();
+  *f_lib_proto.add_function() = test::function::LessThanOrEqualToN(8);
+
+  TF_ASSERT_OK(graph->AddFunctionLibrary(f_lib_proto));
+  auto type = DT_FLOAT;
+  // We will place the loop var on the gpu:0.
+  const string gpu_0_device = "/job:localhost/replica:0/task:0/gpu:0";
+  // We will place loop's control input on the gpu:1.
+  const string gpu_1_device = "/job:localhost/replica:0/task:0/gpu:1";
+  // We will place While op on gpu:2.
+  const string gpu_2_device = "/job:localhost/replica:0/task:0/gpu:2";
+  Node* gpu_0_ph;
+  TF_CHECK_OK(NodeBuilder("placed_node", "Placeholder")
+                  .Attr("dtype", type)
+                  .Device(gpu_0_device)
+                  .Finalize(graph.get(), &gpu_0_ph));
+  Node* control_in;
+  // Add a control input to the While op to trigger the creation of a
+  // LoopExecuted node.
+  TF_CHECK_OK(NodeBuilder("control_in", "Placeholder")
+                  .Attr("dtype", type)
+                  .Device(gpu_1_device)
+                  .Finalize(graph.get(), &control_in));
+  Node* while_node;
+  std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(gpu_0_ph)});
+  AttrValue cond_func;
+  cond_func.mutable_func()->set_name("LessThanOrEqualToN");
+  AttrValue body_func;
+  body_func.mutable_func()->set_name("XTimesTwo");
+  TF_ASSERT_OK(
+      NodeBuilder("while", "While", &graph->flib_def())
+          .Input(inputs)
+          .ControlInput(control_in)
+          .Device(gpu_2_device)
+          .Attr("T", {type})
+          .Attr("cond", cond_func)
+          .Attr("body", body_func)
+          .Attr("parallel_iterations", 100)
+          .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
+          .Finalize(graph.get(), &while_node));
+
+  // Create an empty Const node with control dep from the While op.
+  // This triggers the creation of a LoopExecuted node.
+  Node* control_out;
+  TensorProto proto;
+  proto.set_dtype(DT_FLOAT);
+  TensorShape empty_shape({0});
+  empty_shape.AsProto(proto.mutable_tensor_shape());
+  TF_ASSERT_OK(NodeBuilder("control_out", "Const")
+                   .ControlInput(while_node)
+                   .Attr("dtype", DT_FLOAT)
+                   .Attr("value", proto)
+                   .Finalize(graph.get(), &control_out));
+
+  TF_ASSERT_OK(Rewrite(&graph));
+
+  const Node* placeholder_node = nullptr;
+  for (const auto* op : graph->op_nodes()) {
+    if (op->name() == "placed_node") {
+      placeholder_node = op;
+    }
+  }
+  ASSERT_NE(placeholder_node, nullptr);
+  // Verify the requested device of the Enter node.
+  int enter_consumers = 0;
+  const Node* enter_node = nullptr;
+  for (const Node* consumer : placeholder_node->out_nodes()) {
+    if (consumer->type_string() == "Enter") {
+      enter_consumers += 1;
+      enter_node = consumer;
+      ASSERT_EQ(consumer->requested_device(), gpu_0_device);
+    }
+  }
+  ASSERT_EQ(enter_consumers, 1);
+  // Verify the requested device of the Merge node.
+  int merge_consumers = 0;
+  const Node* merge_node = nullptr;
+  for (const Node* consumer : enter_node->out_nodes()) {
+    if (consumer->type_string() == "Merge") {
+      merge_consumers += 1;
+      merge_node = consumer;
+      ASSERT_EQ(consumer->requested_device(), gpu_0_device);
+    }
+  }
+  ASSERT_EQ(merge_consumers, 1);
+  // Verify the requested device of the NextIteration node.
+  int next_iteration_consumers = 0;
+  for (const Node* consumer : merge_node->in_nodes()) {
+    if (consumer->type_string() == "NextIteration") {
+      next_iteration_consumers += 1;
+      ASSERT_EQ(consumer->requested_device(), gpu_0_device);
+    }
+  }
+  ASSERT_EQ(next_iteration_consumers, 1);
+  // Verify the requested device of the Switch node.
+  int switch_consumers = 0;
+  const Node* switch_node = nullptr;
+  for (const Node* consumer : merge_node->out_nodes()) {
+    if (consumer->type_string() == "Switch") {
+      switch_consumers += 1;
+      switch_node = consumer;
+      ASSERT_EQ(consumer->requested_device(), gpu_0_device);
+    }
+  }
+  ASSERT_EQ(switch_consumers, 1);
+  // Verify the requested device of the Exit node.
+  int exit_consumers = 0;
+  for (const Node* consumer : switch_node->out_nodes()) {
+    if (consumer->type_string() == "Exit") {
+      exit_consumers += 1;
+      ASSERT_EQ(consumer->requested_device(), gpu_0_device);
+    }
+  }
+  ASSERT_EQ(exit_consumers, 1);
+  // Verify the requested device of LoopControlInputs.
+  const Node* loop_control_inputs_node = nullptr;
+  for (const auto* op : graph->op_nodes()) {
+    if (absl::StrContains(op->name(), "LoopControlInputs")) {
+      loop_control_inputs_node = op;
+    }
+  }
+  ASSERT_NE(loop_control_inputs_node, nullptr);
+  ASSERT_EQ(loop_control_inputs_node->requested_device(), gpu_2_device);
+  // Verify the requested device of LoopExecuted.
+  const Node* loop_executed_node = nullptr;
+  for (const auto* op : graph->op_nodes()) {
+    if (absl::StrContains(op->name(), "LoopExecuted")) {
+      loop_executed_node = op;
+    }
+  }
+  ASSERT_NE(loop_executed_node, nullptr);
+  ASSERT_EQ(loop_executed_node->requested_device(), gpu_2_device);
+}
+
 TEST(LowerWhileOpTest, MultipleInputs) {
   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));