diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index 646896ddbb7..b4244c95395 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -551,7 +551,8 @@ Status CapturedFunction::Instantiate(
   if (!metadata_->use_inter_op_parallelism()) {
     inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
   }
-  bool is_multi_device = metadata_->use_multi_device_function();
+  bool is_multi_device = false;
+  TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &is_multi_device));
   inst_opts.is_multi_device_function = is_multi_device;
 
   // We infer the target device from the function library runtime.
@@ -670,20 +671,13 @@ Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
 
   OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
                            ret_types_);
-  Notification n;
-  Status s;
   profiler::TraceMe activity(
       [&] {
         return absl::StrCat(
             "InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#");
       },
       profiler::TraceMeLevel::kInfo);
-  lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) {
-    s.Update(func_status);
-    n.Notify();
-  });
-  n.WaitForNotification();
-  TF_RETURN_IF_ERROR(s);
+  TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
   return frame.ConsumeRetvals(rets);
 }
 
@@ -708,9 +702,6 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
 
   BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
                               ret_types_);
-  Notification n;
-  Status s;
-
   profiler::TraceMe activity(
       [&] {
         return absl::StrCat(
@@ -718,12 +709,7 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
             f_opts.step_id, "#");
       },
       profiler::TraceMeLevel::kInfo);
-  lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) {
-    s.Update(func_status);
-    n.Notify();
-  });
-  n.WaitForNotification();
-  TF_RETURN_IF_ERROR(s);
+  TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
   return frame.ConsumeRetvals(rets);
 }
 
@@ -747,21 +733,13 @@ Status InstantiatedCapturedFunction::RunInstantiated(
 
   BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
                               ret_types_);
-  Notification n;
-  Status s;
-
   profiler::TraceMe activity(
       [&] {
         return absl::StrCat("InstantiatedCapturedFunction::RunInstantiated#id=",
                             f_opts.step_id, "#");
       },
       profiler::TraceMeLevel::kInfo);
-  lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) {
-    s.Update(func_status);
-    n.Notify();
-  });
-  n.WaitForNotification();
-  TF_RETURN_IF_ERROR(s);
+  TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
   return frame.ConsumeRetvals(rets);
 }
 
@@ -877,5 +855,77 @@ CapturedFunction::CapturedFunction(
     : metadata_(std::move(metadata)),
       captured_inputs_(std::move(captured_inputs)) {}
 
+Status CapturedFunction::IsMultiDevice(IteratorContext* ctx,
+                                       bool* is_multi_device) {
+  if (!metadata_->use_multi_device_function()) {
+    *is_multi_device = false;
+    return Status::OK();
+  }
+
+  const FunctionDef* fdef;
+  TF_RETURN_IF_ERROR(
+      LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
+
+  Device* current_device = ctx->flr()->device();
+  DeviceType current_device_type(current_device->device_type());
+  DeviceNameUtils::ParsedName current_device_name;
+  if (!DeviceNameUtils::ParseFullName(current_device->name(),
+                                      &current_device_name)) {
+    return errors::InvalidArgument("Failed to parse device name: ",
+                                   current_device->name());
+  }
+
+  // Check if any of the captured inputs are placed on a device not compatible
+  // with the current device. For non-captured inputs, we assume they are placed
+  // on the current device.
+  for (const auto& input : captured_inputs_) {
+    DataType dtype = input.dtype();
+    if (dtype == DT_RESOURCE) {
+      const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
+      DeviceNameUtils::ParsedName resource_device_name;
+      if (!DeviceNameUtils::ParseFullName(handle.device(),
+                                          &resource_device_name)) {
+        return errors::InvalidArgument("Failed to parse device name: ",
+                                       handle.device());
+      }
+      if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
+                                                  resource_device_name)) {
+        *is_multi_device = true;
+        return Status::OK();
+      }
+    }
+  }
+
+  // Check if all ops could be placed on the current device.
+  for (const auto& name : metadata_->lib_def()->ListFunctionNames()) {
+    const FunctionDef* fdef;
+    TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef));
+    for (const auto& node : fdef->node_def()) {
+      // Check if the op has a kernel available for the current device.
+      if (!KernelDefAvailable(current_device_type, node)) {
+        *is_multi_device = true;
+        return Status::OK();
+      }
+      // If the op has a requested device, check if the requested device is
+      // compatible with the current device.
+      if (!node.device().empty()) {
+        DeviceNameUtils::ParsedName node_device_name;
+        if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) {
+          return errors::InvalidArgument("Failed to parse device name: ",
+                                         node.device());
+        }
+        if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
+                                                    node_device_name)) {
+          *is_multi_device = true;
+          return Status::OK();
+        }
+      }
+    }
+  }
+
+  *is_multi_device = false;
+  return Status::OK();
+}
+
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index de424fc547c..284a02091dd 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -256,6 +256,10 @@ class CapturedFunction {
   CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,
                    std::vector<Tensor> captured_inputs);
 
+  // Determines whether the captured function requires the use of the
+  // multi-device function backend.
+  Status IsMultiDevice(IteratorContext* ctx, bool* is_multi_device);
+
   const std::shared_ptr<const FunctionMetadata> metadata_;
   const std::vector<Tensor> captured_inputs_;
 
diff --git a/tensorflow/python/data/benchmarks/map_benchmark.py b/tensorflow/python/data/benchmarks/map_benchmark.py
index 39c0c544e78..aea0fe9847e 100644
--- a/tensorflow/python/data/benchmarks/map_benchmark.py
+++ b/tensorflow/python/data/benchmarks/map_benchmark.py
@@ -115,15 +115,6 @@ class MapBenchmark(benchmark_base.DatasetBenchmarkBase):
         name="parallel_control_flow",
         apply_default_optimizations=True)
 
-  def benchmark_execution_overhead(self):
-    dataset = dataset_ops.Dataset.range(100000)
-    dataset = dataset_ops.MapDataset(
-        dataset, lambda x: x + 1, use_inter_op_parallelism=False)
-    self.run_and_report_benchmark(
-        dataset,
-        num_elements=100000,
-        name="execution_overhead")
-
 
 if __name__ == "__main__":
   benchmark_base.test.main()