[TF2XLA] Do not copy in XLA device implementation; instead, request correct placement from the start.

PiperOrigin-RevId: 328989681
Change-Id: Ia57d5cd510091e94f58081d58e775d97e8c5ba9e
This commit is contained in:
George Karpenkov 2020-08-28 12:37:20 -07:00 committed by TensorFlower Gardener
parent 16895e59b8
commit c287ba8d5e
4 changed files with 33 additions and 39 deletions

View File

@ -123,30 +123,11 @@ Status XlaCompileOnDemandOp::Compile(
if (!constant_arguments.count(i)) {
if (absl::c_binary_search(constant_input_indices, i)) {
if (ctx->input_memory_type(i) != HOST_MEMORY &&
ctx->op_device_context()) {
// Slow path; the argument is not available as a host constant so we
// must fetch it synchronously.
Tensor host_tensor;
AllocatorAttributes attrs;
attrs.set_on_host(true);
TF_RETURN_IF_ERROR(ctx->allocate_temp(device_tensor.dtype(),
device_tensor.shape(),
&host_tensor, attrs));
Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
&device_tensor, "ConstantArgument",
reinterpret_cast<Device*>(ctx->device()), &host_tensor);
if (!status.ok()) {
LOG(ERROR) << "Copying tensor of shape "
<< device_tensor.shape().DebugString() << " from "
<< ctx->device()->name() << "to CPU failed with "
<< status.ToString();
return status;
}
constant_arguments[i] = host_tensor;
} else {
constant_arguments[i] = device_tensor;
if (ctx->input_memory_type(i) != HOST_MEMORY) {
return errors::Internal(
"Expected constant argument not in host memory");
}
constant_arguments[i] = device_tensor;
}
}
}

View File

@ -573,8 +573,7 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
// Any op assigned to the device that isn't rewritten by the graph rewriter
// gets executed by an XlaCompileOnDemandOp, which compiles it and executes
// it just-in-time.
OpKernel* (*factory)(OpKernelConstruction*) =
[](OpKernelConstruction* context) -> OpKernel* {
auto factory = [](OpKernelConstruction* context) -> OpKernel* {
return new XlaCompileOnDemandOp(context);
};
XlaOpRegistry::RegisterCompilationKernels();
@ -583,6 +582,13 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
jit_device,
/*include_compilation_only_kernels=*/false)) {
KernelDef* def = new KernelDef(*jit_def);
const std::unordered_set<std::string>* constant_inputs =
XlaOpRegistry::CompileTimeConstantInputArgNames(def->op());
for (const std::string& arg_name : *constant_inputs) {
def->add_host_memory_arg(arg_name);
}
def->set_device_type(device);
registrations->op_kernel_registrars.emplace_back(
new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp",

View File

@ -365,6 +365,19 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
return ops;
}
/*static*/ const std::unordered_set<std::string>*
XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) {
XlaOpRegistry& registry = Instance();
mutex_lock lock(registry.mutex_);
auto it = registry.ops_.find(op);
static auto empty_set = new std::unordered_set<std::string>;
if (it == registry.ops_.end() || it->second.empty()) {
return empty_set;
} else {
return &it->second.front()->compile_time_constant_inputs;
}
}
/* static */ Status XlaOpRegistry::CompileTimeConstantInputs(
const NodeDef& node_def, const OpKernel* op_kernel, const OpDef* op_def,
std::vector<int>* result) {
@ -385,21 +398,10 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
compile_time_constant_inputs_from_attr.end()));
compile_time_constant_inputs = &compile_time_constant_inputs_from_attr;
} else {
const string& op = node_def.op();
XlaOpRegistry& registry = Instance();
mutex_lock lock(registry.mutex_);
auto it = registry.ops_.find(op);
if (it == registry.ops_.end() || it->second.empty()) {
compile_time_constant_inputs =
CompileTimeConstantInputArgNames(node_def.op());
if (compile_time_constant_inputs->empty()) {
return Status::OK();
} else {
// The test in IsCompatible ensures that if there are multiple matching
// registrations for this op name, they all have the same value of
// compile_time_constant_inputs, so only the first match is returned.
//
// TODO(sanjoy): This can probably be a std::vector<string>.
compile_time_constant_inputs =
&it->second.front()->compile_time_constant_inputs;
}
}

View File

@ -198,6 +198,11 @@ class XlaOpRegistry {
/*op_def=*/nullptr, result);
}
// Return names of arguments for a given op which are supposed to be
// constants.
static const std::unordered_set<std::string>*
CompileTimeConstantInputArgNames(const string& op);
// Returns true if `op` is a "metadata" op, one that only looks at the shapes
// of its operands and not their values.
static bool IsMetadataOp(const string& op);