[TF/XLA] [NFC] Simplify XlaComputationLaunchContext::PopulateOutputs
Reduce the nesting level, extract a function for gathering VariableInfo. PiperOrigin-RevId: 316720004 Change-Id: I49982058d9f7efbc2dcbb2b180c1fc95193cfa39
This commit is contained in:
parent
421e64c0c6
commit
2d50164bdb
@ -195,18 +195,20 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void XlaComputationLaunchContext::PopulateInputs(
|
void XlaComputationLaunchContext::PopulateInputs(
|
||||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
OpKernelContext* ctx,
|
||||||
|
const XlaCompiler::CompilationResult* compilation_result,
|
||||||
const std::map<int, OptionalTensor>& variables,
|
const std::map<int, OptionalTensor>& variables,
|
||||||
int missing_ctx_input_prefix) {
|
int missing_ctx_input_prefix) {
|
||||||
// Build ShapedBuffers that point directly to the Tensor buffers.
|
// Build ShapedBuffers that point directly to the Tensor buffers.
|
||||||
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
|
arg_ptrs_ =
|
||||||
|
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
|
||||||
|
|
||||||
xla::TransferManager* transfer_manager =
|
xla::TransferManager* transfer_manager =
|
||||||
client_->backend().transfer_manager();
|
client_->backend().transfer_manager();
|
||||||
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
|
for (int i = 0; i < compilation_result->xla_input_shapes.size(); ++i) {
|
||||||
int arg_num = kernel->input_mapping[i];
|
int arg_num = compilation_result->input_mapping[i];
|
||||||
CHECK_GE(arg_num, missing_ctx_input_prefix);
|
CHECK_GE(arg_num, missing_ctx_input_prefix);
|
||||||
const xla::Shape& shape = kernel->xla_input_shapes[i];
|
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
|
||||||
const Tensor* t = variables.count(arg_num)
|
const Tensor* t = variables.count(arg_num)
|
||||||
? &(variables.at(arg_num).value)
|
? &(variables.at(arg_num).value)
|
||||||
: &(ctx->input(arg_num - missing_ctx_input_prefix));
|
: &(ctx->input(arg_num - missing_ctx_input_prefix));
|
||||||
@ -361,13 +363,94 @@ static Status SetBufferForResourceVarTensorUnderAllocateXlaTensors(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sets output `output_num` for `ctx` provided it is known at a compile time.
|
||||||
|
static Status SetOutputForConstant(
|
||||||
|
OpKernelContext* ctx, se::Stream* stream,
|
||||||
|
const XlaCompiler::CompilationResult* compilation_result, int output_num) {
|
||||||
|
CHECK(compilation_result->outputs[output_num].is_constant);
|
||||||
|
// Output is a constant.
|
||||||
|
const Tensor& const_tensor =
|
||||||
|
compilation_result->outputs[output_num].constant_value;
|
||||||
|
Tensor* output_tensor;
|
||||||
|
const size_t total_bytes = const_tensor.TotalBytes();
|
||||||
|
if (stream && total_bytes > 0) {
|
||||||
|
// Copy host -> device. (Empty tensors don't have backing buffers.)
|
||||||
|
// Manually allocate memory using an XlaTensorBuffer so we can allocate
|
||||||
|
// as much memory as the device requires (as given by
|
||||||
|
// GetByteSizeRequirement). This avoids XlaTransferManager having to
|
||||||
|
// reallocate the device buffer later.
|
||||||
|
VLOG(1) << "Constant output tensor on device";
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
|
||||||
|
Device* device = dynamic_cast<Device*>(ctx->device());
|
||||||
|
if (device == nullptr) {
|
||||||
|
return errors::Internal("DeviceBase was not a Device.");
|
||||||
|
}
|
||||||
|
ctx->op_device_context()->CopyCPUTensorToDevice(
|
||||||
|
&const_tensor, device, output_tensor,
|
||||||
|
[&](Status status) { TF_CHECK_OK(status); });
|
||||||
|
|
||||||
|
if (device->device_type() == DEVICE_GPU) {
|
||||||
|
// The GPUDeviceContext enqueues the host->device transfer in a
|
||||||
|
// separate stream from the main compute stream. We must ensure the
|
||||||
|
// compute stream is synchronized with the host->device transfer
|
||||||
|
// stream now otherwise we will create a race condition.
|
||||||
|
auto* gpu_device_context =
|
||||||
|
static_cast<GPUDeviceContext*>(ctx->op_device_context());
|
||||||
|
gpu_device_context->stream()->ThenWaitFor(
|
||||||
|
gpu_device_context->host_to_device_stream());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No copy required.
|
||||||
|
ctx->set_output(output_num, const_tensor);
|
||||||
|
output_tensor = ctx->mutable_output(output_num);
|
||||||
|
}
|
||||||
|
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
|
||||||
|
xla_tensor->set_host_tensor(const_tensor);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates a list of updates resource variables.
|
||||||
|
static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
|
||||||
|
OpKernelContext* ctx,
|
||||||
|
const XlaCompiler::CompilationResult* compilation_result,
|
||||||
|
int missing_ctx_input_prefix) {
|
||||||
|
std::vector<VariableInfo> variable_infos;
|
||||||
|
variable_infos.reserve(compilation_result->resource_updates.size());
|
||||||
|
|
||||||
|
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
||||||
|
const XlaCompiler::ResourceUpdate& write =
|
||||||
|
compilation_result->resource_updates[i];
|
||||||
|
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
||||||
|
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
|
||||||
|
return errors::Internal("Invalid input index for variable write.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
|
||||||
|
// not a Tensor.
|
||||||
|
Var* variable = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
|
||||||
|
ctx, HandleFromInput(ctx, actual_input_index), &variable,
|
||||||
|
[&write](Var** ptr) {
|
||||||
|
*ptr = new Var(write.type);
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
variable_infos.emplace_back(actual_input_index, variable);
|
||||||
|
}
|
||||||
|
return variable_infos;
|
||||||
|
}
|
||||||
|
|
||||||
Status XlaComputationLaunchContext::PopulateOutputs(
|
Status XlaComputationLaunchContext::PopulateOutputs(
|
||||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
OpKernelContext* ctx,
|
||||||
|
const XlaCompiler::CompilationResult* compilation_result,
|
||||||
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||||
const std::map<int, OptionalTensor>& resource_var_snapshots) {
|
const std::map<int, OptionalTensor>& resource_var_snapshots) {
|
||||||
se::Stream* stream =
|
se::Stream* stream =
|
||||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||||
|
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||||
|
|
||||||
// Computation output should always be a tuple.
|
// Computation output should always be a tuple.
|
||||||
if (VLOG_IS_ON(2)) {
|
if (VLOG_IS_ON(2)) {
|
||||||
@ -375,7 +458,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
VLOG(2) << "Result tuple shape (on device): "
|
VLOG(2) << "Result tuple shape (on device): "
|
||||||
<< output.on_device_shape().DebugString();
|
<< output.on_device_shape().DebugString();
|
||||||
}
|
}
|
||||||
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
|
CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
|
||||||
|
|
||||||
// If the on-host-shape isn't a tuple, create a new single-element tuple
|
// If the on-host-shape isn't a tuple, create a new single-element tuple
|
||||||
// buffer with a nullptr root index table. This allows the code below to treat
|
// buffer with a nullptr root index table. This allows the code below to treat
|
||||||
@ -404,82 +487,41 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
// Copy XLA results to the OpOutputList.
|
// Copy XLA results to the OpOutputList.
|
||||||
int output_num = 0;
|
int output_num = 0;
|
||||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
const TensorShape& shape = compilation_result->outputs[i].shape;
|
||||||
if (kernel->outputs[i].is_constant) {
|
const DataType& type = compilation_result->outputs[i].type;
|
||||||
// Output is a constant.
|
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||||
const Tensor& const_tensor = kernel->outputs[i].constant_value;
|
<< DataTypeString(type);
|
||||||
Tensor* output_tensor;
|
if (type == DT_VARIANT) {
|
||||||
const size_t total_bytes = const_tensor.TotalBytes();
|
return errors::Unimplemented(
|
||||||
if (stream && total_bytes > 0) {
|
"Support for TensorList crossing the XLA/TF boundary "
|
||||||
// Copy host -> device. (Empty tensors don't have backing buffers.)
|
"is not implemented");
|
||||||
// Manually allocate memory using an XlaTensorBuffer so we can allocate
|
}
|
||||||
// as much memory as the device requires (as given by
|
|
||||||
// GetByteSizeRequirement). This avoids XlaTransferManager having to
|
|
||||||
// reallocate the device buffer later.
|
|
||||||
VLOG(1) << "Constant output tensor on device";
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
if (compilation_result->outputs[i].is_constant) {
|
||||||
ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
|
TF_RETURN_IF_ERROR(
|
||||||
|
SetOutputForConstant(ctx, stream, compilation_result, i));
|
||||||
Device* device = dynamic_cast<Device*>(ctx->device());
|
} else if (type == DT_RESOURCE) {
|
||||||
if (device == nullptr) {
|
int input_index =
|
||||||
return errors::Internal("DeviceBase was not a Device.");
|
compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
|
||||||
}
|
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
|
||||||
ctx->op_device_context()->CopyCPUTensorToDevice(
|
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||||
&const_tensor, device, output_tensor,
|
ctx->set_output(i, ctx->input(input_index));
|
||||||
[&](Status status) { TF_CHECK_OK(status); });
|
|
||||||
|
|
||||||
if (device->device_type() == DEVICE_GPU) {
|
|
||||||
// The GPUDeviceContext enqueues the host->device transfer in a
|
|
||||||
// separate stream from the main compute stream. We must ensure the
|
|
||||||
// compute stream is synchronized with the host->device transfer
|
|
||||||
// stream now otherwise we will create a race condition.
|
|
||||||
auto* gpu_device_context =
|
|
||||||
static_cast<GPUDeviceContext*>(ctx->op_device_context());
|
|
||||||
gpu_device_context->stream()->ThenWaitFor(
|
|
||||||
gpu_device_context->host_to_device_stream());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// No copy required.
|
|
||||||
ctx->set_output(i, const_tensor);
|
|
||||||
output_tensor = ctx->mutable_output(i);
|
|
||||||
}
|
|
||||||
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
|
|
||||||
xla_tensor->set_host_tensor(const_tensor);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
const TensorShape& shape = kernel->outputs[i].shape;
|
if (allocate_xla_tensors_) {
|
||||||
const DataType& type = kernel->outputs[i].type;
|
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
|
||||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
input_output_alias, output_num, ctx, i, shape, &output,
|
||||||
<< DataTypeString(type);
|
definition_event, stream, use_multiple_streams_));
|
||||||
if (type == DT_RESOURCE) {
|
|
||||||
int input_index =
|
|
||||||
kernel->outputs[i].input_index - missing_ctx_input_prefix;
|
|
||||||
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
|
|
||||||
<< "Invalid input for outputs " << i << ": " << input_index;
|
|
||||||
ctx->set_output(i, ctx->input(input_index));
|
|
||||||
} else {
|
|
||||||
if (allocate_xla_tensors_) {
|
|
||||||
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
|
|
||||||
input_output_alias, output_num, ctx, i, shape, &output,
|
|
||||||
definition_event, stream, use_multiple_streams_));
|
|
||||||
} else {
|
|
||||||
if (type == DT_VARIANT) {
|
|
||||||
return errors::Unimplemented(
|
|
||||||
"Support for TensorList crossing the XLA/TF boundary "
|
|
||||||
"is not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
} else {
|
||||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||||
kernel->input_mapping, resource_var_snapshots,
|
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
compilation_result->input_mapping, resource_var_snapshots,
|
||||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||||
ctx->set_output(i, output_tensor);
|
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||||
}
|
ctx->set_output(i, output_tensor);
|
||||||
++output_num;
|
|
||||||
}
|
}
|
||||||
|
++output_num;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (VLOG_IS_ON(3)) {
|
if (VLOG_IS_ON(3)) {
|
||||||
@ -489,34 +531,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
|
|
||||||
// Apply variable updates, if any.
|
// Apply variable updates, if any.
|
||||||
VLOG(2) << "Applying variable updates";
|
VLOG(2) << "Applying variable updates";
|
||||||
std::vector<VariableInfo> variable_infos;
|
TF_ASSIGN_OR_RETURN(
|
||||||
variable_infos.reserve(kernel->resource_updates.size());
|
std::vector<VariableInfo> variable_infos,
|
||||||
|
GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix));
|
||||||
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
|
|
||||||
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
|
|
||||||
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
|
||||||
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
|
|
||||||
return errors::Internal("Invalid input index for variable write.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
|
|
||||||
// not a Tensor.
|
|
||||||
Var* variable = nullptr;
|
|
||||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
|
|
||||||
ctx, HandleFromInput(ctx, actual_input_index), &variable,
|
|
||||||
[&write](Var** ptr) {
|
|
||||||
*ptr = new Var(write.type);
|
|
||||||
return Status::OK();
|
|
||||||
}));
|
|
||||||
variable_infos.emplace_back(actual_input_index, variable);
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||||
|
|
||||||
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
|
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
||||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
const XlaCompiler::ResourceUpdate& write =
|
||||||
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
|
compilation_result->resource_updates[i];
|
||||||
|
|
||||||
if (variable_infos[i].var()->tensor()->dtype() != write.type) {
|
if (variable_infos[i].var()->tensor()->dtype() != write.type) {
|
||||||
return errors::Internal("Mismatched type in variable write");
|
return errors::Internal("Mismatched type in variable write");
|
||||||
}
|
}
|
||||||
@ -530,7 +552,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||||
kernel->input_mapping, resource_var_snapshots, write.type,
|
compilation_result->input_mapping, resource_var_snapshots, write.type,
|
||||||
write.shape, buffer, allocator);
|
write.shape, buffer, allocator);
|
||||||
*variable_infos[i].var()->tensor() = output_tensor;
|
*variable_infos[i].var()->tensor() = output_tensor;
|
||||||
variable_infos[i].var()->is_initialized |= write.modified;
|
variable_infos[i].var()->is_initialized |= write.modified;
|
||||||
|
@ -136,7 +136,7 @@ class XlaComputationLaunchContext {
|
|||||||
// input_mapping must be greater than or equal to `missing_ctx_input_prefix`
|
// input_mapping must be greater than or equal to `missing_ctx_input_prefix`
|
||||||
// (in other words, no inputs actually required by the kernel can be missing).
|
// (in other words, no inputs actually required by the kernel can be missing).
|
||||||
void PopulateInputs(OpKernelContext* ctx,
|
void PopulateInputs(OpKernelContext* ctx,
|
||||||
const XlaCompiler::CompilationResult* kernel,
|
const XlaCompiler::CompilationResult* compilation_result,
|
||||||
const std::map<int, OptionalTensor>& variables,
|
const std::map<int, OptionalTensor>& variables,
|
||||||
int missing_ctx_input_prefix);
|
int missing_ctx_input_prefix);
|
||||||
|
|
||||||
@ -148,10 +148,11 @@ class XlaComputationLaunchContext {
|
|||||||
// See jit/resource_operation_safety_analysis for details.
|
// See jit/resource_operation_safety_analysis for details.
|
||||||
//
|
//
|
||||||
//
|
//
|
||||||
// Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
|
// Assumes that the first `missing_ctx_input_prefix` inputs to the
|
||||||
// missing and adjusts input indices accordingly.
|
// compilation_result are missing and adjusts input indices accordingly.
|
||||||
Status PopulateOutputs(
|
Status PopulateOutputs(
|
||||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
OpKernelContext* ctx,
|
||||||
|
const XlaCompiler::CompilationResult* compilation_result,
|
||||||
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||||
const std::map<int, OptionalTensor>& resource_var_snapshots);
|
const std::map<int, OptionalTensor>& resource_var_snapshots);
|
||||||
|
Loading…
Reference in New Issue
Block a user