[XLA:GPU] Support aliasing params in XlaComputationLaunchContext::PopulateOutputs.
This is in preparation of eliding copies of "pass-through" params. PiperOrigin-RevId: 265884708
This commit is contained in:
parent
4710759932
commit
ed1fb5717f
@ -397,9 +397,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
auto elapsed = env->NowMicros() - start_time;
|
||||
VLOG(2) << "Elapsed time: " << elapsed << "us";
|
||||
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
|
||||
ctx, kernel, run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/0));
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias));
|
||||
VLOG(1) << "Done";
|
||||
}
|
||||
|
||||
@ -595,6 +597,9 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
auto elapsed = env->NowMicros() - start_time;
|
||||
VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
|
||||
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
closure.executable()->executable()->module().input_output_alias_config();
|
||||
|
||||
tensorflow::profiler::TraceMe hlo_module_activity(
|
||||
[&] {
|
||||
return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")");
|
||||
@ -605,7 +610,8 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
ctx,
|
||||
launch_context.PopulateOutputs(
|
||||
ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/closure.num_constant_args()));
|
||||
/*missing_ctx_input_prefix=*/closure.num_constant_args(),
|
||||
input_output_alias));
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
|
||||
|
||||
@ -83,9 +83,11 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
executable->Run(launch_context.arguments(), run_options);
|
||||
TF_RETURN_IF_ERROR(run_result.status());
|
||||
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
|
||||
ctx, result, run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/0));
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
@ -247,9 +247,32 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
bool MustAliasOutput(const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
int output_num) {
|
||||
xla::ShapeIndex output_index;
|
||||
if (input_output_alias.shape().IsTuple()) {
|
||||
output_index = {output_num};
|
||||
} else {
|
||||
DCHECK_EQ(output_num, 0)
|
||||
<< "output_num must be 0 for non-tuple shapes but is " << output_num;
|
||||
output_index = {};
|
||||
}
|
||||
if (input_output_alias.shape().tuple_shapes_size() == 0) {
|
||||
return false;
|
||||
}
|
||||
return input_output_alias.OutputHasAlias(output_index) &&
|
||||
input_output_alias.GetAliasedParameter(output_index).value().kind ==
|
||||
xla::HloInputOutputAliasConfig::kUserAlias;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
ScopedShapedBuffer output, int missing_ctx_input_prefix) {
|
||||
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
|
||||
@ -343,8 +366,16 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||
ctx->set_output(i, ctx->input(input_index));
|
||||
} else {
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
DCHECK(output.buffer({output_num}).is_null())
|
||||
<< "Expected output buffer to be aliased, but it is not nil.";
|
||||
}
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
if (allocate_xla_tensors_) {
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
return errors::Unimplemented(
|
||||
"Aliasing is not yet supported for allocate_xla_tensors_.");
|
||||
}
|
||||
Tensor* output_tensor;
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
|
||||
@ -359,8 +390,18 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
CHECK_EQ(output_tensor->TotalBytes(), 0);
|
||||
}
|
||||
} else {
|
||||
bool is_aliased = false;
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
int xla_param = input_output_alias.GetAliasedParameter({output_num})
|
||||
.value()
|
||||
.parameter_number;
|
||||
DCHECK(arg_ptrs_[xla_param] != nullptr);
|
||||
buffer = arg_ptrs_[xla_param]->buffer({});
|
||||
is_aliased = true;
|
||||
}
|
||||
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
ctx->expected_output_dtype(i), shape,
|
||||
/*unref_buffer=*/!is_aliased, buffer, allocator);
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
@ -424,7 +465,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
||||
write.type, write.shape, buffer, allocator);
|
||||
write.type, write.shape, /*unref_buffer=*/true, buffer, allocator);
|
||||
*variable_infos[i].var()->tensor() = output_tensor;
|
||||
}
|
||||
++output_num;
|
||||
|
||||
@ -149,10 +149,10 @@ class XlaComputationLaunchContext {
|
||||
//
|
||||
// Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
|
||||
// missing and adjusts input indices accordingly.
|
||||
Status PopulateOutputs(OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* kernel,
|
||||
xla::ScopedShapedBuffer output,
|
||||
int missing_ctx_input_prefix);
|
||||
Status PopulateOutputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias);
|
||||
|
||||
// Return the argument list. Only valid after PopulateInputs() has been
|
||||
// called.
|
||||
@ -193,12 +193,15 @@ class XlaTensorBuffer : public TensorBuffer {
|
||||
}
|
||||
|
||||
static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
|
||||
se::DeviceMemoryBase buffer, Allocator* allocator) {
|
||||
bool unref_buffer, se::DeviceMemoryBase buffer,
|
||||
Allocator* allocator) {
|
||||
size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
|
||||
auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
|
||||
buffer.size(), allocator);
|
||||
Tensor t(dtype, shape, tensor_buffer);
|
||||
tensor_buffer->Unref();
|
||||
if (unref_buffer) {
|
||||
tensor_buffer->Unref();
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
|
||||
@ -103,9 +103,13 @@ StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
|
||||
return result;
|
||||
}
|
||||
|
||||
const Shape& HloInputOutputAliasConfig::shape() const { return alias_.shape(); }
|
||||
|
||||
string HloInputOutputAliasConfig::ToString() const {
|
||||
std::vector<string> pieces;
|
||||
pieces.push_back("HloInputOutputAliasConfig");
|
||||
pieces.push_back(
|
||||
absl::StrFormat(" Output shape: %s", alias_.shape().ToString()));
|
||||
|
||||
ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) {
|
||||
const char* kind = alias.kind == AliasKind::kUserAlias ? "USER" : "SYSTEM";
|
||||
|
||||
@ -117,6 +117,9 @@ class HloInputOutputAliasConfig {
|
||||
|
||||
Status ForEachAliasWithStatus(AliasFnWithStatus fn) const;
|
||||
|
||||
// Returns the shape of the output of the alias config.
|
||||
const Shape& shape() const;
|
||||
|
||||
string ToString() const;
|
||||
|
||||
private:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user