[TF/XLA] [NFC] Simplify XlaComputationLaunchContext::PopulateOutputs

Reduce the nesting level, extract a function for gathering VariableInfo.

PiperOrigin-RevId: 316720004
Change-Id: I49982058d9f7efbc2dcbb2b180c1fc95193cfa39
This commit is contained in:
George Karpenkov 2020-06-16 11:19:06 -07:00 committed by TensorFlower Gardener
parent 421e64c0c6
commit 2d50164bdb
2 changed files with 133 additions and 110 deletions

View File

@ -195,18 +195,20 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
} }
void XlaComputationLaunchContext::PopulateInputs( void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
const std::map<int, OptionalTensor>& variables, const std::map<int, OptionalTensor>& variables,
int missing_ctx_input_prefix) { int missing_ctx_input_prefix) {
// Build ShapedBuffers that point directly to the Tensor buffers. // Build ShapedBuffers that point directly to the Tensor buffers.
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size()); arg_ptrs_ =
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
xla::TransferManager* transfer_manager = xla::TransferManager* transfer_manager =
client_->backend().transfer_manager(); client_->backend().transfer_manager();
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { for (int i = 0; i < compilation_result->xla_input_shapes.size(); ++i) {
int arg_num = kernel->input_mapping[i]; int arg_num = compilation_result->input_mapping[i];
CHECK_GE(arg_num, missing_ctx_input_prefix); CHECK_GE(arg_num, missing_ctx_input_prefix);
const xla::Shape& shape = kernel->xla_input_shapes[i]; const xla::Shape& shape = compilation_result->xla_input_shapes[i];
const Tensor* t = variables.count(arg_num) const Tensor* t = variables.count(arg_num)
? &(variables.at(arg_num).value) ? &(variables.at(arg_num).value)
: &(ctx->input(arg_num - missing_ctx_input_prefix)); : &(ctx->input(arg_num - missing_ctx_input_prefix));
@ -361,13 +363,94 @@ static Status SetBufferForResourceVarTensorUnderAllocateXlaTensors(
return Status::OK(); return Status::OK();
} }
// Sets output `output_num` for `ctx` provided it is known at a compile time.
static Status SetOutputForConstant(
OpKernelContext* ctx, se::Stream* stream,
const XlaCompiler::CompilationResult* compilation_result, int output_num) {
CHECK(compilation_result->outputs[output_num].is_constant);
// Output is a constant.
const Tensor& const_tensor =
compilation_result->outputs[output_num].constant_value;
Tensor* output_tensor;
const size_t total_bytes = const_tensor.TotalBytes();
if (stream && total_bytes > 0) {
// Copy host -> device. (Empty tensors don't have backing buffers.)
// Manually allocate memory using an XlaTensorBuffer so we can allocate
// as much memory as the device requires (as given by
// GetByteSizeRequirement). This avoids XlaTransferManager having to
// reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
TF_RETURN_IF_ERROR(
ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
Device* device = dynamic_cast<Device*>(ctx->device());
if (device == nullptr) {
return errors::Internal("DeviceBase was not a Device.");
}
ctx->op_device_context()->CopyCPUTensorToDevice(
&const_tensor, device, output_tensor,
[&](Status status) { TF_CHECK_OK(status); });
if (device->device_type() == DEVICE_GPU) {
// The GPUDeviceContext enqueues the host->device transfer in a
// separate stream from the main compute stream. We must ensure the
// compute stream is synchronized with the host->device transfer
// stream now otherwise we will create a race condition.
auto* gpu_device_context =
static_cast<GPUDeviceContext*>(ctx->op_device_context());
gpu_device_context->stream()->ThenWaitFor(
gpu_device_context->host_to_device_stream());
}
} else {
// No copy required.
ctx->set_output(output_num, const_tensor);
output_tensor = ctx->mutable_output(output_num);
}
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
xla_tensor->set_host_tensor(const_tensor);
}
return Status::OK();
}
// Creates a list of updates resource variables.
static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
int missing_ctx_input_prefix) {
std::vector<VariableInfo> variable_infos;
variable_infos.reserve(compilation_result->resource_updates.size());
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write =
compilation_result->resource_updates[i];
int actual_input_index = write.input_index - missing_ctx_input_prefix;
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write.");
}
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
Var* variable = nullptr;
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
ctx, HandleFromInput(ctx, actual_input_index), &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
variable_infos.emplace_back(actual_input_index, variable);
}
return variable_infos;
}
Status XlaComputationLaunchContext::PopulateOutputs( Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
ScopedShapedBuffer output, int missing_ctx_input_prefix, ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias, const xla::HloInputOutputAliasConfig& input_output_alias,
const std::map<int, OptionalTensor>& resource_var_snapshots) { const std::map<int, OptionalTensor>& resource_var_snapshots) {
se::Stream* stream = se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
Allocator* allocator = ctx->device()->GetAllocator({});
// Computation output should always be a tuple. // Computation output should always be a tuple.
if (VLOG_IS_ON(2)) { if (VLOG_IS_ON(2)) {
@ -375,7 +458,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
VLOG(2) << "Result tuple shape (on device): " VLOG(2) << "Result tuple shape (on device): "
<< output.on_device_shape().DebugString(); << output.on_device_shape().DebugString();
} }
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
// If the on-host-shape isn't a tuple, create a new single-element tuple // If the on-host-shape isn't a tuple, create a new single-element tuple
// buffer with a nullptr root index table. This allows the code below to treat // buffer with a nullptr root index table. This allows the code below to treat
@ -404,82 +487,41 @@ Status XlaComputationLaunchContext::PopulateOutputs(
// Copy XLA results to the OpOutputList. // Copy XLA results to the OpOutputList.
int output_num = 0; int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) { for (int i = 0; i < ctx->num_outputs(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({}); const TensorShape& shape = compilation_result->outputs[i].shape;
if (kernel->outputs[i].is_constant) { const DataType& type = compilation_result->outputs[i].type;
// Output is a constant. VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
const Tensor& const_tensor = kernel->outputs[i].constant_value; << DataTypeString(type);
Tensor* output_tensor; if (type == DT_VARIANT) {
const size_t total_bytes = const_tensor.TotalBytes(); return errors::Unimplemented(
if (stream && total_bytes > 0) { "Support for TensorList crossing the XLA/TF boundary "
// Copy host -> device. (Empty tensors don't have backing buffers.) "is not implemented");
// Manually allocate memory using an XlaTensorBuffer so we can allocate }
// as much memory as the device requires (as given by
// GetByteSizeRequirement). This avoids XlaTransferManager having to
// reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
TF_RETURN_IF_ERROR( if (compilation_result->outputs[i].is_constant) {
ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); TF_RETURN_IF_ERROR(
SetOutputForConstant(ctx, stream, compilation_result, i));
Device* device = dynamic_cast<Device*>(ctx->device()); } else if (type == DT_RESOURCE) {
if (device == nullptr) { int input_index =
return errors::Internal("DeviceBase was not a Device."); compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
} TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
ctx->op_device_context()->CopyCPUTensorToDevice( << "Invalid input for outputs " << i << ": " << input_index;
&const_tensor, device, output_tensor, ctx->set_output(i, ctx->input(input_index));
[&](Status status) { TF_CHECK_OK(status); });
if (device->device_type() == DEVICE_GPU) {
// The GPUDeviceContext enqueues the host->device transfer in a
// separate stream from the main compute stream. We must ensure the
// compute stream is synchronized with the host->device transfer
// stream now otherwise we will create a race condition.
auto* gpu_device_context =
static_cast<GPUDeviceContext*>(ctx->op_device_context());
gpu_device_context->stream()->ThenWaitFor(
gpu_device_context->host_to_device_stream());
}
} else {
// No copy required.
ctx->set_output(i, const_tensor);
output_tensor = ctx->mutable_output(i);
}
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
xla_tensor->set_host_tensor(const_tensor);
}
} else { } else {
const TensorShape& shape = kernel->outputs[i].shape; if (allocate_xla_tensors_) {
const DataType& type = kernel->outputs[i].type; TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " input_output_alias, output_num, ctx, i, shape, &output,
<< DataTypeString(type); definition_event, stream, use_multiple_streams_));
if (type == DT_RESOURCE) {
int input_index =
kernel->outputs[i].input_index - missing_ctx_input_prefix;
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
<< "Invalid input for outputs " << i << ": " << input_index;
ctx->set_output(i, ctx->input(input_index));
} else {
if (allocate_xla_tensors_) {
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
input_output_alias, output_num, ctx, i, shape, &output,
definition_event, stream, use_multiple_streams_));
} else {
if (type == DT_VARIANT) {
return errors::Unimplemented(
"Support for TensorList crossing the XLA/TF boundary "
"is not implemented");
}
se::DeviceMemoryBase buffer = output.buffer({output_num}); } else {
Tensor output_tensor = GetOrCreateTensorForOutput( se::DeviceMemoryBase buffer = output.buffer({output_num});
output_num, ctx, missing_ctx_input_prefix, input_output_alias, Tensor output_tensor = GetOrCreateTensorForOutput(
kernel->input_mapping, resource_var_snapshots, output_num, ctx, missing_ctx_input_prefix, input_output_alias,
ctx->expected_output_dtype(i), shape, buffer, allocator); compilation_result->input_mapping, resource_var_snapshots,
output.set_buffer(se::OwningDeviceMemory(), {output_num}); ctx->expected_output_dtype(i), shape, buffer, allocator);
ctx->set_output(i, output_tensor); output.set_buffer(se::OwningDeviceMemory(), {output_num});
} ctx->set_output(i, output_tensor);
++output_num;
} }
++output_num;
} }
if (VLOG_IS_ON(3)) { if (VLOG_IS_ON(3)) {
@ -489,34 +531,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
// Apply variable updates, if any. // Apply variable updates, if any.
VLOG(2) << "Applying variable updates"; VLOG(2) << "Applying variable updates";
std::vector<VariableInfo> variable_infos; TF_ASSIGN_OR_RETURN(
variable_infos.reserve(kernel->resource_updates.size()); std::vector<VariableInfo> variable_infos,
GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix));
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
int actual_input_index = write.input_index - missing_ctx_input_prefix;
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write.");
}
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
Var* variable = nullptr;
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
ctx, HandleFromInput(ctx, actual_input_index), &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
variable_infos.emplace_back(actual_input_index, variable);
}
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos))); TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
for (int i = 0; i < kernel->resource_updates.size(); ++i) { for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write =
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; compilation_result->resource_updates[i];
if (variable_infos[i].var()->tensor()->dtype() != write.type) { if (variable_infos[i].var()->tensor()->dtype() != write.type) {
return errors::Internal("Mismatched type in variable write"); return errors::Internal("Mismatched type in variable write");
} }
@ -530,7 +552,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
output.set_buffer(se::OwningDeviceMemory(), {output_num}); output.set_buffer(se::OwningDeviceMemory(), {output_num});
Tensor output_tensor = GetOrCreateTensorForOutput( Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias, output_num, ctx, missing_ctx_input_prefix, input_output_alias,
kernel->input_mapping, resource_var_snapshots, write.type, compilation_result->input_mapping, resource_var_snapshots, write.type,
write.shape, buffer, allocator); write.shape, buffer, allocator);
*variable_infos[i].var()->tensor() = output_tensor; *variable_infos[i].var()->tensor() = output_tensor;
variable_infos[i].var()->is_initialized |= write.modified; variable_infos[i].var()->is_initialized |= write.modified;

View File

@ -136,7 +136,7 @@ class XlaComputationLaunchContext {
// input_mapping must be greater than or equal to `missing_ctx_input_prefix` // input_mapping must be greater than or equal to `missing_ctx_input_prefix`
// (in other words, no inputs actually required by the kernel can be missing). // (in other words, no inputs actually required by the kernel can be missing).
void PopulateInputs(OpKernelContext* ctx, void PopulateInputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel, const XlaCompiler::CompilationResult* compilation_result,
const std::map<int, OptionalTensor>& variables, const std::map<int, OptionalTensor>& variables,
int missing_ctx_input_prefix); int missing_ctx_input_prefix);
@ -148,10 +148,11 @@ class XlaComputationLaunchContext {
// See jit/resource_operation_safety_analysis for details. // See jit/resource_operation_safety_analysis for details.
// //
// //
// Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are // Assumes that the first `missing_ctx_input_prefix` inputs to the
// missing and adjusts input indices accordingly. // compilation_result are missing and adjusts input indices accordingly.
Status PopulateOutputs( Status PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix, xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias, const xla::HloInputOutputAliasConfig& input_output_alias,
const std::map<int, OptionalTensor>& resource_var_snapshots); const std::map<int, OptionalTensor>& resource_var_snapshots);