Rollback of rollback of [TF/XLA] Enable input/output aliasing in the TF2XLA bridge

The underlying bug was fixed

PiperOrigin-RevId: 321863222
Change-Id: I94c25f3243e33374ee089dd808c3f25704de2c92
This commit is contained in:
George Karpenkov 2020-07-17 15:08:25 -07:00 committed by TensorFlower Gardener
parent 01d9e46f28
commit ec25e31817
6 changed files with 319 additions and 168 deletions

View File

@ -277,7 +277,8 @@ static Status CompileToLocalExecutable(
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars, OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
const XlaPlatformInfo& platform_info, const XlaPlatformInfo& platform_info,
absl::Span<VariableInfo const> variable_infos, absl::Span<VariableInfo const> variable_infos,
absl::Span<const int> constants, bool lazy, xla::LocalClient** client, absl::Span<const int> constants, bool lazy, bool may_alias_resource_update,
xla::LocalClient** client,
const XlaCompiler::CompilationResult** compilation_result, const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable) { xla::LocalExecutable** executable) {
// We store information about the JIT-compiled XLA computation // We store information about the JIT-compiled XLA computation
@ -332,6 +333,9 @@ static Status CompileToLocalExecutable(
// Optimization: where possible, have the computation return a naked array // Optimization: where possible, have the computation return a naked array
// rather than a one-element tuple. // rather than a one-element tuple.
compile_options.always_return_tuple = false; compile_options.always_return_tuple = false;
compile_options.alias_resource_update = !has_ref_vars &&
!platform_info.is_on_xla_device() &&
may_alias_resource_update;
std::vector<XlaCompiler::Argument> args; std::vector<XlaCompiler::Argument> args;
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
@ -350,20 +354,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
const XlaCompiler::CompilationResult* compilation_result; const XlaCompiler::CompilationResult* compilation_result;
xla::LocalExecutable* executable; xla::LocalExecutable* executable;
ResourceVarsSnapshot variables_snapshot; std::vector<VariableInfo> variable_infos;
{ {
std::vector<VariableInfo> variable_infos;
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos)); ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos))); OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
Status s = CompileToLocalExecutable( Status s = CompileToLocalExecutable(
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
variable_infos, constants_, /*lazy=*/false, &client, variable_infos, constants_, /*lazy=*/false,
&compilation_result, &executable); /*may_alias_resource_update=*/true, &client, &compilation_result,
&executable);
OP_REQUIRES_OK(ctx, s); OP_REQUIRES_OK(ctx, s);
OP_REQUIRES_OK(ctx, }
SnapshotResourceVariables(ctx, resources_, variable_infos,
&variables_snapshot)); std::map<int, const Tensor*> resource_var_ptrs;
for (int i = 0; i < resources_.size(); i++) {
resource_var_ptrs[resources_[i]] = variable_infos[i].var()->tensor();
} }
se::Stream* stream = se::Stream* stream =
@ -374,12 +380,19 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter; absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator = se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_); GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
int device_ordinal = stream ? stream->parent()->device_ordinal()
: client->default_device_ordinal();
XlaComputationLaunchContext launch_context( XlaComputationLaunchContext launch_context(
client, allocator, client, allocator, device_ordinal,
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
platform_info_.UseMultipleStreams()); platform_info_.UseMultipleStreams());
launch_context.PopulateInputs(ctx, compilation_result, variables_snapshot, const xla::HloInputOutputAliasConfig& input_output_alias =
/*missing_ctx_input_prefix=*/0); executable->executable()->module().input_output_alias_config();
xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
launch_context.PopulateInputs(ctx, compilation_result, resource_var_ptrs,
/*missing_ctx_input_prefix=*/0,
input_output_alias);
OP_REQUIRES_OK(ctx, execution_inputs.status());
// Execute the computation. // Execute the computation.
VLOG(2) << "Executing computation."; VLOG(2) << "Executing computation.";
@ -403,24 +416,24 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
Env* env = Env::Default(); Env* env = Env::Default();
auto start_time = env->NowMicros(); auto start_time = env->NowMicros();
xla::StatusOr<xla::ScopedShapedBuffer> run_result; xla::StatusOr<xla::ExecutionOutput> execution_output;
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) { if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
run_result = executable->Run(launch_context.arguments(), run_options); execution_output =
executable->Run(std::move(*execution_inputs), run_options);
} else { } else {
run_result = executable->RunAsync(launch_context.arguments(), run_options); execution_output =
executable->RunAsync(std::move(*execution_inputs), run_options);
} }
OP_REQUIRES(ctx, run_result.ok(), run_result.status()); OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
auto elapsed = env->NowMicros() - start_time; auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us"; VLOG(2) << "Elapsed time: " << elapsed << "us";
OP_REQUIRES_OK(
ctx, launch_context.PopulateOutputs(
ctx, compilation_result, execution_output->ConsumeResult(),
/*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos),
input_output_alias, resource_var_ptrs));
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
OP_REQUIRES_OK(ctx,
launch_context.PopulateOutputs(
ctx, compilation_result, run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/0, input_output_alias,
variables_snapshot));
VLOG(1) << "Done"; VLOG(1) << "Done";
} }
@ -516,10 +529,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos)); ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos))); OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
// Do not alias resource updates as locking variables in XlaCompile and
// unlocking them in XlaRun may lead to deadlocks.
Status status = CompileToLocalExecutable( Status status = CompileToLocalExecutable(
ctx, function_, has_ref_vars_, platform_info_, variable_infos, ctx, function_, has_ref_vars_, platform_info_, variable_infos,
constants_, constants_,
/*lazy=*/!must_compile_, &client, &kernel, &executable); /*lazy=*/!must_compile_,
/*may_alias_resource_update=*/false, &client, &kernel, &executable);
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_, OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
variable_infos, &variables)); variable_infos, &variables));
if (must_compile_ || status.code() != error::UNIMPLEMENTED) { if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
@ -587,14 +604,22 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter; absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator = se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_); GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
int device_ordinal = stream ? stream->parent()->device_ordinal()
: closure.client()->default_device_ordinal();
XlaComputationLaunchContext launch_context( XlaComputationLaunchContext launch_context(
closure.client(), allocator, closure.client(), allocator, device_ordinal,
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
/*use_multiple_streams=*/platform_info_.UseMultipleStreams()); /*use_multiple_streams=*/platform_info_.UseMultipleStreams());
// We're missing the must-be-constant inputs, tell `PopulateInputs` // We're missing the must-be-constant inputs, tell `PopulateInputs`
// about this. We don't actually need these inputs because they've // about this. We don't actually need these inputs because they've
// already been baked into the compiled kernel. // already been baked into the compiled kernel.
const xla::HloInputOutputAliasConfig& input_output_alias =
closure.executable()->executable()->module().input_output_alias_config();
xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs;
std::map<int, const Tensor*> snapshot_ptrs;
{ {
tensorflow::profiler::TraceMe hlo_module_activity( tensorflow::profiler::TraceMe hlo_module_activity(
[&] { [&] {
@ -604,13 +629,17 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
}, },
tensorflow::profiler::TraceMeLevel::kInfo); tensorflow::profiler::TraceMeLevel::kInfo);
launch_context.PopulateInputs( for (auto& p : closure.resource_var_snapshots()) {
ctx, closure.compilation_result(), closure.resource_var_snapshots(), snapshot_ptrs.emplace(p.first,
/*missing_ctx_input_prefix=*/closure.num_constant_args()); p.second.has_value() ? &p.second.value() : nullptr);
}
execution_inputs = launch_context.PopulateInputs(
ctx, closure.compilation_result(), snapshot_ptrs,
/*missing_ctx_input_prefix=*/closure.num_constant_args(),
input_output_alias);
OP_REQUIRES_OK(ctx, execution_inputs.status());
} }
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
xla::ExecutableRunOptions run_options; xla::ExecutableRunOptions run_options;
run_options.set_stream(stream); run_options.set_stream(stream);
run_options.set_allocator(allocator); run_options.set_allocator(allocator);
@ -631,21 +660,19 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
Env* env = Env::Default(); Env* env = Env::Default();
auto start_time = env->NowMicros(); auto start_time = env->NowMicros();
xla::StatusOr<xla::ScopedShapedBuffer> run_result; xla::StatusOr<xla::ExecutionOutput> execution_output;
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) { if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
run_result = execution_output =
closure.executable()->Run(launch_context.arguments(), run_options); closure.executable()->Run(std::move(*execution_inputs), run_options);
} else { } else {
run_result = execution_output = closure.executable()->RunAsync(
closure.executable()->RunAsync(launch_context.arguments(), run_options); std::move(*execution_inputs), run_options);
} }
OP_REQUIRES(ctx, run_result.ok(), run_result.status()); OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
auto elapsed = env->NowMicros() - start_time; auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time in computation: " << elapsed << "us"; VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
const xla::HloInputOutputAliasConfig& input_output_alias =
closure.executable()->executable()->module().input_output_alias_config();
tensorflow::profiler::TraceMe hlo_module_activity( tensorflow::profiler::TraceMe hlo_module_activity(
[&] { [&] {
@ -653,12 +680,16 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
}, },
tensorflow::profiler::TraceMeLevel::kInfo); tensorflow::profiler::TraceMeLevel::kInfo);
xla::StatusOr<std::vector<VariableInfo>> variable_infos = GatherVariableInfo(
ctx, *closure.compilation_result(), closure.num_constant_args());
OP_REQUIRES_OK(ctx, variable_infos.status());
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(*variable_infos)));
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, ctx,
launch_context.PopulateOutputs( launch_context.PopulateOutputs(
ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(), ctx, closure.compilation_result(), execution_output->ConsumeResult(),
/*missing_ctx_input_prefix=*/closure.num_constant_args(), /*missing_ctx_input_prefix=*/closure.num_constant_args(),
input_output_alias, closure.resource_var_snapshots())); absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs));
} }
XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}

