[TF/XLA] [NFC] Simplify XlaComputationLaunchContext::PopulateOutputs

Reduce the nesting level, extract a function for gathering VariableInfo.

PiperOrigin-RevId: 316720004
Change-Id: I49982058d9f7efbc2dcbb2b180c1fc95193cfa39
This commit is contained in:
George Karpenkov 2020-06-16 11:19:06 -07:00 committed by TensorFlower Gardener
parent 421e64c0c6
commit 2d50164bdb
2 changed files with 133 additions and 110 deletions

View File

@ -195,18 +195,20 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
}
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
const std::map<int, OptionalTensor>& variables,
int missing_ctx_input_prefix) {
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
arg_ptrs_ =
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
xla::TransferManager* transfer_manager =
client_->backend().transfer_manager();
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
int arg_num = kernel->input_mapping[i];
for (int i = 0; i < compilation_result->xla_input_shapes.size(); ++i) {
int arg_num = compilation_result->input_mapping[i];
CHECK_GE(arg_num, missing_ctx_input_prefix);
const xla::Shape& shape = kernel->xla_input_shapes[i];
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
const Tensor* t = variables.count(arg_num)
? &(variables.at(arg_num).value)
: &(ctx->input(arg_num - missing_ctx_input_prefix));
@ -361,13 +363,94 @@ static Status SetBufferForResourceVarTensorUnderAllocateXlaTensors(
return Status::OK();
}
// Sets output `output_num` for `ctx` provided it is known at a compile time.
static Status SetOutputForConstant(
OpKernelContext* ctx, se::Stream* stream,
const XlaCompiler::CompilationResult* compilation_result, int output_num) {
CHECK(compilation_result->outputs[output_num].is_constant);
// Output is a constant.
const Tensor& const_tensor =
compilation_result->outputs[output_num].constant_value;
Tensor* output_tensor;
const size_t total_bytes = const_tensor.TotalBytes();
if (stream && total_bytes > 0) {
// Copy host -> device. (Empty tensors don't have backing buffers.)
// Manually allocate memory using an XlaTensorBuffer so we can allocate
// as much memory as the device requires (as given by
// GetByteSizeRequirement). This avoids XlaTransferManager having to
// reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
TF_RETURN_IF_ERROR(
ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
Device* device = dynamic_cast<Device*>(ctx->device());
if (device == nullptr) {
return errors::Internal("DeviceBase was not a Device.");
}
ctx->op_device_context()->CopyCPUTensorToDevice(
&const_tensor, device, output_tensor,
[&](Status status) { TF_CHECK_OK(status); });
if (device->device_type() == DEVICE_GPU) {
// The GPUDeviceContext enqueues the host->device transfer in a
// separate stream from the main compute stream. We must ensure the
// compute stream is synchronized with the host->device transfer
// stream now otherwise we will create a race condition.
auto* gpu_device_context =
static_cast<GPUDeviceContext*>(ctx->op_device_context());
gpu_device_context->stream()->ThenWaitFor(
gpu_device_context->host_to_device_stream());
}
} else {
// No copy required.
ctx->set_output(output_num, const_tensor);
output_tensor = ctx->mutable_output(output_num);
}
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
xla_tensor->set_host_tensor(const_tensor);
}
return Status::OK();
}
// Creates a list of updates resource variables.
static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
int missing_ctx_input_prefix) {
std::vector<VariableInfo> variable_infos;
variable_infos.reserve(compilation_result->resource_updates.size());
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write =
compilation_result->resource_updates[i];
int actual_input_index = write.input_index - missing_ctx_input_prefix;
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write.");
}
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
Var* variable = nullptr;
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
ctx, HandleFromInput(ctx, actual_input_index), &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
variable_infos.emplace_back(actual_input_index, variable);
}
return variable_infos;
}
Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
const std::map<int, OptionalTensor>& resource_var_snapshots) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
Allocator* allocator = ctx->device()->GetAllocator({});
// Computation output should always be a tuple.
if (VLOG_IS_ON(2)) {
@ -375,7 +458,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
VLOG(2) << "Result tuple shape (on device): "
<< output.on_device_shape().DebugString();
}
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
// If the on-host-shape isn't a tuple, create a new single-element tuple
// buffer with a nullptr root index table. This allows the code below to treat
@ -404,57 +487,22 @@ Status XlaComputationLaunchContext::PopulateOutputs(
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
if (kernel->outputs[i].is_constant) {
// Output is a constant.
const Tensor& const_tensor = kernel->outputs[i].constant_value;
Tensor* output_tensor;
const size_t total_bytes = const_tensor.TotalBytes();
if (stream && total_bytes > 0) {
// Copy host -> device. (Empty tensors don't have backing buffers.)
// Manually allocate memory using an XlaTensorBuffer so we can allocate
// as much memory as the device requires (as given by
// GetByteSizeRequirement). This avoids XlaTransferManager having to
// reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
TF_RETURN_IF_ERROR(
ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
Device* device = dynamic_cast<Device*>(ctx->device());
if (device == nullptr) {
return errors::Internal("DeviceBase was not a Device.");
}
ctx->op_device_context()->CopyCPUTensorToDevice(
&const_tensor, device, output_tensor,
[&](Status status) { TF_CHECK_OK(status); });
if (device->device_type() == DEVICE_GPU) {
// The GPUDeviceContext enqueues the host->device transfer in a
// separate stream from the main compute stream. We must ensure the
// compute stream is synchronized with the host->device transfer
// stream now otherwise we will create a race condition.
auto* gpu_device_context =
static_cast<GPUDeviceContext*>(ctx->op_device_context());
gpu_device_context->stream()->ThenWaitFor(
gpu_device_context->host_to_device_stream());
}
} else {
// No copy required.
ctx->set_output(i, const_tensor);
output_tensor = ctx->mutable_output(i);
}
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
xla_tensor->set_host_tensor(const_tensor);
}
} else {
const TensorShape& shape = kernel->outputs[i].shape;
const DataType& type = kernel->outputs[i].type;
const TensorShape& shape = compilation_result->outputs[i].shape;
const DataType& type = compilation_result->outputs[i].type;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);
if (type == DT_RESOURCE) {
if (type == DT_VARIANT) {
return errors::Unimplemented(
"Support for TensorList crossing the XLA/TF boundary "
"is not implemented");
}
if (compilation_result->outputs[i].is_constant) {
TF_RETURN_IF_ERROR(
SetOutputForConstant(ctx, stream, compilation_result, i));
} else if (type == DT_RESOURCE) {
int input_index =
kernel->outputs[i].input_index - missing_ctx_input_prefix;
compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
<< "Invalid input for outputs " << i << ": " << input_index;
ctx->set_output(i, ctx->input(input_index));
@ -463,24 +511,18 @@ Status XlaComputationLaunchContext::PopulateOutputs(
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
input_output_alias, output_num, ctx, i, shape, &output,
definition_event, stream, use_multiple_streams_));
} else {
if (type == DT_VARIANT) {
return errors::Unimplemented(
"Support for TensorList crossing the XLA/TF boundary "
"is not implemented");
}
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
kernel->input_mapping, resource_var_snapshots,
compilation_result->input_mapping, resource_var_snapshots,
ctx->expected_output_dtype(i), shape, buffer, allocator);
output.set_buffer(se::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor);
}
++output_num;
}
}
if (VLOG_IS_ON(3)) {
VLOG(3) << ctx->mutable_output(i)->DeviceSafeDebugString();
@ -489,34 +531,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
// Apply variable updates, if any.
VLOG(2) << "Applying variable updates";
std::vector<VariableInfo> variable_infos;
variable_infos.reserve(kernel->resource_updates.size());
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
int actual_input_index = write.input_index - missing_ctx_input_prefix;
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write.");
}
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
Var* variable = nullptr;
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
ctx, HandleFromInput(ctx, actual_input_index), &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
variable_infos.emplace_back(actual_input_index, variable);
}
TF_ASSIGN_OR_RETURN(
std::vector<VariableInfo> variable_infos,
GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix));
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write =
compilation_result->resource_updates[i];
if (variable_infos[i].var()->tensor()->dtype() != write.type) {
return errors::Internal("Mismatched type in variable write");
}
@ -530,7 +552,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
output.set_buffer(se::OwningDeviceMemory(), {output_num});
Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
kernel->input_mapping, resource_var_snapshots, write.type,
compilation_result->input_mapping, resource_var_snapshots, write.type,
write.shape, buffer, allocator);
*variable_infos[i].var()->tensor() = output_tensor;
variable_infos[i].var()->is_initialized |= write.modified;

View File

@ -136,7 +136,7 @@ class XlaComputationLaunchContext {
// input_mapping must be greater than or equal to `missing_ctx_input_prefix`
// (in other words, no inputs actually required by the kernel can be missing).
void PopulateInputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
const XlaCompiler::CompilationResult* compilation_result,
const std::map<int, OptionalTensor>& variables,
int missing_ctx_input_prefix);
@ -148,10 +148,11 @@ class XlaComputationLaunchContext {
// See jit/resource_operation_safety_analysis for details.
//
//
// Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
// missing and adjusts input indices accordingly.
// Assumes that the first `missing_ctx_input_prefix` inputs to the
// compilation_result are missing and adjusts input indices accordingly.
Status PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
const std::map<int, OptionalTensor>& resource_var_snapshots);