[XLA:GPU] Support aliasing params in XlaComputationLaunchContext::PopulateOutputs.

This is in preparation of eliding copies of "pass-through" params.

PiperOrigin-RevId: 265884708
This commit is contained in:
Thomas Joerg 2019-08-28 04:54:39 -07:00 committed by TensorFlower Gardener
parent 4710759932
commit ed1fb5717f
6 changed files with 71 additions and 12 deletions

View File

@ -397,9 +397,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us";
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
ctx, kernel, run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/0));
/*missing_ctx_input_prefix=*/0, input_output_alias));
VLOG(1) << "Done";
}
@ -595,6 +597,9 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
const xla::HloInputOutputAliasConfig& input_output_alias =
closure.executable()->executable()->module().input_output_alias_config();
tensorflow::profiler::TraceMe hlo_module_activity(
[&] {
return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")");
@ -605,7 +610,8 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
ctx,
launch_context.PopulateOutputs(
ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/closure.num_constant_args()));
/*missing_ctx_input_prefix=*/closure.num_constant_args(),
input_output_alias));
}
REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);

View File

@ -83,9 +83,11 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
executable->Run(launch_context.arguments(), run_options);
TF_RETURN_IF_ERROR(run_result.status());
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
ctx, result, run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/0));
/*missing_ctx_input_prefix=*/0, input_output_alias));
return Status::OK();
}

View File

@ -247,9 +247,32 @@ void XlaComputationLaunchContext::PopulateInputs(
}
}
namespace {
bool MustAliasOutput(const xla::HloInputOutputAliasConfig& input_output_alias,
int output_num) {
xla::ShapeIndex output_index;
if (input_output_alias.shape().IsTuple()) {
output_index = {output_num};
} else {
DCHECK_EQ(output_num, 0)
<< "output_num must be 0 for non-tuple shapes but is " << output_num;
output_index = {};
}
if (input_output_alias.shape().tuple_shapes_size() == 0) {
return false;
}
return input_output_alias.OutputHasAlias(output_index) &&
input_output_alias.GetAliasedParameter(output_index).value().kind ==
xla::HloInputOutputAliasConfig::kUserAlias;
}
} // namespace
Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
ScopedShapedBuffer output, int missing_ctx_input_prefix) {
ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
@ -343,8 +366,16 @@ Status XlaComputationLaunchContext::PopulateOutputs(
<< "Invalid input for outputs " << i << ": " << input_index;
ctx->set_output(i, ctx->input(input_index));
} else {
if (MustAliasOutput(input_output_alias, output_num)) {
DCHECK(output.buffer({output_num}).is_null())
<< "Expected output buffer to be aliased, but it is not nil.";
}
se::DeviceMemoryBase buffer = output.buffer({output_num});
if (allocate_xla_tensors_) {
if (MustAliasOutput(input_output_alias, output_num)) {
return errors::Unimplemented(
"Aliasing is not yet supported for allocate_xla_tensors_.");
}
Tensor* output_tensor;
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
@ -359,8 +390,18 @@ Status XlaComputationLaunchContext::PopulateOutputs(
CHECK_EQ(output_tensor->TotalBytes(), 0);
}
} else {
bool is_aliased = false;
if (MustAliasOutput(input_output_alias, output_num)) {
int xla_param = input_output_alias.GetAliasedParameter({output_num})
.value()
.parameter_number;
DCHECK(arg_ptrs_[xla_param] != nullptr);
buffer = arg_ptrs_[xla_param]->buffer({});
is_aliased = true;
}
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
ctx->expected_output_dtype(i), shape, buffer, allocator);
ctx->expected_output_dtype(i), shape,
/*unref_buffer=*/!is_aliased, buffer, allocator);
output.set_buffer(se::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor);
}
@ -424,7 +465,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
se::DeviceMemoryBase buffer = output.buffer({output_num});
output.set_buffer(se::OwningDeviceMemory(), {output_num});
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
write.type, write.shape, buffer, allocator);
write.type, write.shape, /*unref_buffer=*/true, buffer, allocator);
*variable_infos[i].var()->tensor() = output_tensor;
}
++output_num;

View File

@ -149,10 +149,10 @@ class XlaComputationLaunchContext {
//
// Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
// missing and adjusts input indices accordingly.
Status PopulateOutputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
xla::ScopedShapedBuffer output,
int missing_ctx_input_prefix);
Status PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias);
// Return the argument list. Only valid after PopulateInputs() has been
// called.
@ -193,12 +193,15 @@ class XlaTensorBuffer : public TensorBuffer {
}
static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
se::DeviceMemoryBase buffer, Allocator* allocator) {
bool unref_buffer, se::DeviceMemoryBase buffer,
Allocator* allocator) {
size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
buffer.size(), allocator);
Tensor t(dtype, shape, tensor_buffer);
tensor_buffer->Unref();
if (unref_buffer) {
tensor_buffer->Unref();
}
return t;
}

View File

@ -103,9 +103,13 @@ StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
return result;
}
const Shape& HloInputOutputAliasConfig::shape() const { return alias_.shape(); }
string HloInputOutputAliasConfig::ToString() const {
std::vector<string> pieces;
pieces.push_back("HloInputOutputAliasConfig");
pieces.push_back(
absl::StrFormat(" Output shape: %s", alias_.shape().ToString()));
ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) {
const char* kind = alias.kind == AliasKind::kUserAlias ? "USER" : "SYSTEM";

View File

@ -117,6 +117,9 @@ class HloInputOutputAliasConfig {
Status ForEachAliasWithStatus(AliasFnWithStatus fn) const;
// Returns the shape of the output of the alias config.
const Shape& shape() const;
string ToString() const;
private: