Rollback of rollback of [TF/XLA] Enable input/output aliasing in the TF2XLA bridge

The underlying bug was fixed

PiperOrigin-RevId: 321863222
Change-Id: I94c25f3243e33374ee089dd808c3f25704de2c92
This commit is contained in:
George Karpenkov 2020-07-17 15:08:25 -07:00 committed by TensorFlower Gardener
parent 01d9e46f28
commit ec25e31817
6 changed files with 319 additions and 168 deletions

View File

@ -277,7 +277,8 @@ static Status CompileToLocalExecutable(
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
const XlaPlatformInfo& platform_info,
absl::Span<VariableInfo const> variable_infos,
absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
absl::Span<const int> constants, bool lazy, bool may_alias_resource_update,
xla::LocalClient** client,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable) {
// We store information about the JIT-compiled XLA computation
@ -332,6 +333,9 @@ static Status CompileToLocalExecutable(
// Optimization: where possible, have the computation return a naked array
// rather than a one-element tuple.
compile_options.always_return_tuple = false;
compile_options.alias_resource_update = !has_ref_vars &&
!platform_info.is_on_xla_device() &&
may_alias_resource_update;
std::vector<XlaCompiler::Argument> args;
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
@ -350,20 +354,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
const XlaCompiler::CompilationResult* compilation_result;
xla::LocalExecutable* executable;
ResourceVarsSnapshot variables_snapshot;
std::vector<VariableInfo> variable_infos;
{
std::vector<VariableInfo> variable_infos;
OP_REQUIRES_OK(
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
Status s = CompileToLocalExecutable(
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
variable_infos, constants_, /*lazy=*/false, &client,
&compilation_result, &executable);
variable_infos, constants_, /*lazy=*/false,
/*may_alias_resource_update=*/true, &client, &compilation_result,
&executable);
OP_REQUIRES_OK(ctx, s);
OP_REQUIRES_OK(ctx,
SnapshotResourceVariables(ctx, resources_, variable_infos,
&variables_snapshot));
}
std::map<int, const Tensor*> resource_var_ptrs;
for (int i = 0; i < resources_.size(); i++) {
resource_var_ptrs[resources_[i]] = variable_infos[i].var()->tensor();
}
se::Stream* stream =
@ -374,12 +380,19 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
int device_ordinal = stream ? stream->parent()->device_ordinal()
: client->default_device_ordinal();
XlaComputationLaunchContext launch_context(
client, allocator,
client, allocator, device_ordinal,
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
platform_info_.UseMultipleStreams());
launch_context.PopulateInputs(ctx, compilation_result, variables_snapshot,
/*missing_ctx_input_prefix=*/0);
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
launch_context.PopulateInputs(ctx, compilation_result, resource_var_ptrs,
/*missing_ctx_input_prefix=*/0,
input_output_alias);
OP_REQUIRES_OK(ctx, execution_inputs.status());
// Execute the computation.
VLOG(2) << "Executing computation.";
@ -403,24 +416,24 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
Env* env = Env::Default();
auto start_time = env->NowMicros();
xla::StatusOr<xla::ScopedShapedBuffer> run_result;
xla::StatusOr<xla::ExecutionOutput> execution_output;
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
run_result = executable->Run(launch_context.arguments(), run_options);
execution_output =
executable->Run(std::move(*execution_inputs), run_options);
} else {
run_result = executable->RunAsync(launch_context.arguments(), run_options);
execution_output =
executable->RunAsync(std::move(*execution_inputs), run_options);
}
OP_REQUIRES(ctx, run_result.ok(), run_result.status());
OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us";
OP_REQUIRES_OK(
ctx, launch_context.PopulateOutputs(
ctx, compilation_result, execution_output->ConsumeResult(),
/*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos),
input_output_alias, resource_var_ptrs));
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
OP_REQUIRES_OK(ctx,
launch_context.PopulateOutputs(
ctx, compilation_result, run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/0, input_output_alias,
variables_snapshot));
VLOG(1) << "Done";
}
@ -516,10 +529,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
// Do not alias resource updates as locking variables in XlaCompile and
// unlocking them in XlaRun may lead to deadlocks.
Status status = CompileToLocalExecutable(
ctx, function_, has_ref_vars_, platform_info_, variable_infos,
constants_,
/*lazy=*/!must_compile_, &client, &kernel, &executable);
/*lazy=*/!must_compile_,
/*may_alias_resource_update=*/false, &client, &kernel, &executable);
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
variable_infos, &variables));
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
@ -587,14 +604,22 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
int device_ordinal = stream ? stream->parent()->device_ordinal()
: closure.client()->default_device_ordinal();
XlaComputationLaunchContext launch_context(
closure.client(), allocator,
closure.client(), allocator, device_ordinal,
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
/*use_multiple_streams=*/platform_info_.UseMultipleStreams());
// We're missing the must-be-constant inputs, tell `PopulateInputs`
// about this. We don't actually need these inputs because they've
// already been baked into the compiled kernel.
const xla::HloInputOutputAliasConfig& input_output_alias =
closure.executable()->executable()->module().input_output_alias_config();
xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs;
std::map<int, const Tensor*> snapshot_ptrs;
{
tensorflow::profiler::TraceMe hlo_module_activity(
[&] {
@ -604,13 +629,17 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
},
tensorflow::profiler::TraceMeLevel::kInfo);
launch_context.PopulateInputs(
ctx, closure.compilation_result(), closure.resource_var_snapshots(),
/*missing_ctx_input_prefix=*/closure.num_constant_args());
for (auto& p : closure.resource_var_snapshots()) {
snapshot_ptrs.emplace(p.first,
p.second.has_value() ? &p.second.value() : nullptr);
}
execution_inputs = launch_context.PopulateInputs(
ctx, closure.compilation_result(), snapshot_ptrs,
/*missing_ctx_input_prefix=*/closure.num_constant_args(),
input_output_alias);
OP_REQUIRES_OK(ctx, execution_inputs.status());
}
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_allocator(allocator);
@ -631,21 +660,19 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
Env* env = Env::Default();
auto start_time = env->NowMicros();
xla::StatusOr<xla::ScopedShapedBuffer> run_result;
xla::StatusOr<xla::ExecutionOutput> execution_output;
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
run_result =
closure.executable()->Run(launch_context.arguments(), run_options);
execution_output =
closure.executable()->Run(std::move(*execution_inputs), run_options);
} else {
run_result =
closure.executable()->RunAsync(launch_context.arguments(), run_options);
execution_output = closure.executable()->RunAsync(
std::move(*execution_inputs), run_options);
}
OP_REQUIRES(ctx, run_result.ok(), run_result.status());
OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
const xla::HloInputOutputAliasConfig& input_output_alias =
closure.executable()->executable()->module().input_output_alias_config();
tensorflow::profiler::TraceMe hlo_module_activity(
[&] {
@ -653,12 +680,16 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
},
tensorflow::profiler::TraceMeLevel::kInfo);
xla::StatusOr<std::vector<VariableInfo>> variable_infos = GatherVariableInfo(
ctx, *closure.compilation_result(), closure.num_constant_args());
OP_REQUIRES_OK(ctx, variable_infos.status());
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(*variable_infos)));
OP_REQUIRES_OK(
ctx,
launch_context.PopulateOutputs(
ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
ctx, closure.compilation_result(), execution_output->ConsumeResult(),
/*missing_ctx_input_prefix=*/closure.num_constant_args(),
input_output_alias, closure.resource_var_snapshots()));
absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs));
}
XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}

