From c287ba8d5e699e0ece7845e66ec9c357fb090dbc Mon Sep 17 00:00:00 2001 From: George Karpenkov <cheshire@google.com> Date: Fri, 28 Aug 2020 12:37:20 -0700 Subject: [PATCH] [TF2XLA] Do not copy in XLA device implementation; instead, request correct placement from the start. PiperOrigin-RevId: 328989681 Change-Id: Ia57d5cd510091e94f58081d58e775d97e8c5ba9e --- .../compiler/jit/xla_compile_on_demand_op.cc | 27 +++-------------- tensorflow/compiler/jit/xla_device.cc | 10 +++++-- tensorflow/compiler/tf2xla/xla_op_registry.cc | 30 ++++++++++--------- tensorflow/compiler/tf2xla/xla_op_registry.h | 5 ++++ 4 files changed, 33 insertions(+), 39 deletions(-) diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index ba20b532a11..1c3656edef4 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -123,30 +123,11 @@ Status XlaCompileOnDemandOp::Compile( if (!constant_arguments.count(i)) { if (absl::c_binary_search(constant_input_indices, i)) { - if (ctx->input_memory_type(i) != HOST_MEMORY && - ctx->op_device_context()) { - // Slow path; the argument is not available as a host constant so we - // must fetch it synchronously. - Tensor host_tensor; - AllocatorAttributes attrs; - attrs.set_on_host(true); - TF_RETURN_IF_ERROR(ctx->allocate_temp(device_tensor.dtype(), - device_tensor.shape(), - &host_tensor, attrs)); - Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync( - &device_tensor, "ConstantArgument", - reinterpret_cast<Device*>(ctx->device()), &host_tensor); - if (!status.ok()) { - LOG(ERROR) << "Copying tensor of shape " - << device_tensor.shape().DebugString() << " from " - << ctx->device()->name() << "to CPU failed with " - << status.ToString(); - return status; - } - constant_arguments[i] = host_tensor; - } else { - constant_arguments[i] = device_tensor; + if (ctx->input_memory_type(i) != HOST_MEMORY) { + return errors::Internal( + "Expected constant argument not in host memory"); } + constant_arguments[i] = device_tensor; } } } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index c47c9a29c1a..089d22dca03 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -573,8 +573,7 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, // Any op assigned to the device that isn't rewritten by the graph rewriter // gets executed by an XlaCompileOnDemandOp, which compiles it and executes // it just-in-time. - OpKernel* (*factory)(OpKernelConstruction*) = - [](OpKernelConstruction* context) -> OpKernel* { + auto factory = [](OpKernelConstruction* context) -> OpKernel* { return new XlaCompileOnDemandOp(context); }; XlaOpRegistry::RegisterCompilationKernels(); @@ -583,6 +582,13 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, jit_device, /*include_compilation_only_kernels=*/false)) { KernelDef* def = new KernelDef(*jit_def); + const std::unordered_set<std::string>* constant_inputs = + XlaOpRegistry::CompileTimeConstantInputArgNames(def->op()); + + for (const std::string& arg_name : *constant_inputs) { + def->add_host_memory_arg(arg_name); + } + def->set_device_type(device); registrations->op_kernel_registrars.emplace_back( new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp", diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index e37f4659185..9948fe6d1b9 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -365,6 +365,19 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels( return ops; } +/*static*/ const std::unordered_set<std::string>* +XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto it = registry.ops_.find(op); + static auto empty_set = new std::unordered_set<std::string>; + if (it == registry.ops_.end() || it->second.empty()) { + return empty_set; + } else { + return &it->second.front()->compile_time_constant_inputs; + } +} + /* static */ Status XlaOpRegistry::CompileTimeConstantInputs( const NodeDef& node_def, const OpKernel* op_kernel, const OpDef* op_def, std::vector<int>* result) { @@ -385,21 +398,10 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels( compile_time_constant_inputs_from_attr.end())); compile_time_constant_inputs = &compile_time_constant_inputs_from_attr; } else { - const string& op = node_def.op(); - - XlaOpRegistry& registry = Instance(); - mutex_lock lock(registry.mutex_); - auto it = registry.ops_.find(op); - if (it == registry.ops_.end() || it->second.empty()) { + compile_time_constant_inputs = + CompileTimeConstantInputArgNames(node_def.op()); + if (compile_time_constant_inputs->empty()) { return Status::OK(); - } else { - // The test in IsCompatible ensures that if there are multiple matching - // registrations for this op name, they all have the same value of - // compile_time_constant_inputs, so only the first match is returned. - // - // TODO(sanjoy): This can probably be a std::vector<string>. - compile_time_constant_inputs = - &it->second.front()->compile_time_constant_inputs; } } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index af720fb4bb9..9533acb6a0c 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -198,6 +198,11 @@ class XlaOpRegistry { /*op_def=*/nullptr, result); } + // Return names of arguments for a given op which are supposed to be + // constants. + static const std::unordered_set<std::string>* + CompileTimeConstantInputArgNames(const string& op); + // Returns true if `op` is a "metadata" op, one that only looks at the shapes // of its operands and not their values. static bool IsMetadataOp(const string& op);