View File

@ -50,35 +50,47 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
// Builds an XLA allocator for the device. // Builds an XLA allocator for the device.
XlaComputationLaunchContext launch_context( XlaComputationLaunchContext launch_context(
client, client->backend().memory_allocator(), client, client->backend().memory_allocator(),
client->default_device_ordinal(),
/*allocate_xla_tensors=*/true, /*allocate_xla_tensors=*/true,
/*use_multiple_streams=*/metadata.UseMultipleStreams()); /*use_multiple_streams=*/metadata.UseMultipleStreams());
launch_context.PopulateInputs(ctx, result, variable_args, std::map<int, const Tensor*> snapshot_ptrs;
/*missing_ctx_input_prefix=*/0); for (auto& p : variable_args) {
snapshot_ptrs.emplace(p.first,
p.second.has_value() ? &p.second.value() : nullptr);
}
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
launch_context.PopulateInputs(ctx, result, snapshot_ptrs,
/*missing_ctx_input_prefix=*/0,
input_output_alias);
TF_RETURN_IF_ERROR(execution_inputs.status());
se::Stream* stream = se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
TF_RET_CHECK(stream); TF_RET_CHECK(stream);
VLOG(2) << "Executing computation: " << name(); VLOG(2) << "Executing computation: " << name();
for (const xla::ShapedBuffer* arg : launch_context.arguments()) {
VLOG(2) << name() << ": " << *arg;
}
xla::ExecutableRunOptions run_options; xla::ExecutableRunOptions run_options;
run_options.set_stream(stream); run_options.set_stream(stream);
run_options.set_allocator(client->backend().memory_allocator()); run_options.set_allocator(client->backend().memory_allocator());
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed()); run_options.set_rng_seed(GetXLARandomSeed());
xla::StatusOr<xla::ScopedShapedBuffer> run_result = xla::StatusOr<xla::ExecutionOutput> run_result =
executable->Run(launch_context.arguments(), run_options); executable->Run(execution_inputs.ConsumeValueOrDie(), run_options);
TF_RETURN_IF_ERROR(run_result.status()); TF_RETURN_IF_ERROR(run_result.status());
xla::ExecutionOutput execution_output = run_result.ConsumeValueOrDie();
const xla::HloInputOutputAliasConfig& input_output_alias = xla::StatusOr<std::vector<VariableInfo>> variable_infos =
executable->executable()->module().input_output_alias_config(); GatherVariableInfo(ctx, *result, 0);
TF_RETURN_IF_ERROR(variable_infos.status());
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(*variable_infos)));
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs( TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
ctx, result, run_result.ConsumeValueOrDie(), ctx, result, execution_output.ConsumeResult(),
/*missing_ctx_input_prefix=*/0, input_output_alias, variable_args)); /*missing_ctx_input_prefix=*/0, absl::MakeSpan(*variable_infos),
input_output_alias, snapshot_ptrs));
return Status::OK(); return Status::OK();
} }