View File

@ -50,35 +50,47 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
// Builds an XLA allocator for the device.
XlaComputationLaunchContext launch_context(
client, client->backend().memory_allocator(),
client->default_device_ordinal(),
/*allocate_xla_tensors=*/true,
/*use_multiple_streams=*/metadata.UseMultipleStreams());
launch_context.PopulateInputs(ctx, result, variable_args,
/*missing_ctx_input_prefix=*/0);
std::map<int, const Tensor*> snapshot_ptrs;
for (auto& p : variable_args) {
snapshot_ptrs.emplace(p.first,
p.second.has_value() ? &p.second.value() : nullptr);
}
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
launch_context.PopulateInputs(ctx, result, snapshot_ptrs,
/*missing_ctx_input_prefix=*/0,
input_output_alias);
TF_RETURN_IF_ERROR(execution_inputs.status());
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
TF_RET_CHECK(stream);
VLOG(2) << "Executing computation: " << name();
for (const xla::ShapedBuffer* arg : launch_context.arguments()) {
VLOG(2) << name() << ": " << *arg;
}
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_allocator(client->backend().memory_allocator());
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
xla::StatusOr<xla::ScopedShapedBuffer> run_result =
executable->Run(launch_context.arguments(), run_options);
xla::StatusOr<xla::ExecutionOutput> run_result =
executable->Run(execution_inputs.ConsumeValueOrDie(), run_options);
TF_RETURN_IF_ERROR(run_result.status());
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
xla::ExecutionOutput execution_output = run_result.ConsumeValueOrDie();
xla::StatusOr<std::vector<VariableInfo>> variable_infos =
GatherVariableInfo(ctx, *result, 0);
TF_RETURN_IF_ERROR(variable_infos.status());
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(*variable_infos)));
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
ctx, result, run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/0, input_output_alias, variable_args));
ctx, result, execution_output.ConsumeResult(),
/*missing_ctx_input_prefix=*/0, absl::MakeSpan(*variable_infos),
input_output_alias, snapshot_ptrs));
return Status::OK();
}

