diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index ba20b532a11..1c3656edef4 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -123,30 +123,11 @@ Status XlaCompileOnDemandOp::Compile(
 
     if (!constant_arguments.count(i)) {
       if (absl::c_binary_search(constant_input_indices, i)) {
-        if (ctx->input_memory_type(i) != HOST_MEMORY &&
-            ctx->op_device_context()) {
-          // Slow path; the argument is not available as a host constant so we
-          // must fetch it synchronously.
-          Tensor host_tensor;
-          AllocatorAttributes attrs;
-          attrs.set_on_host(true);
-          TF_RETURN_IF_ERROR(ctx->allocate_temp(device_tensor.dtype(),
-                                                device_tensor.shape(),
-                                                &host_tensor, attrs));
-          Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
-              &device_tensor, "ConstantArgument",
-              reinterpret_cast<Device*>(ctx->device()), &host_tensor);
-          if (!status.ok()) {
-            LOG(ERROR) << "Copying tensor of shape "
-                       << device_tensor.shape().DebugString() << " from "
-                       << ctx->device()->name() << "to CPU failed with "
-                       << status.ToString();
-            return status;
-          }
-          constant_arguments[i] = host_tensor;
-        } else {
-          constant_arguments[i] = device_tensor;
+        if (ctx->input_memory_type(i) != HOST_MEMORY) {
+          return errors::Internal(
+              "Expected constant argument not in host memory");
         }
+        constant_arguments[i] = device_tensor;
       }
     }
   }
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index c47c9a29c1a..089d22dca03 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -573,8 +573,7 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
   // Any op assigned to the device that isn't rewritten by the graph rewriter
   // gets executed by an XlaCompileOnDemandOp, which compiles it and executes
   // it just-in-time.
-  OpKernel* (*factory)(OpKernelConstruction*) =
-      [](OpKernelConstruction* context) -> OpKernel* {
+  auto factory = [](OpKernelConstruction* context) -> OpKernel* {
     return new XlaCompileOnDemandOp(context);
   };
   XlaOpRegistry::RegisterCompilationKernels();
@@ -583,6 +582,13 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
            jit_device,
            /*include_compilation_only_kernels=*/false)) {
     KernelDef* def = new KernelDef(*jit_def);
+    const std::unordered_set<std::string>* constant_inputs =
+        XlaOpRegistry::CompileTimeConstantInputArgNames(def->op());
+
+    for (const std::string& arg_name : *constant_inputs) {
+      def->add_host_memory_arg(arg_name);
+    }
+
     def->set_device_type(device);
     registrations->op_kernel_registrars.emplace_back(
         new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp",
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index e37f4659185..9948fe6d1b9 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -365,6 +365,19 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
   return ops;
 }
 
+/*static*/ const std::unordered_set<std::string>*
+XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) {
+  XlaOpRegistry& registry = Instance();
+  mutex_lock lock(registry.mutex_);
+  auto it = registry.ops_.find(op);
+  static auto empty_set = new std::unordered_set<std::string>;
+  if (it == registry.ops_.end() || it->second.empty()) {
+    return empty_set;
+  } else {
+    return &it->second.front()->compile_time_constant_inputs;
+  }
+}
+
 /* static */ Status XlaOpRegistry::CompileTimeConstantInputs(
     const NodeDef& node_def, const OpKernel* op_kernel, const OpDef* op_def,
     std::vector<int>* result) {
@@ -385,21 +398,10 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
                                compile_time_constant_inputs_from_attr.end()));
     compile_time_constant_inputs = &compile_time_constant_inputs_from_attr;
   } else {
-    const string& op = node_def.op();
-
-    XlaOpRegistry& registry = Instance();
-    mutex_lock lock(registry.mutex_);
-    auto it = registry.ops_.find(op);
-    if (it == registry.ops_.end() || it->second.empty()) {
+    compile_time_constant_inputs =
+        CompileTimeConstantInputArgNames(node_def.op());
+    if (compile_time_constant_inputs->empty()) {
       return Status::OK();
-    } else {
-      // The test in IsCompatible ensures that if there are multiple matching
-      // registrations for this op name, they all have the same value of
-      // compile_time_constant_inputs, so only the first match is returned.
-      //
-      // TODO(sanjoy): This can probably be a std::vector<string>.
-      compile_time_constant_inputs =
-          &it->second.front()->compile_time_constant_inputs;
     }
   }
 
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index af720fb4bb9..9533acb6a0c 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -198,6 +198,11 @@ class XlaOpRegistry {
                                      /*op_def=*/nullptr, result);
   }
 
+  // Return names of arguments for a given op which are supposed to be
+  // constants.
+  static const std::unordered_set<std::string>*
+  CompileTimeConstantInputArgNames(const string& op);
+
   // Returns true if `op` is a "metadata" op, one that only looks at the shapes
   // of its operands and not their values.
   static bool IsMetadataOp(const string& op);