View File

@ -59,11 +59,13 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) {
return Status::OK(); return Status::OK();
})); }));
mutex_lock ml(*variable->mu()); mutex_lock ml(*variable->mu());
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_, OP_REQUIRES(
errors::InvalidArgument( context,
"Trying to assign variable with wrong dtype. Expected ", !variable->is_initialized || variable->tensor()->dtype() == dtype_,
DataTypeString(variable->tensor()->dtype()), " got ", errors::InvalidArgument(
DataTypeString(dtype_))); "Trying to assign variable with wrong dtype. Expected ",
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(dtype_)));
variable->is_initialized = true; variable->is_initialized = true;
*variable->tensor() = value; *variable->tensor() = value;
} }

View File

@ -91,29 +91,19 @@ VariableInfo::~VariableInfo() {
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx, Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
absl::Span<const int> variable_indices, absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result) { std::vector<VariableInfo>* result) {
std::vector<const ResourceHandle*> resource_handles;
absl::c_transform(
variable_indices, std::back_inserter(resource_handles),
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
std::vector<core::RefCountPtr<Var>> variables;
Status s = LookupResources(ctx, resource_handles, &variables);
if (!s.ok()) {
errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage);
return s;
}
result->clear(); result->clear();
result->reserve(variable_indices.size()); result->reserve(variable_indices.size());
for (int i = 0; i < variable_indices.size(); i++) { for (int var_idx : variable_indices) {
// *Release* the variable because we're going to unref it later in Var* variable = nullptr;
// ~VariableInfo. ResourceHandle handle = HandleFromInput(ctx, var_idx);
Var* variable = variables[i].release(); TF_RETURN_IF_ERROR(
int input_idx = variable_indices[i]; LookupOrCreateResource<Var>(ctx, handle, &variable, [&](Var** ptr) {
std::string var_name = HandleFromInput(ctx, input_idx).name(); // This var is uninitialized for now.
result->emplace_back(input_idx, var_name, variable); *ptr = new Var(DT_INVALID);
return Status::OK();
}));
result->emplace_back(var_idx, handle.name(), variable);
} }
return Status::OK(); return Status::OK();
} }
@ -176,24 +166,43 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
XlaComputationLaunchContext::XlaComputationLaunchContext( XlaComputationLaunchContext::XlaComputationLaunchContext(
xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator, xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors, bool use_multiple_streams) int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams)
: client_(client), : client_(client),
xla_allocator_(xla_allocator), xla_allocator_(xla_allocator),
allocate_xla_tensors_(allocate_xla_tensors), allocate_xla_tensors_(allocate_xla_tensors),
use_multiple_streams_(use_multiple_streams) { use_multiple_streams_(use_multiple_streams),
device_ordinal_(device_ordinal) {
if (use_multiple_streams_) { if (use_multiple_streams_) {
CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must " CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
"be allocating XLA tensors!"; "be allocating XLA tensors!";
} }
} }
void XlaComputationLaunchContext::PopulateInputs( // Fills in `execution_input` with `buffer` for `index`.
static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input,
xla::ShapeIndex index,
se::DeviceMemoryBase& buffer,
bool donate_buffer, int device_ordinal,
se::DeviceMemoryAllocator* allocator) {
xla::MaybeOwningDeviceMemory* in_buffer =
execution_input.MutableBuffer(index);
if (donate_buffer) {
*in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator);
buffer = se::DeviceMemoryBase();
} else {
*in_buffer = buffer;
}
}
xla::StatusOr<std::vector<xla::ExecutionInput>>
XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result, const XlaCompiler::CompilationResult* compilation_result,
const ResourceVarsSnapshot& variables, int missing_ctx_input_prefix) { const std::map<int, const Tensor*>& resource_vars,
// Build ShapedBuffers that point directly to the Tensor buffers. int missing_ctx_input_prefix,
arg_ptrs_ = const xla::HloInputOutputAliasConfig& input_output_alias) {
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size()); std::vector<xla::ExecutionInput> arguments;
arguments.reserve(compilation_result->xla_input_shapes.size());
xla::TransferManager* transfer_manager = xla::TransferManager* transfer_manager =
client_->backend().transfer_manager(); client_->backend().transfer_manager();
@ -201,10 +210,28 @@ void XlaComputationLaunchContext::PopulateInputs(
int arg_num = compilation_result->input_mapping[i]; int arg_num = compilation_result->input_mapping[i];
CHECK_GE(arg_num, missing_ctx_input_prefix); CHECK_GE(arg_num, missing_ctx_input_prefix);
const xla::Shape& shape = compilation_result->xla_input_shapes[i]; const xla::Shape& shape = compilation_result->xla_input_shapes[i];
const Tensor* t = variables.count(arg_num) const xla::Shape& device_shape =
? &(variables.at(arg_num).value()) transfer_manager->HostShapeToDeviceShape(shape);
bool is_resource_variable = resource_vars.count(arg_num);
bool is_updated_resource_variable =
is_resource_variable &&
absl::c_any_of(compilation_result->resource_updates,
[&](const XlaCompiler::ResourceUpdate& update) {
return update.input_index == i && update.modified;
});
const Tensor* t = is_resource_variable
? resource_vars.at(arg_num)
: &(ctx->input(arg_num - missing_ctx_input_prefix)); : &(ctx->input(arg_num - missing_ctx_input_prefix));
CHECK(t); CHECK(t);
bool donate_buffer =
t->RefCountIsOne() && is_updated_resource_variable &&
input_output_alias.ParameterHasAlias(i, xla::ShapeIndex{});
VLOG(3) << "Processing input: " << i
<< "; is_resource_variable=" << is_resource_variable
<< "; is_updated_resource_variable=" << is_updated_resource_variable
<< "; donate_buffer=" << donate_buffer;
if (use_multiple_streams_) { if (use_multiple_streams_) {
CHECK(ctx->op_device_context() && ctx->op_device_context()->stream()) CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
@ -215,23 +242,28 @@ void XlaComputationLaunchContext::PopulateInputs(
ctx->op_device_context()->stream()); ctx->op_device_context()->stream());
} }
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()( arguments.emplace_back(device_shape, shape);
shape, transfer_manager->HostShapeToDeviceShape(shape))) { xla::ExecutionInput& execution_input = arguments.back();
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape, device_shape)) {
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
arg_buffers_.emplace_back( PopulateExecutionInputBuffer(execution_input, xla::ShapeIndex{}, dmem,
/*on_host_shape=*/shape, /*on_device_shape=*/shape, donate_buffer, device_ordinal_,
client_->platform(), client_->default_device_ordinal()); xla_allocator_);
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
arg_ptrs_[i] = &arg_buffers_.back();
} else { } else {
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t); XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
CHECK(xla_tensor && xla_tensor->has_shaped_buffer()); CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer()); xla_tensor->shaped_buffer().buffers().ForEachMutableElement(
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
PopulateExecutionInputBuffer(execution_input, index, *buffer,
donate_buffer, device_ordinal_,
xla_allocator_);
});
} }
} }
return std::move(arguments);
} }
// Construct the tensor for given type and buffer. // Construct the tensor for the given type and buffer.
static Tensor MakeTensor(DataType dtype, const TensorShape& shape, static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
se::DeviceMemoryBase buffer, Allocator* allocator) { se::DeviceMemoryBase buffer, Allocator* allocator) {
size_t expected_size = shape.num_elements() * DataTypeSize(dtype); size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
@ -247,28 +279,26 @@ static Tensor GetOrCreateTensorForOutput(
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix, int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias, const xla::HloInputOutputAliasConfig& input_output_alias,
absl::Span<const int> input_mapping, absl::Span<const int> input_mapping,
const ResourceVarsSnapshot& resource_var_snapshots, DataType output_dtype, const std::map<int, const Tensor*>& resource_vars_snapshots,
const TensorShape& output_shape, se::DeviceMemoryBase output_buffer, DataType output_dtype, const TensorShape& output_shape,
Allocator* output_allocator) { se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
xla::ShapeIndex output_index = input_output_alias.shape().IsTuple() xla::ShapeIndex output_index = input_output_alias.shape().IsTuple()
? xla::ShapeIndex({output_num}) ? xla::ShapeIndex({output_num})
: xla::ShapeIndex({}); : xla::ShapeIndex({});
CHECK(input_output_alias.shape().IsTuple() || output_num == 0); CHECK(input_output_alias.shape().IsTuple() || output_num == 0);
if (absl::optional<xla::HloInputOutputAliasConfig::Alias> alias = if (absl::optional<xla::HloInputOutputAliasConfig::Alias> alias =
input_output_alias.GetAliasedParameter(output_index)) { input_output_alias.GetAliasedParameter(output_index)) {
VLOG(3) << "Found alias: " << alias->ToString();
int tf_param = int tf_param =
input_mapping[alias->parameter_number] - missing_ctx_input_prefix; input_mapping[alias->parameter_number] - missing_ctx_input_prefix;
const Tensor* input_tensor = &ctx->input(tf_param); const Tensor input_tensor =
ctx->input(tf_param).dtype() != DT_RESOURCE
// If input tensor is a resource variable, alias to the snapshot we took at ? ctx->input(tf_param)
// entry time. : *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param);
if (input_tensor->dtype() == DT_RESOURCE) { if (output_buffer.opaque() == input_tensor.data()) {
const absl::optional<Tensor>& v = return input_tensor;
resource_var_snapshots.at(missing_ctx_input_prefix + tf_param);
CHECK(v.has_value());
return *v;
} }
return *input_tensor;
} }
return MakeTensor(output_dtype, output_shape, output_buffer, return MakeTensor(output_dtype, output_shape, output_buffer,
output_allocator); output_allocator);
@ -291,12 +321,10 @@ static Status SetOutputForConstant(
OpKernelContext* ctx, se::Stream* stream, OpKernelContext* ctx, se::Stream* stream,
const XlaCompiler::CompilationResult* compilation_result, int output_num) { const XlaCompiler::CompilationResult* compilation_result, int output_num) {
CHECK(compilation_result->outputs[output_num].is_constant); CHECK(compilation_result->outputs[output_num].is_constant);
// Output is a constant.
const Tensor& const_tensor = const Tensor& const_tensor =
compilation_result->outputs[output_num].constant_value; compilation_result->outputs[output_num].constant_value;
Tensor* output_tensor; Tensor* output_tensor;
const size_t total_bytes = const_tensor.TotalBytes(); if (stream && const_tensor.TotalBytes() > 0) {
if (stream && total_bytes > 0) {
// Copy host -> device. (Empty tensors don't have backing buffers.) // Copy host -> device. (Empty tensors don't have backing buffers.)
// Manually allocate memory using an XlaTensorBuffer so we can allocate // Manually allocate memory using an XlaTensorBuffer so we can allocate
// as much memory as the device requires (as given by // as much memory as the device requires (as given by
@ -335,52 +363,55 @@ static Status SetOutputForConstant(
return Status::OK(); return Status::OK();
} }
// Creates a list of updates resource variables. static xla::StatusOr<Var*> GetOrCreateResourceVar(
static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo( OpKernelContext* ctx, const ResourceHandle& handle,
OpKernelContext* ctx, const XlaCompiler::ResourceUpdate& write) {
const XlaCompiler::CompilationResult* compilation_result, Var* variable = nullptr;
int missing_ctx_input_prefix) { TF_RETURN_IF_ERROR(
std::vector<VariableInfo> variable_infos; LookupOrCreateResource<Var>(ctx, handle, &variable, [&write](Var** ptr) {
variable_infos.reserve(compilation_result->resource_updates.size()); *ptr = new Var(write.type);
return Status::OK();
}));
return variable;
}
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) { xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult& compilation_result,
int missing_ctx_input_prefix) {
std::vector<VariableInfo> out;
out.reserve(compilation_result.resource_updates.size());
for (int i = 0; i < compilation_result.resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write = const XlaCompiler::ResourceUpdate& write =
compilation_result->resource_updates[i]; compilation_result.resource_updates[i];
int actual_input_index = write.input_index - missing_ctx_input_prefix; int actual_input_index = write.input_index - missing_ctx_input_prefix;
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) { if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write."); return errors::Internal("Invalid input index for variable write.");
} }
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
Var* variable = nullptr;
const ResourceHandle handle = HandleFromInput(ctx, actual_input_index); const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(ctx, handle, &variable, TF_ASSIGN_OR_RETURN(Var * variable,
[&write](Var** ptr) { GetOrCreateResourceVar(ctx, handle, write));
*ptr = new Var(write.type); out.emplace_back(actual_input_index, handle.name(), variable);
return Status::OK();
}));
variable_infos.emplace_back(actual_input_index, handle.name(), variable);
} }
return variable_infos; return std::move(out);
} }
Status XlaComputationLaunchContext::PopulateOutputs( Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result, const XlaCompiler::CompilationResult* compilation_result,
ScopedShapedBuffer output, int missing_ctx_input_prefix, ScopedShapedBuffer output, int missing_ctx_input_prefix,
absl::Span<VariableInfo> variable_infos,
const xla::HloInputOutputAliasConfig& input_output_alias, const xla::HloInputOutputAliasConfig& input_output_alias,
const ResourceVarsSnapshot& resource_var_snapshots) { const std::map<int, const Tensor*>& resource_vars) {
se::Stream* stream = se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
Allocator* allocator = ctx->device()->GetAllocator({}); Allocator* allocator = ctx->device()->GetAllocator({});
// Computation output should always be a tuple. // Computation output should always be a tuple.
if (VLOG_IS_ON(2)) { VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString(); VLOG(2) << "Result tuple shape (on device): "
VLOG(2) << "Result tuple shape (on device): " << output.on_device_shape().DebugString();
<< output.on_device_shape().DebugString();
}
CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size()); CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
// If the on-host-shape isn't a tuple, create a new single-element tuple // If the on-host-shape isn't a tuple, create a new single-element tuple
@ -438,8 +469,8 @@ Status XlaComputationLaunchContext::PopulateOutputs(
for (int i = 0; i < ctx->num_outputs(); ++i) { for (int i = 0; i < ctx->num_outputs(); ++i) {
const TensorShape& shape = output_tensor_shapes[i]; const TensorShape& shape = output_tensor_shapes[i];
const DataType& type = compilation_result->outputs[i].type; const DataType& type = compilation_result->outputs[i].type;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " VLOG(2) << "Populating output for retval " << i << " shape "
<< DataTypeString(type); << shape.DebugString() << " type " << DataTypeString(type);
if (type == DT_VARIANT) { if (type == DT_VARIANT) {
return errors::Unimplemented( return errors::Unimplemented(
"Support for TensorList crossing the XLA/TF boundary " "Support for TensorList crossing the XLA/TF boundary "
@ -467,30 +498,37 @@ Status XlaComputationLaunchContext::PopulateOutputs(
se::DeviceMemoryBase buffer = output.buffer({output_num}); se::DeviceMemoryBase buffer = output.buffer({output_num});
Tensor output_tensor = GetOrCreateTensorForOutput( Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias, output_num, ctx, missing_ctx_input_prefix, input_output_alias,
compilation_result->input_mapping, resource_var_snapshots, compilation_result->input_mapping, resource_vars,
ctx->expected_output_dtype(i), shape, buffer, allocator); ctx->expected_output_dtype(i), shape, buffer, allocator);
output.set_buffer(se::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor); ctx->set_output(i, output_tensor);
} }
output.set_buffer(se::OwningDeviceMemory(), {output_num});
++output_num; ++output_num;
} }
if (VLOG_IS_ON(3)) {
VLOG(3) << ctx->mutable_output(i)->DeviceSafeDebugString();
}
} }
// Apply variable updates, if any. // input_index -> index into variable_infos.
VLOG(2) << "Applying variable updates"; absl::flat_hash_map<int, int> variable_info_lookup;
TF_ASSIGN_OR_RETURN( for (int i = 0; i < variable_infos.size(); i++) {
std::vector<VariableInfo> variable_infos, variable_info_lookup.emplace(variable_infos[i].index(), i);
GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix)); }
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
// Apply variable updates, if any.
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) { for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write = const XlaCompiler::ResourceUpdate& write =
compilation_result->resource_updates[i]; compilation_result->resource_updates[i];
if (variable_infos[i].var()->tensor()->dtype() != write.type) { int actual_input_index = write.input_index - missing_ctx_input_prefix;
CHECK_GE(actual_input_index, 0);
CHECK_LT(actual_input_index, ctx->num_inputs());
Var* var = variable_infos[variable_info_lookup[actual_input_index]].var();
CHECK(var);
VLOG(2) << "Updating variable #" << i
<< " at input index: " << actual_input_index << " with shape "
<< write.shape.DebugString() << "; variable tensor has shape: "
<< var->tensor()->shape().DebugString();
if (var->is_initialized && var->tensor()->dtype() != write.type) {
return errors::Internal("Mismatched type in variable write"); return errors::Internal("Mismatched type in variable write");
} }
@ -504,14 +542,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
} }
} else { } else {
se::DeviceMemoryBase buffer = output.buffer({output_num}); se::DeviceMemoryBase buffer = output.buffer({output_num});
output.set_buffer(se::OwningDeviceMemory(), {output_num});
output_tensor = GetOrCreateTensorForOutput( output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias, output_num, ctx, missing_ctx_input_prefix, input_output_alias,
compilation_result->input_mapping, resource_var_snapshots, write.type, compilation_result->input_mapping, resource_vars, write.type,
write.shape, buffer, allocator); write.shape, buffer, allocator);
} }
*variable_infos[i].var()->tensor() = output_tensor; output.set_buffer(se::OwningDeviceMemory(), {output_num});
variable_infos[i].var()->is_initialized |= write.modified; var->is_initialized |= write.modified;
*var->tensor() = output_tensor;
++output_num; ++output_num;
} }
return Status::OK(); return Status::OK();
@ -562,7 +600,7 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
arg.name = std::string(variable.name()); arg.name = std::string(variable.name());
arg.kind = XlaCompiler::Argument::kResource; arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = XlaResource::kVariable; arg.resource_kind = XlaResource::kVariable;
if (variable.var()) { if (variable.var() && variable.var()->is_initialized) {
const Tensor* value = variable.var()->tensor(); const Tensor* value = variable.var()->tensor();
arg.type = value->dtype(); arg.type = value->dtype();
arg.shape = value->shape(); arg.shape = value->shape();

View File

@ -81,6 +81,12 @@ class VariableInfo {
bool lock_held_ = false; bool lock_held_ = false;
}; };
// Creates a list of updated resource variables.
xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult& compilation_result,
int missing_ctx_input_prefix);
// Takes a snapshot of the values of resource variable arguments, whose indices // Takes a snapshot of the values of resource variable arguments, whose indices
// are specified in `variable_indices` argument. We snapshot tensors that back // are specified in `variable_indices` argument. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is // resource variables since concurrent updates may modify the shape, and it is
@ -124,7 +130,7 @@ class XlaComputationLaunchContext {
// objects. // objects.
XlaComputationLaunchContext(xla::LocalClient* client, XlaComputationLaunchContext(xla::LocalClient* client,
se::DeviceMemoryAllocator* xla_allocator, se::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors, int device_ordinal, bool allocate_xla_tensors,
bool use_multiple_streams); bool use_multiple_streams);
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch // Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
@ -142,10 +148,12 @@ class XlaComputationLaunchContext {
// missing and adjusts input indices accordingly. All elements in kernel's // missing and adjusts input indices accordingly. All elements in kernel's
// input_mapping must be greater than or equal to `missing_ctx_input_prefix` // input_mapping must be greater than or equal to `missing_ctx_input_prefix`
// (in other words, no inputs actually required by the kernel can be missing). // (in other words, no inputs actually required by the kernel can be missing).
void PopulateInputs(OpKernelContext* ctx, xla::StatusOr<std::vector<xla::ExecutionInput>> PopulateInputs(
const XlaCompiler::CompilationResult* compilation_result, OpKernelContext* ctx,
const ResourceVarsSnapshot& variables, const XlaCompiler::CompilationResult* compilation_result,
int missing_ctx_input_prefix); const std::map<int, const Tensor*>& resource_vars,
int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias);
// Given the XLA output in `output`, populate all outputs of `ctx`. Also // Given the XLA output in `output`, populate all outputs of `ctx`. Also
// writes out the resource variable updates. // writes out the resource variable updates.
@ -161,20 +169,16 @@ class XlaComputationLaunchContext {
OpKernelContext* ctx, OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result, const XlaCompiler::CompilationResult* compilation_result,
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix, xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
absl::Span<VariableInfo> variable_infos,
const xla::HloInputOutputAliasConfig& input_output_alias, const xla::HloInputOutputAliasConfig& input_output_alias,
const ResourceVarsSnapshot& resource_var_snapshots); const std::map<int, const Tensor*>& resource_vars);
// Return the argument list. Only valid after PopulateInputs() has been
// called.
const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; }
private: private:
xla::LocalClient* client_; xla::LocalClient* client_;
se::DeviceMemoryAllocator* xla_allocator_; se::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_; bool allocate_xla_tensors_;
bool use_multiple_streams_; bool use_multiple_streams_;
std::deque<xla::ShapedBuffer> arg_buffers_; int device_ordinal_;
std::vector<xla::ShapedBuffer*> arg_ptrs_;
}; };
// A simple TensorBuffer implementation that allows us to create Tensors that // A simple TensorBuffer implementation that allows us to create Tensors that

View File

@ -32,6 +32,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -403,6 +404,69 @@ class DefFunctionTest(test.TestCase):
self.assertEqual(inner_retracings, 1) self.assertEqual(inner_retracings, 1)
def testUpdateVariable(self):
v = variables.Variable(3.1)
@def_function.function(experimental_compile=True)
def update_var(a, b):
v.assign_add(a * b)
update_var(constant_op.constant(0.7), constant_op.constant(0.6))
self.assertAllClose(v, 3.52)
def testUpdateVariableVector(self):
v = variables.Variable([3.1, 3.1])
@def_function.function(experimental_compile=True)
def update_var(a, b):
v.assign_add(a * b)
update_var(
constant_op.constant([0.7, 0.7]), constant_op.constant([0.6, 0.6]))
self.assertAllClose(v, [3.52, 3.52])
def testUpdateVariableInClass(self):
class C(object):
@def_function.function(experimental_compile=True)
def update_var(self, a, b):
if not hasattr(self, 'v'):
self.v = variables.Variable(3.1)
self.v.assign_add(a * b)
c = C()
@def_function.function
def outer():
c.update_var(constant_op.constant(0.7), constant_op.constant(0.6))
outer()
self.assertAllClose(c.v, 3.52)
def testUpdateVariableMultipleOutputs(self):
v = variables.Variable(3.1)
@def_function.function(experimental_compile=True)
def update_var(a, b):
v.assign_add(a * b)
return a * b + v
out = update_var(constant_op.constant(0.7), constant_op.constant(0.6))
self.assertAllClose(v, 3.52)
self.assertAllClose(out, 3.94)
def testReturnIdentity(self):
@def_function.function(experimental_compile=True)
def f(a, b):
return (a, b)
a = constant_op.constant([0.7])
b = constant_op.constant([0.6])
f(a, b)
if __name__ == '__main__': if __name__ == '__main__':
ops.enable_eager_execution() ops.enable_eager_execution()