View File

@ -59,11 +59,13 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) {
return Status::OK();
}));
mutex_lock ml(*variable->mu());
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
errors::InvalidArgument(
"Trying to assign variable with wrong dtype. Expected ",
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(dtype_)));
OP_REQUIRES(
context,
!variable->is_initialized || variable->tensor()->dtype() == dtype_,
errors::InvalidArgument(
"Trying to assign variable with wrong dtype. Expected ",
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(dtype_)));
variable->is_initialized = true;
*variable->tensor() = value;
}

View File

@ -91,29 +91,19 @@ VariableInfo::~VariableInfo() {
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result) {
std::vector<const ResourceHandle*> resource_handles;
absl::c_transform(
variable_indices, std::back_inserter(resource_handles),
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
std::vector<core::RefCountPtr<Var>> variables;
Status s = LookupResources(ctx, resource_handles, &variables);
if (!s.ok()) {
errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage);
return s;
}
result->clear();
result->reserve(variable_indices.size());
for (int i = 0; i < variable_indices.size(); i++) {
// *Release* the variable because we're going to unref it later in
// ~VariableInfo.
Var* variable = variables[i].release();
int input_idx = variable_indices[i];
std::string var_name = HandleFromInput(ctx, input_idx).name();
result->emplace_back(input_idx, var_name, variable);
for (int var_idx : variable_indices) {
Var* variable = nullptr;
ResourceHandle handle = HandleFromInput(ctx, var_idx);
TF_RETURN_IF_ERROR(
LookupOrCreateResource<Var>(ctx, handle, &variable, [&](Var** ptr) {
// This var is uninitialized for now.
*ptr = new Var(DT_INVALID);
return Status::OK();
}));
result->emplace_back(var_idx, handle.name(), variable);
}
return Status::OK();
}
@ -176,24 +166,43 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
XlaComputationLaunchContext::XlaComputationLaunchContext(
xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors, bool use_multiple_streams)
int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams)
: client_(client),
xla_allocator_(xla_allocator),
allocate_xla_tensors_(allocate_xla_tensors),
use_multiple_streams_(use_multiple_streams) {
use_multiple_streams_(use_multiple_streams),
device_ordinal_(device_ordinal) {
if (use_multiple_streams_) {
CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
"be allocating XLA tensors!";
}
}
void XlaComputationLaunchContext::PopulateInputs(
// Fills in `execution_input` with `buffer` for `index`.
static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input,
xla::ShapeIndex index,
se::DeviceMemoryBase& buffer,
bool donate_buffer, int device_ordinal,
se::DeviceMemoryAllocator* allocator) {
xla::MaybeOwningDeviceMemory* in_buffer =
execution_input.MutableBuffer(index);
if (donate_buffer) {
*in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator);
buffer = se::DeviceMemoryBase();
} else {
*in_buffer = buffer;
}
}
xla::StatusOr<std::vector<xla::ExecutionInput>>
XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
const ResourceVarsSnapshot& variables, int missing_ctx_input_prefix) {
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_ptrs_ =
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
const std::map<int, const Tensor*>& resource_vars,
int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias) {
std::vector<xla::ExecutionInput> arguments;
arguments.reserve(compilation_result->xla_input_shapes.size());
xla::TransferManager* transfer_manager =
client_->backend().transfer_manager();
@ -201,10 +210,28 @@ void XlaComputationLaunchContext::PopulateInputs(
int arg_num = compilation_result->input_mapping[i];
CHECK_GE(arg_num, missing_ctx_input_prefix);
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
const Tensor* t = variables.count(arg_num)
? &(variables.at(arg_num).value())
const xla::Shape& device_shape =
transfer_manager->HostShapeToDeviceShape(shape);
bool is_resource_variable = resource_vars.count(arg_num);
bool is_updated_resource_variable =
is_resource_variable &&
absl::c_any_of(compilation_result->resource_updates,
[&](const XlaCompiler::ResourceUpdate& update) {
return update.input_index == i && update.modified;
});
const Tensor* t = is_resource_variable
? resource_vars.at(arg_num)
: &(ctx->input(arg_num - missing_ctx_input_prefix));
CHECK(t);
bool donate_buffer =
t->RefCountIsOne() && is_updated_resource_variable &&
input_output_alias.ParameterHasAlias(i, xla::ShapeIndex{});
VLOG(3) << "Processing input: " << i
<< "; is_resource_variable=" << is_resource_variable
<< "; is_updated_resource_variable=" << is_updated_resource_variable
<< "; donate_buffer=" << donate_buffer;
if (use_multiple_streams_) {
CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
@ -215,23 +242,28 @@ void XlaComputationLaunchContext::PopulateInputs(
ctx->op_device_context()->stream());
}
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(
shape, transfer_manager->HostShapeToDeviceShape(shape))) {
arguments.emplace_back(device_shape, shape);
xla::ExecutionInput& execution_input = arguments.back();
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape, device_shape)) {
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
arg_buffers_.emplace_back(
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
client_->platform(), client_->default_device_ordinal());
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
arg_ptrs_[i] = &arg_buffers_.back();
PopulateExecutionInputBuffer(execution_input, xla::ShapeIndex{}, dmem,
donate_buffer, device_ordinal_,
xla_allocator_);
} else {
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
xla_tensor->shaped_buffer().buffers().ForEachMutableElement(
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
PopulateExecutionInputBuffer(execution_input, index, *buffer,
donate_buffer, device_ordinal_,
xla_allocator_);
});
}
}
return std::move(arguments);
}
// Construct the tensor for given type and buffer.
// Construct the tensor for the given type and buffer.
static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
se::DeviceMemoryBase buffer, Allocator* allocator) {
size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
@ -247,28 +279,26 @@ static Tensor GetOrCreateTensorForOutput(
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
absl::Span<const int> input_mapping,
const ResourceVarsSnapshot& resource_var_snapshots, DataType output_dtype,
const TensorShape& output_shape, se::DeviceMemoryBase output_buffer,
Allocator* output_allocator) {
const std::map<int, const Tensor*>& resource_vars_snapshots,
DataType output_dtype, const TensorShape& output_shape,
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
xla::ShapeIndex output_index = input_output_alias.shape().IsTuple()
? xla::ShapeIndex({output_num})
: xla::ShapeIndex({});
CHECK(input_output_alias.shape().IsTuple() || output_num == 0);
if (absl::optional<xla::HloInputOutputAliasConfig::Alias> alias =
input_output_alias.GetAliasedParameter(output_index)) {
VLOG(3) << "Found alias: " << alias->ToString();
int tf_param =
input_mapping[alias->parameter_number] - missing_ctx_input_prefix;
const Tensor* input_tensor = &ctx->input(tf_param);
// If input tensor is a resource variable, alias to the snapshot we took at
// entry time.
if (input_tensor->dtype() == DT_RESOURCE) {
const absl::optional<Tensor>& v =
resource_var_snapshots.at(missing_ctx_input_prefix + tf_param);
CHECK(v.has_value());
return *v;
const Tensor input_tensor =
ctx->input(tf_param).dtype() != DT_RESOURCE
? ctx->input(tf_param)
: *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param);
if (output_buffer.opaque() == input_tensor.data()) {
return input_tensor;
}
return *input_tensor;
}
return MakeTensor(output_dtype, output_shape, output_buffer,
output_allocator);
@ -291,12 +321,10 @@ static Status SetOutputForConstant(
OpKernelContext* ctx, se::Stream* stream,
const XlaCompiler::CompilationResult* compilation_result, int output_num) {
CHECK(compilation_result->outputs[output_num].is_constant);
// Output is a constant.
const Tensor& const_tensor =
compilation_result->outputs[output_num].constant_value;
Tensor* output_tensor;
const size_t total_bytes = const_tensor.TotalBytes();
if (stream && total_bytes > 0) {
if (stream && const_tensor.TotalBytes() > 0) {
// Copy host -> device. (Empty tensors don't have backing buffers.)
// Manually allocate memory using an XlaTensorBuffer so we can allocate
// as much memory as the device requires (as given by
@ -335,52 +363,55 @@ static Status SetOutputForConstant(
return Status::OK();
}
// Creates a list of updates resource variables.
static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
int missing_ctx_input_prefix) {
std::vector<VariableInfo> variable_infos;
variable_infos.reserve(compilation_result->resource_updates.size());
static xla::StatusOr<Var*> GetOrCreateResourceVar(
OpKernelContext* ctx, const ResourceHandle& handle,
const XlaCompiler::ResourceUpdate& write) {
Var* variable = nullptr;
TF_RETURN_IF_ERROR(
LookupOrCreateResource<Var>(ctx, handle, &variable, [&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
return variable;
}
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult& compilation_result,
int missing_ctx_input_prefix) {
std::vector<VariableInfo> out;
out.reserve(compilation_result.resource_updates.size());
for (int i = 0; i < compilation_result.resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write =
compilation_result->resource_updates[i];
compilation_result.resource_updates[i];
int actual_input_index = write.input_index - missing_ctx_input_prefix;
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write.");
}
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
Var* variable = nullptr;
const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(ctx, handle, &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
variable_infos.emplace_back(actual_input_index, handle.name(), variable);
TF_ASSIGN_OR_RETURN(Var * variable,
GetOrCreateResourceVar(ctx, handle, write));
out.emplace_back(actual_input_index, handle.name(), variable);
}
return variable_infos;
return std::move(out);
}
Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
ScopedShapedBuffer output, int missing_ctx_input_prefix,
absl::Span<VariableInfo> variable_infos,
const xla::HloInputOutputAliasConfig& input_output_alias,
const ResourceVarsSnapshot& resource_var_snapshots) {
const std::map<int, const Tensor*>& resource_vars) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
Allocator* allocator = ctx->device()->GetAllocator({});
// Computation output should always be a tuple.
if (VLOG_IS_ON(2)) {
VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
VLOG(2) << "Result tuple shape (on device): "
<< output.on_device_shape().DebugString();
}
VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
VLOG(2) << "Result tuple shape (on device): "
<< output.on_device_shape().DebugString();
CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
// If the on-host-shape isn't a tuple, create a new single-element tuple
@ -438,8 +469,8 @@ Status XlaComputationLaunchContext::PopulateOutputs(
for (int i = 0; i < ctx->num_outputs(); ++i) {
const TensorShape& shape = output_tensor_shapes[i];
const DataType& type = compilation_result->outputs[i].type;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);
VLOG(2) << "Populating output for retval " << i << " shape "
<< shape.DebugString() << " type " << DataTypeString(type);
if (type == DT_VARIANT) {
return errors::Unimplemented(
"Support for TensorList crossing the XLA/TF boundary "
@ -467,30 +498,37 @@ Status XlaComputationLaunchContext::PopulateOutputs(
se::DeviceMemoryBase buffer = output.buffer({output_num});
Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
compilation_result->input_mapping, resource_var_snapshots,
compilation_result->input_mapping, resource_vars,
ctx->expected_output_dtype(i), shape, buffer, allocator);
output.set_buffer(se::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor);
}
output.set_buffer(se::OwningDeviceMemory(), {output_num});
++output_num;
}
if (VLOG_IS_ON(3)) {
VLOG(3) << ctx->mutable_output(i)->DeviceSafeDebugString();
}
}
// Apply variable updates, if any.
VLOG(2) << "Applying variable updates";
TF_ASSIGN_OR_RETURN(
std::vector<VariableInfo> variable_infos,
GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix));
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
// input_index -> index into variable_infos.
absl::flat_hash_map<int, int> variable_info_lookup;
for (int i = 0; i < variable_infos.size(); i++) {
variable_info_lookup.emplace(variable_infos[i].index(), i);
}
// Apply variable updates, if any.
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write =
compilation_result->resource_updates[i];
if (variable_infos[i].var()->tensor()->dtype() != write.type) {
int actual_input_index = write.input_index - missing_ctx_input_prefix;
CHECK_GE(actual_input_index, 0);
CHECK_LT(actual_input_index, ctx->num_inputs());
Var* var = variable_infos[variable_info_lookup[actual_input_index]].var();
CHECK(var);
VLOG(2) << "Updating variable #" << i
<< " at input index: " << actual_input_index << " with shape "
<< write.shape.DebugString() << "; variable tensor has shape: "
<< var->tensor()->shape().DebugString();
if (var->is_initialized && var->tensor()->dtype() != write.type) {
return errors::Internal("Mismatched type in variable write");
}
@ -504,14 +542,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
}
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
output.set_buffer(se::OwningDeviceMemory(), {output_num});
output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
compilation_result->input_mapping, resource_var_snapshots, write.type,
compilation_result->input_mapping, resource_vars, write.type,
write.shape, buffer, allocator);
}
*variable_infos[i].var()->tensor() = output_tensor;
variable_infos[i].var()->is_initialized |= write.modified;
output.set_buffer(se::OwningDeviceMemory(), {output_num});
var->is_initialized |= write.modified;
*var->tensor() = output_tensor;
++output_num;
}
return Status::OK();
@ -562,7 +600,7 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
arg.name = std::string(variable.name());
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = XlaResource::kVariable;
if (variable.var()) {
if (variable.var() && variable.var()->is_initialized) {
const Tensor* value = variable.var()->tensor();
arg.type = value->dtype();
arg.shape = value->shape();

View File

@ -81,6 +81,12 @@ class VariableInfo {
bool lock_held_ = false;
};
// Creates a list of updated resource variables.
xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult& compilation_result,
int missing_ctx_input_prefix);
// Takes a snapshot of the values of resource variable arguments, whose indices
// are specified in `variable_indices` argument. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is
@ -124,7 +130,7 @@ class XlaComputationLaunchContext {
// objects.
XlaComputationLaunchContext(xla::LocalClient* client,
se::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors,
int device_ordinal, bool allocate_xla_tensors,
bool use_multiple_streams);
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
@ -142,10 +148,12 @@ class XlaComputationLaunchContext {
// missing and adjusts input indices accordingly. All elements in kernel's
// input_mapping must be greater than or equal to `missing_ctx_input_prefix`
// (in other words, no inputs actually required by the kernel can be missing).
void PopulateInputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
const ResourceVarsSnapshot& variables,
int missing_ctx_input_prefix);
xla::StatusOr<std::vector<xla::ExecutionInput>> PopulateInputs(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
const std::map<int, const Tensor*>& resource_vars,
int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias);
// Given the XLA output in `output`, populate all outputs of `ctx`. Also
// writes out the resource variable updates.
@ -161,20 +169,16 @@ class XlaComputationLaunchContext {
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
absl::Span<VariableInfo> variable_infos,
const xla::HloInputOutputAliasConfig& input_output_alias,
const ResourceVarsSnapshot& resource_var_snapshots);
// Return the argument list. Only valid after PopulateInputs() has been
// called.
const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; }
const std::map<int, const Tensor*>& resource_vars);
private:
xla::LocalClient* client_;
se::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;
bool use_multiple_streams_;
std::deque<xla::ShapedBuffer> arg_buffers_;
std::vector<xla::ShapedBuffer*> arg_ptrs_;
int device_ordinal_;
};
// A simple TensorBuffer implementation that allows us to create Tensors that

View File

@ -32,6 +32,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -403,6 +404,69 @@ class DefFunctionTest(test.TestCase):
self.assertEqual(inner_retracings, 1)
def testUpdateVariable(self):
v = variables.Variable(3.1)
@def_function.function(experimental_compile=True)
def update_var(a, b):
v.assign_add(a * b)
update_var(constant_op.constant(0.7), constant_op.constant(0.6))
self.assertAllClose(v, 3.52)
def testUpdateVariableVector(self):
v = variables.Variable([3.1, 3.1])
@def_function.function(experimental_compile=True)
def update_var(a, b):
v.assign_add(a * b)
update_var(
constant_op.constant([0.7, 0.7]), constant_op.constant([0.6, 0.6]))
self.assertAllClose(v, [3.52, 3.52])
def testUpdateVariableInClass(self):
class C(object):
@def_function.function(experimental_compile=True)
def update_var(self, a, b):
if not hasattr(self, 'v'):
self.v = variables.Variable(3.1)
self.v.assign_add(a * b)
c = C()
@def_function.function
def outer():
c.update_var(constant_op.constant(0.7), constant_op.constant(0.6))
outer()
self.assertAllClose(c.v, 3.52)
def testUpdateVariableMultipleOutputs(self):
v = variables.Variable(3.1)
@def_function.function(experimental_compile=True)
def update_var(a, b):
v.assign_add(a * b)
return a * b + v
out = update_var(constant_op.constant(0.7), constant_op.constant(0.6))
self.assertAllClose(v, 3.52)
self.assertAllClose(out, 3.94)
def testReturnIdentity(self):
@def_function.function(experimental_compile=True)
def f(a, b):
return (a, b)
a = constant_op.constant([0.7])
b = constant_op.constant([0.6])
f(a, b)
if __name__ == '__main__':
ops.enable_eager_execution()