[TF/XLA] [NFC] Simplify XlaComputationLaunchContext::PopulateOutputs
Reduce the nesting level, extract a function for gathering VariableInfo. PiperOrigin-RevId: 316720004 Change-Id: I49982058d9f7efbc2dcbb2b180c1fc95193cfa39
This commit is contained in:
parent
421e64c0c6
commit
2d50164bdb
@ -195,18 +195,20 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
|
||||
}
|
||||
|
||||
void XlaComputationLaunchContext::PopulateInputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
const std::map<int, OptionalTensor>& variables,
|
||||
int missing_ctx_input_prefix) {
|
||||
// Build ShapedBuffers that point directly to the Tensor buffers.
|
||||
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
|
||||
arg_ptrs_ =
|
||||
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
|
||||
|
||||
xla::TransferManager* transfer_manager =
|
||||
client_->backend().transfer_manager();
|
||||
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
|
||||
int arg_num = kernel->input_mapping[i];
|
||||
for (int i = 0; i < compilation_result->xla_input_shapes.size(); ++i) {
|
||||
int arg_num = compilation_result->input_mapping[i];
|
||||
CHECK_GE(arg_num, missing_ctx_input_prefix);
|
||||
const xla::Shape& shape = kernel->xla_input_shapes[i];
|
||||
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
|
||||
const Tensor* t = variables.count(arg_num)
|
||||
? &(variables.at(arg_num).value)
|
||||
: &(ctx->input(arg_num - missing_ctx_input_prefix));
|
||||
@ -361,13 +363,94 @@ static Status SetBufferForResourceVarTensorUnderAllocateXlaTensors(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Sets output `output_num` for `ctx` provided it is known at a compile time.
|
||||
static Status SetOutputForConstant(
|
||||
OpKernelContext* ctx, se::Stream* stream,
|
||||
const XlaCompiler::CompilationResult* compilation_result, int output_num) {
|
||||
CHECK(compilation_result->outputs[output_num].is_constant);
|
||||
// Output is a constant.
|
||||
const Tensor& const_tensor =
|
||||
compilation_result->outputs[output_num].constant_value;
|
||||
Tensor* output_tensor;
|
||||
const size_t total_bytes = const_tensor.TotalBytes();
|
||||
if (stream && total_bytes > 0) {
|
||||
// Copy host -> device. (Empty tensors don't have backing buffers.)
|
||||
// Manually allocate memory using an XlaTensorBuffer so we can allocate
|
||||
// as much memory as the device requires (as given by
|
||||
// GetByteSizeRequirement). This avoids XlaTransferManager having to
|
||||
// reallocate the device buffer later.
|
||||
VLOG(1) << "Constant output tensor on device";
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
|
||||
Device* device = dynamic_cast<Device*>(ctx->device());
|
||||
if (device == nullptr) {
|
||||
return errors::Internal("DeviceBase was not a Device.");
|
||||
}
|
||||
ctx->op_device_context()->CopyCPUTensorToDevice(
|
||||
&const_tensor, device, output_tensor,
|
||||
[&](Status status) { TF_CHECK_OK(status); });
|
||||
|
||||
if (device->device_type() == DEVICE_GPU) {
|
||||
// The GPUDeviceContext enqueues the host->device transfer in a
|
||||
// separate stream from the main compute stream. We must ensure the
|
||||
// compute stream is synchronized with the host->device transfer
|
||||
// stream now otherwise we will create a race condition.
|
||||
auto* gpu_device_context =
|
||||
static_cast<GPUDeviceContext*>(ctx->op_device_context());
|
||||
gpu_device_context->stream()->ThenWaitFor(
|
||||
gpu_device_context->host_to_device_stream());
|
||||
}
|
||||
} else {
|
||||
// No copy required.
|
||||
ctx->set_output(output_num, const_tensor);
|
||||
output_tensor = ctx->mutable_output(output_num);
|
||||
}
|
||||
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
|
||||
xla_tensor->set_host_tensor(const_tensor);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Creates a list of updates resource variables.
|
||||
static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
int missing_ctx_input_prefix) {
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
variable_infos.reserve(compilation_result->resource_updates.size());
|
||||
|
||||
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& write =
|
||||
compilation_result->resource_updates[i];
|
||||
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
||||
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
|
||||
return errors::Internal("Invalid input index for variable write.");
|
||||
}
|
||||
|
||||
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
|
||||
// not a Tensor.
|
||||
Var* variable = nullptr;
|
||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
|
||||
ctx, HandleFromInput(ctx, actual_input_index), &variable,
|
||||
[&write](Var** ptr) {
|
||||
*ptr = new Var(write.type);
|
||||
return Status::OK();
|
||||
}));
|
||||
variable_infos.emplace_back(actual_input_index, variable);
|
||||
}
|
||||
return variable_infos;
|
||||
}
|
||||
|
||||
Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
|
||||
// Computation output should always be a tuple.
|
||||
if (VLOG_IS_ON(2)) {
|
||||
@ -375,7 +458,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
VLOG(2) << "Result tuple shape (on device): "
|
||||
<< output.on_device_shape().DebugString();
|
||||
}
|
||||
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
|
||||
CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
|
||||
|
||||
// If the on-host-shape isn't a tuple, create a new single-element tuple
|
||||
// buffer with a nullptr root index table. This allows the code below to treat
|
||||
@ -404,82 +487,41 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
// Copy XLA results to the OpOutputList.
|
||||
int output_num = 0;
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
if (kernel->outputs[i].is_constant) {
|
||||
// Output is a constant.
|
||||
const Tensor& const_tensor = kernel->outputs[i].constant_value;
|
||||
Tensor* output_tensor;
|
||||
const size_t total_bytes = const_tensor.TotalBytes();
|
||||
if (stream && total_bytes > 0) {
|
||||
// Copy host -> device. (Empty tensors don't have backing buffers.)
|
||||
// Manually allocate memory using an XlaTensorBuffer so we can allocate
|
||||
// as much memory as the device requires (as given by
|
||||
// GetByteSizeRequirement). This avoids XlaTransferManager having to
|
||||
// reallocate the device buffer later.
|
||||
VLOG(1) << "Constant output tensor on device";
|
||||
const TensorShape& shape = compilation_result->outputs[i].shape;
|
||||
const DataType& type = compilation_result->outputs[i].type;
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||
<< DataTypeString(type);
|
||||
if (type == DT_VARIANT) {
|
||||
return errors::Unimplemented(
|
||||
"Support for TensorList crossing the XLA/TF boundary "
|
||||
"is not implemented");
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
|
||||
|
||||
Device* device = dynamic_cast<Device*>(ctx->device());
|
||||
if (device == nullptr) {
|
||||
return errors::Internal("DeviceBase was not a Device.");
|
||||
}
|
||||
ctx->op_device_context()->CopyCPUTensorToDevice(
|
||||
&const_tensor, device, output_tensor,
|
||||
[&](Status status) { TF_CHECK_OK(status); });
|
||||
|
||||
if (device->device_type() == DEVICE_GPU) {
|
||||
// The GPUDeviceContext enqueues the host->device transfer in a
|
||||
// separate stream from the main compute stream. We must ensure the
|
||||
// compute stream is synchronized with the host->device transfer
|
||||
// stream now otherwise we will create a race condition.
|
||||
auto* gpu_device_context =
|
||||
static_cast<GPUDeviceContext*>(ctx->op_device_context());
|
||||
gpu_device_context->stream()->ThenWaitFor(
|
||||
gpu_device_context->host_to_device_stream());
|
||||
}
|
||||
} else {
|
||||
// No copy required.
|
||||
ctx->set_output(i, const_tensor);
|
||||
output_tensor = ctx->mutable_output(i);
|
||||
}
|
||||
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
|
||||
xla_tensor->set_host_tensor(const_tensor);
|
||||
}
|
||||
if (compilation_result->outputs[i].is_constant) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
SetOutputForConstant(ctx, stream, compilation_result, i));
|
||||
} else if (type == DT_RESOURCE) {
|
||||
int input_index =
|
||||
compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
|
||||
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
|
||||
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||
ctx->set_output(i, ctx->input(input_index));
|
||||
} else {
|
||||
const TensorShape& shape = kernel->outputs[i].shape;
|
||||
const DataType& type = kernel->outputs[i].type;
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||
<< DataTypeString(type);
|
||||
if (type == DT_RESOURCE) {
|
||||
int input_index =
|
||||
kernel->outputs[i].input_index - missing_ctx_input_prefix;
|
||||
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
|
||||
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||
ctx->set_output(i, ctx->input(input_index));
|
||||
} else {
|
||||
if (allocate_xla_tensors_) {
|
||||
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
|
||||
input_output_alias, output_num, ctx, i, shape, &output,
|
||||
definition_event, stream, use_multiple_streams_));
|
||||
} else {
|
||||
if (type == DT_VARIANT) {
|
||||
return errors::Unimplemented(
|
||||
"Support for TensorList crossing the XLA/TF boundary "
|
||||
"is not implemented");
|
||||
}
|
||||
if (allocate_xla_tensors_) {
|
||||
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
|
||||
input_output_alias, output_num, ctx, i, shape, &output,
|
||||
definition_event, stream, use_multiple_streams_));
|
||||
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
kernel->input_mapping, resource_var_snapshots,
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
++output_num;
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
compilation_result->input_mapping, resource_var_snapshots,
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
++output_num;
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
@ -489,34 +531,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
|
||||
// Apply variable updates, if any.
|
||||
VLOG(2) << "Applying variable updates";
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
variable_infos.reserve(kernel->resource_updates.size());
|
||||
|
||||
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
|
||||
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
||||
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
|
||||
return errors::Internal("Invalid input index for variable write.");
|
||||
}
|
||||
|
||||
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
|
||||
// not a Tensor.
|
||||
Var* variable = nullptr;
|
||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
|
||||
ctx, HandleFromInput(ctx, actual_input_index), &variable,
|
||||
[&write](Var** ptr) {
|
||||
*ptr = new Var(write.type);
|
||||
return Status::OK();
|
||||
}));
|
||||
variable_infos.emplace_back(actual_input_index, variable);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<VariableInfo> variable_infos,
|
||||
GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix));
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||
|
||||
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
|
||||
|
||||
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& write =
|
||||
compilation_result->resource_updates[i];
|
||||
if (variable_infos[i].var()->tensor()->dtype() != write.type) {
|
||||
return errors::Internal("Mismatched type in variable write");
|
||||
}
|
||||
@ -530,7 +552,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
kernel->input_mapping, resource_var_snapshots, write.type,
|
||||
compilation_result->input_mapping, resource_var_snapshots, write.type,
|
||||
write.shape, buffer, allocator);
|
||||
*variable_infos[i].var()->tensor() = output_tensor;
|
||||
variable_infos[i].var()->is_initialized |= write.modified;
|
||||
|
@ -136,7 +136,7 @@ class XlaComputationLaunchContext {
|
||||
// input_mapping must be greater than or equal to `missing_ctx_input_prefix`
|
||||
// (in other words, no inputs actually required by the kernel can be missing).
|
||||
void PopulateInputs(OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* kernel,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
const std::map<int, OptionalTensor>& variables,
|
||||
int missing_ctx_input_prefix);
|
||||
|
||||
@ -148,10 +148,11 @@ class XlaComputationLaunchContext {
|
||||
// See jit/resource_operation_safety_analysis for details.
|
||||
//
|
||||
//
|
||||
// Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
|
||||
// missing and adjusts input indices accordingly.
|
||||
// Assumes that the first `missing_ctx_input_prefix` inputs to the
|
||||
// compilation_result are missing and adjusts input indices accordingly.
|
||||
Status PopulateOutputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots);
|
||||
|
Loading…
Reference in New Issue
Block a user