diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index ec5a372875c..25eed134e35 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -195,18 +195,20 @@ XlaComputationLaunchContext::XlaComputationLaunchContext( } void XlaComputationLaunchContext::PopulateInputs( - OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, + OpKernelContext* ctx, + const XlaCompiler::CompilationResult* compilation_result, const std::map& variables, int missing_ctx_input_prefix) { // Build ShapedBuffers that point directly to the Tensor buffers. - arg_ptrs_ = std::vector(kernel->xla_input_shapes.size()); + arg_ptrs_ = + std::vector(compilation_result->xla_input_shapes.size()); xla::TransferManager* transfer_manager = client_->backend().transfer_manager(); - for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { - int arg_num = kernel->input_mapping[i]; + for (int i = 0; i < compilation_result->xla_input_shapes.size(); ++i) { + int arg_num = compilation_result->input_mapping[i]; CHECK_GE(arg_num, missing_ctx_input_prefix); - const xla::Shape& shape = kernel->xla_input_shapes[i]; + const xla::Shape& shape = compilation_result->xla_input_shapes[i]; const Tensor* t = variables.count(arg_num) ? &(variables.at(arg_num).value) : &(ctx->input(arg_num - missing_ctx_input_prefix)); @@ -361,13 +363,94 @@ static Status SetBufferForResourceVarTensorUnderAllocateXlaTensors( return Status::OK(); } +// Sets output `output_num` for `ctx` provided it is known at a compile time. +static Status SetOutputForConstant( + OpKernelContext* ctx, se::Stream* stream, + const XlaCompiler::CompilationResult* compilation_result, int output_num) { + CHECK(compilation_result->outputs[output_num].is_constant); + // Output is a constant. + const Tensor& const_tensor = + compilation_result->outputs[output_num].constant_value; + Tensor* output_tensor; + const size_t total_bytes = const_tensor.TotalBytes(); + if (stream && total_bytes > 0) { + // Copy host -> device. (Empty tensors don't have backing buffers.) + // Manually allocate memory using an XlaTensorBuffer so we can allocate + // as much memory as the device requires (as given by + // GetByteSizeRequirement). This avoids XlaTransferManager having to + // reallocate the device buffer later. + VLOG(1) << "Constant output tensor on device"; + + TF_RETURN_IF_ERROR( + ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor)); + Device* device = dynamic_cast(ctx->device()); + if (device == nullptr) { + return errors::Internal("DeviceBase was not a Device."); + } + ctx->op_device_context()->CopyCPUTensorToDevice( + &const_tensor, device, output_tensor, + [&](Status status) { TF_CHECK_OK(status); }); + + if (device->device_type() == DEVICE_GPU) { + // The GPUDeviceContext enqueues the host->device transfer in a + // separate stream from the main compute stream. We must ensure the + // compute stream is synchronized with the host->device transfer + // stream now otherwise we will create a race condition. + auto* gpu_device_context = + static_cast(ctx->op_device_context()); + gpu_device_context->stream()->ThenWaitFor( + gpu_device_context->host_to_device_stream()); + } + } else { + // No copy required. + ctx->set_output(output_num, const_tensor); + output_tensor = ctx->mutable_output(output_num); + } + if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) { + xla_tensor->set_host_tensor(const_tensor); + } + return Status::OK(); +} + +// Creates a list of updates resource variables. +static xla::StatusOr> GatherVariableInfo( + OpKernelContext* ctx, + const XlaCompiler::CompilationResult* compilation_result, + int missing_ctx_input_prefix) { + std::vector variable_infos; + variable_infos.reserve(compilation_result->resource_updates.size()); + + for (int i = 0; i < compilation_result->resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& write = + compilation_result->resource_updates[i]; + int actual_input_index = write.input_index - missing_ctx_input_prefix; + if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) { + return errors::Internal("Invalid input index for variable write."); + } + + // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, + // not a Tensor. + Var* variable = nullptr; + TF_RETURN_IF_ERROR(LookupOrCreateResource( + ctx, HandleFromInput(ctx, actual_input_index), &variable, + [&write](Var** ptr) { + *ptr = new Var(write.type); + return Status::OK(); + })); + variable_infos.emplace_back(actual_input_index, variable); + } + return variable_infos; +} + Status XlaComputationLaunchContext::PopulateOutputs( - OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, + OpKernelContext* ctx, + const XlaCompiler::CompilationResult* compilation_result, ScopedShapedBuffer output, int missing_ctx_input_prefix, const xla::HloInputOutputAliasConfig& input_output_alias, const std::map& resource_var_snapshots) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + Allocator* allocator = ctx->device()->GetAllocator({}); // Computation output should always be a tuple. if (VLOG_IS_ON(2)) { @@ -375,7 +458,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( VLOG(2) << "Result tuple shape (on device): " << output.on_device_shape().DebugString(); } - CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); + CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size()); // If the on-host-shape isn't a tuple, create a new single-element tuple // buffer with a nullptr root index table. This allows the code below to treat @@ -404,82 +487,41 @@ Status XlaComputationLaunchContext::PopulateOutputs( // Copy XLA results to the OpOutputList. int output_num = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { - Allocator* allocator = ctx->device()->GetAllocator({}); - if (kernel->outputs[i].is_constant) { - // Output is a constant. - const Tensor& const_tensor = kernel->outputs[i].constant_value; - Tensor* output_tensor; - const size_t total_bytes = const_tensor.TotalBytes(); - if (stream && total_bytes > 0) { - // Copy host -> device. (Empty tensors don't have backing buffers.) - // Manually allocate memory using an XlaTensorBuffer so we can allocate - // as much memory as the device requires (as given by - // GetByteSizeRequirement). This avoids XlaTransferManager having to - // reallocate the device buffer later. - VLOG(1) << "Constant output tensor on device"; + const TensorShape& shape = compilation_result->outputs[i].shape; + const DataType& type = compilation_result->outputs[i].type; + VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " + << DataTypeString(type); + if (type == DT_VARIANT) { + return errors::Unimplemented( + "Support for TensorList crossing the XLA/TF boundary " + "is not implemented"); + } - TF_RETURN_IF_ERROR( - ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); - - Device* device = dynamic_cast(ctx->device()); - if (device == nullptr) { - return errors::Internal("DeviceBase was not a Device."); - } - ctx->op_device_context()->CopyCPUTensorToDevice( - &const_tensor, device, output_tensor, - [&](Status status) { TF_CHECK_OK(status); }); - - if (device->device_type() == DEVICE_GPU) { - // The GPUDeviceContext enqueues the host->device transfer in a - // separate stream from the main compute stream. We must ensure the - // compute stream is synchronized with the host->device transfer - // stream now otherwise we will create a race condition. - auto* gpu_device_context = - static_cast(ctx->op_device_context()); - gpu_device_context->stream()->ThenWaitFor( - gpu_device_context->host_to_device_stream()); - } - } else { - // No copy required. - ctx->set_output(i, const_tensor); - output_tensor = ctx->mutable_output(i); - } - if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) { - xla_tensor->set_host_tensor(const_tensor); - } + if (compilation_result->outputs[i].is_constant) { + TF_RETURN_IF_ERROR( + SetOutputForConstant(ctx, stream, compilation_result, i)); + } else if (type == DT_RESOURCE) { + int input_index = + compilation_result->outputs[i].input_index - missing_ctx_input_prefix; + TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs()) + << "Invalid input for outputs " << i << ": " << input_index; + ctx->set_output(i, ctx->input(input_index)); } else { - const TensorShape& shape = kernel->outputs[i].shape; - const DataType& type = kernel->outputs[i].type; - VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " - << DataTypeString(type); - if (type == DT_RESOURCE) { - int input_index = - kernel->outputs[i].input_index - missing_ctx_input_prefix; - TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs()) - << "Invalid input for outputs " << i << ": " << input_index; - ctx->set_output(i, ctx->input(input_index)); - } else { - if (allocate_xla_tensors_) { - TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors( - input_output_alias, output_num, ctx, i, shape, &output, - definition_event, stream, use_multiple_streams_)); - } else { - if (type == DT_VARIANT) { - return errors::Unimplemented( - "Support for TensorList crossing the XLA/TF boundary " - "is not implemented"); - } + if (allocate_xla_tensors_) { + TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors( + input_output_alias, output_num, ctx, i, shape, &output, + definition_event, stream, use_multiple_streams_)); - se::DeviceMemoryBase buffer = output.buffer({output_num}); - Tensor output_tensor = GetOrCreateTensorForOutput( - output_num, ctx, missing_ctx_input_prefix, input_output_alias, - kernel->input_mapping, resource_var_snapshots, - ctx->expected_output_dtype(i), shape, buffer, allocator); - output.set_buffer(se::OwningDeviceMemory(), {output_num}); - ctx->set_output(i, output_tensor); - } - ++output_num; + } else { + se::DeviceMemoryBase buffer = output.buffer({output_num}); + Tensor output_tensor = GetOrCreateTensorForOutput( + output_num, ctx, missing_ctx_input_prefix, input_output_alias, + compilation_result->input_mapping, resource_var_snapshots, + ctx->expected_output_dtype(i), shape, buffer, allocator); + output.set_buffer(se::OwningDeviceMemory(), {output_num}); + ctx->set_output(i, output_tensor); } + ++output_num; } if (VLOG_IS_ON(3)) { @@ -489,34 +531,14 @@ Status XlaComputationLaunchContext::PopulateOutputs( // Apply variable updates, if any. VLOG(2) << "Applying variable updates"; - std::vector variable_infos; - variable_infos.reserve(kernel->resource_updates.size()); - - for (int i = 0; i < kernel->resource_updates.size(); ++i) { - const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - int actual_input_index = write.input_index - missing_ctx_input_prefix; - if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) { - return errors::Internal("Invalid input index for variable write."); - } - - // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, - // not a Tensor. - Var* variable = nullptr; - TF_RETURN_IF_ERROR(LookupOrCreateResource( - ctx, HandleFromInput(ctx, actual_input_index), &variable, - [&write](Var** ptr) { - *ptr = new Var(write.type); - return Status::OK(); - })); - variable_infos.emplace_back(actual_input_index, variable); - } - + TF_ASSIGN_OR_RETURN( + std::vector variable_infos, + GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix)); TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos))); - for (int i = 0; i < kernel->resource_updates.size(); ++i) { - Allocator* allocator = ctx->device()->GetAllocator({}); - const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - + for (int i = 0; i < compilation_result->resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& write = + compilation_result->resource_updates[i]; if (variable_infos[i].var()->tensor()->dtype() != write.type) { return errors::Internal("Mismatched type in variable write"); } @@ -530,7 +552,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( output.set_buffer(se::OwningDeviceMemory(), {output_num}); Tensor output_tensor = GetOrCreateTensorForOutput( output_num, ctx, missing_ctx_input_prefix, input_output_alias, - kernel->input_mapping, resource_var_snapshots, write.type, + compilation_result->input_mapping, resource_var_snapshots, write.type, write.shape, buffer, allocator); *variable_infos[i].var()->tensor() = output_tensor; variable_infos[i].var()->is_initialized |= write.modified; diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index cf68dcb7dd6..9a7f20cb310 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -136,7 +136,7 @@ class XlaComputationLaunchContext { // input_mapping must be greater than or equal to `missing_ctx_input_prefix` // (in other words, no inputs actually required by the kernel can be missing). void PopulateInputs(OpKernelContext* ctx, - const XlaCompiler::CompilationResult* kernel, + const XlaCompiler::CompilationResult* compilation_result, const std::map& variables, int missing_ctx_input_prefix); @@ -148,10 +148,11 @@ class XlaComputationLaunchContext { // See jit/resource_operation_safety_analysis for details. // // - // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are - // missing and adjusts input indices accordingly. + // Assumes that the first `missing_ctx_input_prefix` inputs to the + // compilation_result are missing and adjusts input indices accordingly. Status PopulateOutputs( - OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, + OpKernelContext* ctx, + const XlaCompiler::CompilationResult* compilation_result, xla::ScopedShapedBuffer output, int missing_ctx_input_prefix, const xla::HloInputOutputAliasConfig& input_output_alias, const std::map& resource_var_snapshots);