diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 1c3656edef4..ba20b532a11 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -123,11 +123,30 @@ Status XlaCompileOnDemandOp::Compile( if (!constant_arguments.count(i)) { if (absl::c_binary_search(constant_input_indices, i)) { - if (ctx->input_memory_type(i) != HOST_MEMORY) { - return errors::Internal( - "Expected constant argument not in host memory"); + if (ctx->input_memory_type(i) != HOST_MEMORY && + ctx->op_device_context()) { + // Slow path; the argument is not available as a host constant so we + // must fetch it synchronously. + Tensor host_tensor; + AllocatorAttributes attrs; + attrs.set_on_host(true); + TF_RETURN_IF_ERROR(ctx->allocate_temp(device_tensor.dtype(), + device_tensor.shape(), + &host_tensor, attrs)); + Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync( + &device_tensor, "ConstantArgument", + reinterpret_cast<Device*>(ctx->device()), &host_tensor); + if (!status.ok()) { + LOG(ERROR) << "Copying tensor of shape " + << device_tensor.shape().DebugString() << " from " + << ctx->device()->name() << "to CPU failed with " + << status.ToString(); + return status; + } + constant_arguments[i] = host_tensor; + } else { + constant_arguments[i] = device_tensor; } - constant_arguments[i] = device_tensor; } } } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 089d22dca03..c47c9a29c1a 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -573,7 +573,8 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, // Any op assigned to the device that isn't rewritten by the graph rewriter // gets executed by an XlaCompileOnDemandOp, which compiles it and executes // it just-in-time. - auto factory = [](OpKernelConstruction* context) -> OpKernel* { + OpKernel* (*factory)(OpKernelConstruction*) = + [](OpKernelConstruction* context) -> OpKernel* { return new XlaCompileOnDemandOp(context); }; XlaOpRegistry::RegisterCompilationKernels(); @@ -582,13 +583,6 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, jit_device, /*include_compilation_only_kernels=*/false)) { KernelDef* def = new KernelDef(*jit_def); - const std::unordered_set<std::string>* constant_inputs = - XlaOpRegistry::CompileTimeConstantInputArgNames(def->op()); - - for (const std::string& arg_name : *constant_inputs) { - def->add_host_memory_arg(arg_name); - } - def->set_device_type(device); registrations->op_kernel_registrars.emplace_back( new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp", diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 9948fe6d1b9..e37f4659185 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -365,19 +365,6 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels( return ops; } -/*static*/ const std::unordered_set<std::string>* -XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { - XlaOpRegistry& registry = Instance(); - mutex_lock lock(registry.mutex_); - auto it = registry.ops_.find(op); - static auto empty_set = new std::unordered_set<std::string>; - if (it == registry.ops_.end() || it->second.empty()) { - return empty_set; - } else { - return &it->second.front()->compile_time_constant_inputs; - } -} - /* static */ Status XlaOpRegistry::CompileTimeConstantInputs( const NodeDef& node_def, const OpKernel* op_kernel, const OpDef* op_def, std::vector<int>* result) { @@ -398,10 +385,21 @@ XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { compile_time_constant_inputs_from_attr.end())); compile_time_constant_inputs = &compile_time_constant_inputs_from_attr; } else { - compile_time_constant_inputs = - CompileTimeConstantInputArgNames(node_def.op()); - if (compile_time_constant_inputs->empty()) { + const string& op = node_def.op(); + + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto it = registry.ops_.find(op); + if (it == registry.ops_.end() || it->second.empty()) { return Status::OK(); + } else { + // The test in IsCompatible ensures that if there are multiple matching + // registrations for this op name, they all have the same value of + // compile_time_constant_inputs, so only the first match is returned. + // + // TODO(sanjoy): This can probably be a std::vector<string>. + compile_time_constant_inputs = + &it->second.front()->compile_time_constant_inputs; } } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 9533acb6a0c..af720fb4bb9 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -198,11 +198,6 @@ class XlaOpRegistry { /*op_def=*/nullptr, result); } - // Return names of arguments for a given op which are supposed to be - // constants. - static const std::unordered_set<std::string>* - CompileTimeConstantInputArgNames(const string& op); - // Returns true if `op` is a "metadata" op, one that only looks at the shapes // of its operands and not their values. static bool IsMetadataOp(const string& op);