[TF2XLA] Do not copy in XLA device implementation; instead, request correct placement from the start.
PiperOrigin-RevId: 328989681 Change-Id: Ia57d5cd510091e94f58081d58e775d97e8c5ba9e
This commit is contained in:
parent
16895e59b8
commit
c287ba8d5e
@ -123,30 +123,11 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
|
||||
if (!constant_arguments.count(i)) {
|
||||
if (absl::c_binary_search(constant_input_indices, i)) {
|
||||
if (ctx->input_memory_type(i) != HOST_MEMORY &&
|
||||
ctx->op_device_context()) {
|
||||
// Slow path; the argument is not available as a host constant so we
|
||||
// must fetch it synchronously.
|
||||
Tensor host_tensor;
|
||||
AllocatorAttributes attrs;
|
||||
attrs.set_on_host(true);
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_temp(device_tensor.dtype(),
|
||||
device_tensor.shape(),
|
||||
&host_tensor, attrs));
|
||||
Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
|
||||
&device_tensor, "ConstantArgument",
|
||||
reinterpret_cast<Device*>(ctx->device()), &host_tensor);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Copying tensor of shape "
|
||||
<< device_tensor.shape().DebugString() << " from "
|
||||
<< ctx->device()->name() << "to CPU failed with "
|
||||
<< status.ToString();
|
||||
return status;
|
||||
}
|
||||
constant_arguments[i] = host_tensor;
|
||||
} else {
|
||||
constant_arguments[i] = device_tensor;
|
||||
if (ctx->input_memory_type(i) != HOST_MEMORY) {
|
||||
return errors::Internal(
|
||||
"Expected constant argument not in host memory");
|
||||
}
|
||||
constant_arguments[i] = device_tensor;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -573,8 +573,7 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
|
||||
// Any op assigned to the device that isn't rewritten by the graph rewriter
|
||||
// gets executed by an XlaCompileOnDemandOp, which compiles it and executes
|
||||
// it just-in-time.
|
||||
OpKernel* (*factory)(OpKernelConstruction*) =
|
||||
[](OpKernelConstruction* context) -> OpKernel* {
|
||||
auto factory = [](OpKernelConstruction* context) -> OpKernel* {
|
||||
return new XlaCompileOnDemandOp(context);
|
||||
};
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
@ -583,6 +582,13 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
|
||||
jit_device,
|
||||
/*include_compilation_only_kernels=*/false)) {
|
||||
KernelDef* def = new KernelDef(*jit_def);
|
||||
const std::unordered_set<std::string>* constant_inputs =
|
||||
XlaOpRegistry::CompileTimeConstantInputArgNames(def->op());
|
||||
|
||||
for (const std::string& arg_name : *constant_inputs) {
|
||||
def->add_host_memory_arg(arg_name);
|
||||
}
|
||||
|
||||
def->set_device_type(device);
|
||||
registrations->op_kernel_registrars.emplace_back(
|
||||
new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp",
|
||||
|
@ -365,6 +365,19 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
|
||||
return ops;
|
||||
}
|
||||
|
||||
/*static*/ const std::unordered_set<std::string>*
|
||||
XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) {
|
||||
XlaOpRegistry& registry = Instance();
|
||||
mutex_lock lock(registry.mutex_);
|
||||
auto it = registry.ops_.find(op);
|
||||
static auto empty_set = new std::unordered_set<std::string>;
|
||||
if (it == registry.ops_.end() || it->second.empty()) {
|
||||
return empty_set;
|
||||
} else {
|
||||
return &it->second.front()->compile_time_constant_inputs;
|
||||
}
|
||||
}
|
||||
|
||||
/* static */ Status XlaOpRegistry::CompileTimeConstantInputs(
|
||||
const NodeDef& node_def, const OpKernel* op_kernel, const OpDef* op_def,
|
||||
std::vector<int>* result) {
|
||||
@ -385,21 +398,10 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
|
||||
compile_time_constant_inputs_from_attr.end()));
|
||||
compile_time_constant_inputs = &compile_time_constant_inputs_from_attr;
|
||||
} else {
|
||||
const string& op = node_def.op();
|
||||
|
||||
XlaOpRegistry& registry = Instance();
|
||||
mutex_lock lock(registry.mutex_);
|
||||
auto it = registry.ops_.find(op);
|
||||
if (it == registry.ops_.end() || it->second.empty()) {
|
||||
compile_time_constant_inputs =
|
||||
CompileTimeConstantInputArgNames(node_def.op());
|
||||
if (compile_time_constant_inputs->empty()) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
// The test in IsCompatible ensures that if there are multiple matching
|
||||
// registrations for this op name, they all have the same value of
|
||||
// compile_time_constant_inputs, so only the first match is returned.
|
||||
//
|
||||
// TODO(sanjoy): This can probably be a std::vector<string>.
|
||||
compile_time_constant_inputs =
|
||||
&it->second.front()->compile_time_constant_inputs;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -198,6 +198,11 @@ class XlaOpRegistry {
|
||||
/*op_def=*/nullptr, result);
|
||||
}
|
||||
|
||||
// Return names of arguments for a given op which are supposed to be
|
||||
// constants.
|
||||
static const std::unordered_set<std::string>*
|
||||
CompileTimeConstantInputArgNames(const string& op);
|
||||
|
||||
// Returns true if `op` is a "metadata" op, one that only looks at the shapes
|
||||
// of its operands and not their values.
|
||||
static bool IsMetadataOp(const string& op);
|
||||
|
Loading…
Reference in New Issue
Block a user