diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index 1c3656edef4..ba20b532a11 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -123,11 +123,30 @@ Status XlaCompileOnDemandOp::Compile(
 
     if (!constant_arguments.count(i)) {
       if (absl::c_binary_search(constant_input_indices, i)) {
-        if (ctx->input_memory_type(i) != HOST_MEMORY) {
-          return errors::Internal(
-              "Expected constant argument not in host memory");
+        if (ctx->input_memory_type(i) != HOST_MEMORY &&
+            ctx->op_device_context()) {
+          // Slow path; the argument is not available as a host constant so we
+          // must fetch it synchronously.
+          Tensor host_tensor;
+          AllocatorAttributes attrs;
+          attrs.set_on_host(true);
+          TF_RETURN_IF_ERROR(ctx->allocate_temp(device_tensor.dtype(),
+                                                device_tensor.shape(),
+                                                &host_tensor, attrs));
+          Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
+              &device_tensor, "ConstantArgument",
+              reinterpret_cast<Device*>(ctx->device()), &host_tensor);
+          if (!status.ok()) {
+            LOG(ERROR) << "Copying tensor of shape "
+                       << device_tensor.shape().DebugString() << " from "
+                       << ctx->device()->name() << "to CPU failed with "
+                       << status.ToString();
+            return status;
+          }
+          constant_arguments[i] = host_tensor;
+        } else {
+          constant_arguments[i] = device_tensor;
         }
-        constant_arguments[i] = device_tensor;
       }
     }
   }
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 089d22dca03..c47c9a29c1a 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -573,7 +573,8 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
   // Any op assigned to the device that isn't rewritten by the graph rewriter
   // gets executed by an XlaCompileOnDemandOp, which compiles it and executes
   // it just-in-time.
-  auto factory = [](OpKernelConstruction* context) -> OpKernel* {
+  OpKernel* (*factory)(OpKernelConstruction*) =
+      [](OpKernelConstruction* context) -> OpKernel* {
     return new XlaCompileOnDemandOp(context);
   };
   XlaOpRegistry::RegisterCompilationKernels();
@@ -582,13 +583,6 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
            jit_device,
            /*include_compilation_only_kernels=*/false)) {
     KernelDef* def = new KernelDef(*jit_def);
-    const std::unordered_set<std::string>* constant_inputs =
-        XlaOpRegistry::CompileTimeConstantInputArgNames(def->op());
-
-    for (const std::string& arg_name : *constant_inputs) {
-      def->add_host_memory_arg(arg_name);
-    }
-
     def->set_device_type(device);
     registrations->op_kernel_registrars.emplace_back(
         new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp",
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index 9948fe6d1b9..e37f4659185 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -365,19 +365,6 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
   return ops;
 }
 
-/*static*/ const std::unordered_set<std::string>*
-XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) {
-  XlaOpRegistry& registry = Instance();
-  mutex_lock lock(registry.mutex_);
-  auto it = registry.ops_.find(op);
-  static auto empty_set = new std::unordered_set<std::string>;
-  if (it == registry.ops_.end() || it->second.empty()) {
-    return empty_set;
-  } else {
-    return &it->second.front()->compile_time_constant_inputs;
-  }
-}
-
 /* static */ Status XlaOpRegistry::CompileTimeConstantInputs(
     const NodeDef& node_def, const OpKernel* op_kernel, const OpDef* op_def,
     std::vector<int>* result) {
@@ -398,10 +385,21 @@ XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) {
                                compile_time_constant_inputs_from_attr.end()));
     compile_time_constant_inputs = &compile_time_constant_inputs_from_attr;
   } else {
-    compile_time_constant_inputs =
-        CompileTimeConstantInputArgNames(node_def.op());
-    if (compile_time_constant_inputs->empty()) {
+    const string& op = node_def.op();
+
+    XlaOpRegistry& registry = Instance();
+    mutex_lock lock(registry.mutex_);
+    auto it = registry.ops_.find(op);
+    if (it == registry.ops_.end() || it->second.empty()) {
       return Status::OK();
+    } else {
+      // The test in IsCompatible ensures that if there are multiple matching
+      // registrations for this op name, they all have the same value of
+      // compile_time_constant_inputs, so only the first match is returned.
+      //
+      // TODO(sanjoy): This can probably be a std::vector<string>.
+      compile_time_constant_inputs =
+          &it->second.front()->compile_time_constant_inputs;
     }
   }
 
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index 9533acb6a0c..af720fb4bb9 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -198,11 +198,6 @@ class XlaOpRegistry {
                                      /*op_def=*/nullptr, result);
   }
 
-  // Return names of arguments for a given op which are supposed to be
-  // constants.
-  static const std::unordered_set<std::string>*
-  CompileTimeConstantInputArgNames(const string& op);
-
   // Returns true if `op` is a "metadata" op, one that only looks at the shapes
   // of its operands and not their values.
   static bool IsMetadataOp(const string& op);