Rollback of rollback of [TF/XLA] Enable input/output aliasing in the TF2XLA bridge
The underlying bug was fixed PiperOrigin-RevId: 321863222 Change-Id: I94c25f3243e33374ee089dd808c3f25704de2c92
This commit is contained in:
parent
01d9e46f28
commit
ec25e31817
@ -277,7 +277,8 @@ static Status CompileToLocalExecutable(
|
|||||||
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
|
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
|
||||||
const XlaPlatformInfo& platform_info,
|
const XlaPlatformInfo& platform_info,
|
||||||
absl::Span<VariableInfo const> variable_infos,
|
absl::Span<VariableInfo const> variable_infos,
|
||||||
absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
|
absl::Span<const int> constants, bool lazy, bool may_alias_resource_update,
|
||||||
|
xla::LocalClient** client,
|
||||||
const XlaCompiler::CompilationResult** compilation_result,
|
const XlaCompiler::CompilationResult** compilation_result,
|
||||||
xla::LocalExecutable** executable) {
|
xla::LocalExecutable** executable) {
|
||||||
// We store information about the JIT-compiled XLA computation
|
// We store information about the JIT-compiled XLA computation
|
||||||
@ -332,6 +333,9 @@ static Status CompileToLocalExecutable(
|
|||||||
// Optimization: where possible, have the computation return a naked array
|
// Optimization: where possible, have the computation return a naked array
|
||||||
// rather than a one-element tuple.
|
// rather than a one-element tuple.
|
||||||
compile_options.always_return_tuple = false;
|
compile_options.always_return_tuple = false;
|
||||||
|
compile_options.alias_resource_update = !has_ref_vars &&
|
||||||
|
!platform_info.is_on_xla_device() &&
|
||||||
|
may_alias_resource_update;
|
||||||
|
|
||||||
std::vector<XlaCompiler::Argument> args;
|
std::vector<XlaCompiler::Argument> args;
|
||||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||||
@ -350,20 +354,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
const XlaCompiler::CompilationResult* compilation_result;
|
const XlaCompiler::CompilationResult* compilation_result;
|
||||||
xla::LocalExecutable* executable;
|
xla::LocalExecutable* executable;
|
||||||
|
|
||||||
ResourceVarsSnapshot variables_snapshot;
|
std::vector<VariableInfo> variable_infos;
|
||||||
{
|
{
|
||||||
std::vector<VariableInfo> variable_infos;
|
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
||||||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
||||||
Status s = CompileToLocalExecutable(
|
Status s = CompileToLocalExecutable(
|
||||||
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
|
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
|
||||||
variable_infos, constants_, /*lazy=*/false, &client,
|
variable_infos, constants_, /*lazy=*/false,
|
||||||
&compilation_result, &executable);
|
/*may_alias_resource_update=*/true, &client, &compilation_result,
|
||||||
|
&executable);
|
||||||
OP_REQUIRES_OK(ctx, s);
|
OP_REQUIRES_OK(ctx, s);
|
||||||
OP_REQUIRES_OK(ctx,
|
}
|
||||||
SnapshotResourceVariables(ctx, resources_, variable_infos,
|
|
||||||
&variables_snapshot));
|
std::map<int, const Tensor*> resource_var_ptrs;
|
||||||
|
for (int i = 0; i < resources_.size(); i++) {
|
||||||
|
resource_var_ptrs[resources_[i]] = variable_infos[i].var()->tensor();
|
||||||
}
|
}
|
||||||
|
|
||||||
se::Stream* stream =
|
se::Stream* stream =
|
||||||
@ -374,12 +380,19 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||||
se::DeviceMemoryAllocator* allocator =
|
se::DeviceMemoryAllocator* allocator =
|
||||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
||||||
|
int device_ordinal = stream ? stream->parent()->device_ordinal()
|
||||||
|
: client->default_device_ordinal();
|
||||||
XlaComputationLaunchContext launch_context(
|
XlaComputationLaunchContext launch_context(
|
||||||
client, allocator,
|
client, allocator, device_ordinal,
|
||||||
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
|
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
|
||||||
platform_info_.UseMultipleStreams());
|
platform_info_.UseMultipleStreams());
|
||||||
launch_context.PopulateInputs(ctx, compilation_result, variables_snapshot,
|
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||||
/*missing_ctx_input_prefix=*/0);
|
executable->executable()->module().input_output_alias_config();
|
||||||
|
xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
|
||||||
|
launch_context.PopulateInputs(ctx, compilation_result, resource_var_ptrs,
|
||||||
|
/*missing_ctx_input_prefix=*/0,
|
||||||
|
input_output_alias);
|
||||||
|
OP_REQUIRES_OK(ctx, execution_inputs.status());
|
||||||
|
|
||||||
// Execute the computation.
|
// Execute the computation.
|
||||||
VLOG(2) << "Executing computation.";
|
VLOG(2) << "Executing computation.";
|
||||||
@ -403,24 +416,24 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
Env* env = Env::Default();
|
Env* env = Env::Default();
|
||||||
auto start_time = env->NowMicros();
|
auto start_time = env->NowMicros();
|
||||||
|
|
||||||
xla::StatusOr<xla::ScopedShapedBuffer> run_result;
|
xla::StatusOr<xla::ExecutionOutput> execution_output;
|
||||||
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
|
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
|
||||||
run_result = executable->Run(launch_context.arguments(), run_options);
|
execution_output =
|
||||||
|
executable->Run(std::move(*execution_inputs), run_options);
|
||||||
} else {
|
} else {
|
||||||
run_result = executable->RunAsync(launch_context.arguments(), run_options);
|
execution_output =
|
||||||
|
executable->RunAsync(std::move(*execution_inputs), run_options);
|
||||||
}
|
}
|
||||||
OP_REQUIRES(ctx, run_result.ok(), run_result.status());
|
OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
|
||||||
|
|
||||||
auto elapsed = env->NowMicros() - start_time;
|
auto elapsed = env->NowMicros() - start_time;
|
||||||
VLOG(2) << "Elapsed time: " << elapsed << "us";
|
VLOG(2) << "Elapsed time: " << elapsed << "us";
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, launch_context.PopulateOutputs(
|
||||||
|
ctx, compilation_result, execution_output->ConsumeResult(),
|
||||||
|
/*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos),
|
||||||
|
input_output_alias, resource_var_ptrs));
|
||||||
|
|
||||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
|
||||||
executable->executable()->module().input_output_alias_config();
|
|
||||||
OP_REQUIRES_OK(ctx,
|
|
||||||
launch_context.PopulateOutputs(
|
|
||||||
ctx, compilation_result, run_result.ConsumeValueOrDie(),
|
|
||||||
/*missing_ctx_input_prefix=*/0, input_output_alias,
|
|
||||||
variables_snapshot));
|
|
||||||
VLOG(1) << "Done";
|
VLOG(1) << "Done";
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -516,10 +529,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
|||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
||||||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
||||||
|
|
||||||
|
// Do not alias resource updates as locking variables in XlaCompile and
|
||||||
|
// unlocking them in XlaRun may lead to deadlocks.
|
||||||
Status status = CompileToLocalExecutable(
|
Status status = CompileToLocalExecutable(
|
||||||
ctx, function_, has_ref_vars_, platform_info_, variable_infos,
|
ctx, function_, has_ref_vars_, platform_info_, variable_infos,
|
||||||
constants_,
|
constants_,
|
||||||
/*lazy=*/!must_compile_, &client, &kernel, &executable);
|
/*lazy=*/!must_compile_,
|
||||||
|
/*may_alias_resource_update=*/false, &client, &kernel, &executable);
|
||||||
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
|
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
|
||||||
variable_infos, &variables));
|
variable_infos, &variables));
|
||||||
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
|
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
|
||||||
@ -587,14 +604,22 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
|||||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||||
se::DeviceMemoryAllocator* allocator =
|
se::DeviceMemoryAllocator* allocator =
|
||||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
||||||
|
se::Stream* stream =
|
||||||
|
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||||
|
int device_ordinal = stream ? stream->parent()->device_ordinal()
|
||||||
|
: closure.client()->default_device_ordinal();
|
||||||
XlaComputationLaunchContext launch_context(
|
XlaComputationLaunchContext launch_context(
|
||||||
closure.client(), allocator,
|
closure.client(), allocator, device_ordinal,
|
||||||
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
|
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
|
||||||
/*use_multiple_streams=*/platform_info_.UseMultipleStreams());
|
/*use_multiple_streams=*/platform_info_.UseMultipleStreams());
|
||||||
|
|
||||||
// We're missing the must-be-constant inputs, tell `PopulateInputs`
|
// We're missing the must-be-constant inputs, tell `PopulateInputs`
|
||||||
// about this. We don't actually need these inputs because they've
|
// about this. We don't actually need these inputs because they've
|
||||||
// already been baked into the compiled kernel.
|
// already been baked into the compiled kernel.
|
||||||
|
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||||
|
closure.executable()->executable()->module().input_output_alias_config();
|
||||||
|
xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs;
|
||||||
|
std::map<int, const Tensor*> snapshot_ptrs;
|
||||||
{
|
{
|
||||||
tensorflow::profiler::TraceMe hlo_module_activity(
|
tensorflow::profiler::TraceMe hlo_module_activity(
|
||||||
[&] {
|
[&] {
|
||||||
@ -604,13 +629,17 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
|||||||
},
|
},
|
||||||
tensorflow::profiler::TraceMeLevel::kInfo);
|
tensorflow::profiler::TraceMeLevel::kInfo);
|
||||||
|
|
||||||
launch_context.PopulateInputs(
|
for (auto& p : closure.resource_var_snapshots()) {
|
||||||
ctx, closure.compilation_result(), closure.resource_var_snapshots(),
|
snapshot_ptrs.emplace(p.first,
|
||||||
/*missing_ctx_input_prefix=*/closure.num_constant_args());
|
p.second.has_value() ? &p.second.value() : nullptr);
|
||||||
|
}
|
||||||
|
execution_inputs = launch_context.PopulateInputs(
|
||||||
|
ctx, closure.compilation_result(), snapshot_ptrs,
|
||||||
|
/*missing_ctx_input_prefix=*/closure.num_constant_args(),
|
||||||
|
input_output_alias);
|
||||||
|
OP_REQUIRES_OK(ctx, execution_inputs.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
se::Stream* stream =
|
|
||||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
|
||||||
xla::ExecutableRunOptions run_options;
|
xla::ExecutableRunOptions run_options;
|
||||||
run_options.set_stream(stream);
|
run_options.set_stream(stream);
|
||||||
run_options.set_allocator(allocator);
|
run_options.set_allocator(allocator);
|
||||||
@ -631,21 +660,19 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
|||||||
Env* env = Env::Default();
|
Env* env = Env::Default();
|
||||||
auto start_time = env->NowMicros();
|
auto start_time = env->NowMicros();
|
||||||
|
|
||||||
xla::StatusOr<xla::ScopedShapedBuffer> run_result;
|
xla::StatusOr<xla::ExecutionOutput> execution_output;
|
||||||
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
|
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
|
||||||
run_result =
|
execution_output =
|
||||||
closure.executable()->Run(launch_context.arguments(), run_options);
|
closure.executable()->Run(std::move(*execution_inputs), run_options);
|
||||||
} else {
|
} else {
|
||||||
run_result =
|
execution_output = closure.executable()->RunAsync(
|
||||||
closure.executable()->RunAsync(launch_context.arguments(), run_options);
|
std::move(*execution_inputs), run_options);
|
||||||
}
|
}
|
||||||
OP_REQUIRES(ctx, run_result.ok(), run_result.status());
|
OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
|
||||||
|
|
||||||
auto elapsed = env->NowMicros() - start_time;
|
auto elapsed = env->NowMicros() - start_time;
|
||||||
VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
|
VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
|
||||||
|
|
||||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
|
||||||
closure.executable()->executable()->module().input_output_alias_config();
|
|
||||||
|
|
||||||
tensorflow::profiler::TraceMe hlo_module_activity(
|
tensorflow::profiler::TraceMe hlo_module_activity(
|
||||||
[&] {
|
[&] {
|
||||||
@ -653,12 +680,16 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
|||||||
},
|
},
|
||||||
tensorflow::profiler::TraceMeLevel::kInfo);
|
tensorflow::profiler::TraceMeLevel::kInfo);
|
||||||
|
|
||||||
|
xla::StatusOr<std::vector<VariableInfo>> variable_infos = GatherVariableInfo(
|
||||||
|
ctx, *closure.compilation_result(), closure.num_constant_args());
|
||||||
|
OP_REQUIRES_OK(ctx, variable_infos.status());
|
||||||
|
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(*variable_infos)));
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx,
|
ctx,
|
||||||
launch_context.PopulateOutputs(
|
launch_context.PopulateOutputs(
|
||||||
ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
|
ctx, closure.compilation_result(), execution_output->ConsumeResult(),
|
||||||
/*missing_ctx_input_prefix=*/closure.num_constant_args(),
|
/*missing_ctx_input_prefix=*/closure.num_constant_args(),
|
||||||
input_output_alias, closure.resource_var_snapshots()));
|
absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs));
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
@ -50,35 +50,47 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
|||||||
// Builds an XLA allocator for the device.
|
// Builds an XLA allocator for the device.
|
||||||
XlaComputationLaunchContext launch_context(
|
XlaComputationLaunchContext launch_context(
|
||||||
client, client->backend().memory_allocator(),
|
client, client->backend().memory_allocator(),
|
||||||
|
client->default_device_ordinal(),
|
||||||
/*allocate_xla_tensors=*/true,
|
/*allocate_xla_tensors=*/true,
|
||||||
/*use_multiple_streams=*/metadata.UseMultipleStreams());
|
/*use_multiple_streams=*/metadata.UseMultipleStreams());
|
||||||
|
|
||||||
launch_context.PopulateInputs(ctx, result, variable_args,
|
std::map<int, const Tensor*> snapshot_ptrs;
|
||||||
/*missing_ctx_input_prefix=*/0);
|
for (auto& p : variable_args) {
|
||||||
|
snapshot_ptrs.emplace(p.first,
|
||||||
|
p.second.has_value() ? &p.second.value() : nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||||
|
executable->executable()->module().input_output_alias_config();
|
||||||
|
xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
|
||||||
|
launch_context.PopulateInputs(ctx, result, snapshot_ptrs,
|
||||||
|
/*missing_ctx_input_prefix=*/0,
|
||||||
|
input_output_alias);
|
||||||
|
TF_RETURN_IF_ERROR(execution_inputs.status());
|
||||||
|
|
||||||
se::Stream* stream =
|
se::Stream* stream =
|
||||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||||
TF_RET_CHECK(stream);
|
TF_RET_CHECK(stream);
|
||||||
|
|
||||||
VLOG(2) << "Executing computation: " << name();
|
VLOG(2) << "Executing computation: " << name();
|
||||||
for (const xla::ShapedBuffer* arg : launch_context.arguments()) {
|
|
||||||
VLOG(2) << name() << ": " << *arg;
|
|
||||||
}
|
|
||||||
xla::ExecutableRunOptions run_options;
|
xla::ExecutableRunOptions run_options;
|
||||||
run_options.set_stream(stream);
|
run_options.set_stream(stream);
|
||||||
run_options.set_allocator(client->backend().memory_allocator());
|
run_options.set_allocator(client->backend().memory_allocator());
|
||||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||||
run_options.set_rng_seed(GetXLARandomSeed());
|
run_options.set_rng_seed(GetXLARandomSeed());
|
||||||
|
|
||||||
xla::StatusOr<xla::ScopedShapedBuffer> run_result =
|
xla::StatusOr<xla::ExecutionOutput> run_result =
|
||||||
executable->Run(launch_context.arguments(), run_options);
|
executable->Run(execution_inputs.ConsumeValueOrDie(), run_options);
|
||||||
TF_RETURN_IF_ERROR(run_result.status());
|
TF_RETURN_IF_ERROR(run_result.status());
|
||||||
|
xla::ExecutionOutput execution_output = run_result.ConsumeValueOrDie();
|
||||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
xla::StatusOr<std::vector<VariableInfo>> variable_infos =
|
||||||
executable->executable()->module().input_output_alias_config();
|
GatherVariableInfo(ctx, *result, 0);
|
||||||
|
TF_RETURN_IF_ERROR(variable_infos.status());
|
||||||
|
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(*variable_infos)));
|
||||||
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
|
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
|
||||||
ctx, result, run_result.ConsumeValueOrDie(),
|
ctx, result, execution_output.ConsumeResult(),
|
||||||
/*missing_ctx_input_prefix=*/0, input_output_alias, variable_args));
|
/*missing_ctx_input_prefix=*/0, absl::MakeSpan(*variable_infos),
|
||||||
|
input_output_alias, snapshot_ptrs));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,11 +59,13 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}));
|
}));
|
||||||
mutex_lock ml(*variable->mu());
|
mutex_lock ml(*variable->mu());
|
||||||
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
|
OP_REQUIRES(
|
||||||
errors::InvalidArgument(
|
context,
|
||||||
"Trying to assign variable with wrong dtype. Expected ",
|
!variable->is_initialized || variable->tensor()->dtype() == dtype_,
|
||||||
DataTypeString(variable->tensor()->dtype()), " got ",
|
errors::InvalidArgument(
|
||||||
DataTypeString(dtype_)));
|
"Trying to assign variable with wrong dtype. Expected ",
|
||||||
|
DataTypeString(variable->tensor()->dtype()), " got ",
|
||||||
|
DataTypeString(dtype_)));
|
||||||
variable->is_initialized = true;
|
variable->is_initialized = true;
|
||||||
*variable->tensor() = value;
|
*variable->tensor() = value;
|
||||||
}
|
}
|
||||||
|
@ -91,29 +91,19 @@ VariableInfo::~VariableInfo() {
|
|||||||
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
|
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
|
||||||
absl::Span<const int> variable_indices,
|
absl::Span<const int> variable_indices,
|
||||||
std::vector<VariableInfo>* result) {
|
std::vector<VariableInfo>* result) {
|
||||||
std::vector<const ResourceHandle*> resource_handles;
|
|
||||||
absl::c_transform(
|
|
||||||
variable_indices, std::back_inserter(resource_handles),
|
|
||||||
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
|
|
||||||
|
|
||||||
std::vector<core::RefCountPtr<Var>> variables;
|
|
||||||
Status s = LookupResources(ctx, resource_handles, &variables);
|
|
||||||
if (!s.ok()) {
|
|
||||||
errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage);
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
|
|
||||||
result->clear();
|
result->clear();
|
||||||
result->reserve(variable_indices.size());
|
result->reserve(variable_indices.size());
|
||||||
for (int i = 0; i < variable_indices.size(); i++) {
|
for (int var_idx : variable_indices) {
|
||||||
// *Release* the variable because we're going to unref it later in
|
Var* variable = nullptr;
|
||||||
// ~VariableInfo.
|
ResourceHandle handle = HandleFromInput(ctx, var_idx);
|
||||||
Var* variable = variables[i].release();
|
TF_RETURN_IF_ERROR(
|
||||||
int input_idx = variable_indices[i];
|
LookupOrCreateResource<Var>(ctx, handle, &variable, [&](Var** ptr) {
|
||||||
std::string var_name = HandleFromInput(ctx, input_idx).name();
|
// This var is uninitialized for now.
|
||||||
result->emplace_back(input_idx, var_name, variable);
|
*ptr = new Var(DT_INVALID);
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
result->emplace_back(var_idx, handle.name(), variable);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,24 +166,43 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
|
|||||||
|
|
||||||
XlaComputationLaunchContext::XlaComputationLaunchContext(
|
XlaComputationLaunchContext::XlaComputationLaunchContext(
|
||||||
xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
|
xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
|
||||||
bool allocate_xla_tensors, bool use_multiple_streams)
|
int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams)
|
||||||
: client_(client),
|
: client_(client),
|
||||||
xla_allocator_(xla_allocator),
|
xla_allocator_(xla_allocator),
|
||||||
allocate_xla_tensors_(allocate_xla_tensors),
|
allocate_xla_tensors_(allocate_xla_tensors),
|
||||||
use_multiple_streams_(use_multiple_streams) {
|
use_multiple_streams_(use_multiple_streams),
|
||||||
|
device_ordinal_(device_ordinal) {
|
||||||
if (use_multiple_streams_) {
|
if (use_multiple_streams_) {
|
||||||
CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
|
CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
|
||||||
"be allocating XLA tensors!";
|
"be allocating XLA tensors!";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void XlaComputationLaunchContext::PopulateInputs(
|
// Fills in `execution_input` with `buffer` for `index`.
|
||||||
|
static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input,
|
||||||
|
xla::ShapeIndex index,
|
||||||
|
se::DeviceMemoryBase& buffer,
|
||||||
|
bool donate_buffer, int device_ordinal,
|
||||||
|
se::DeviceMemoryAllocator* allocator) {
|
||||||
|
xla::MaybeOwningDeviceMemory* in_buffer =
|
||||||
|
execution_input.MutableBuffer(index);
|
||||||
|
if (donate_buffer) {
|
||||||
|
*in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator);
|
||||||
|
buffer = se::DeviceMemoryBase();
|
||||||
|
} else {
|
||||||
|
*in_buffer = buffer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<std::vector<xla::ExecutionInput>>
|
||||||
|
XlaComputationLaunchContext::PopulateInputs(
|
||||||
OpKernelContext* ctx,
|
OpKernelContext* ctx,
|
||||||
const XlaCompiler::CompilationResult* compilation_result,
|
const XlaCompiler::CompilationResult* compilation_result,
|
||||||
const ResourceVarsSnapshot& variables, int missing_ctx_input_prefix) {
|
const std::map<int, const Tensor*>& resource_vars,
|
||||||
// Build ShapedBuffers that point directly to the Tensor buffers.
|
int missing_ctx_input_prefix,
|
||||||
arg_ptrs_ =
|
const xla::HloInputOutputAliasConfig& input_output_alias) {
|
||||||
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
|
std::vector<xla::ExecutionInput> arguments;
|
||||||
|
arguments.reserve(compilation_result->xla_input_shapes.size());
|
||||||
|
|
||||||
xla::TransferManager* transfer_manager =
|
xla::TransferManager* transfer_manager =
|
||||||
client_->backend().transfer_manager();
|
client_->backend().transfer_manager();
|
||||||
@ -201,10 +210,28 @@ void XlaComputationLaunchContext::PopulateInputs(
|
|||||||
int arg_num = compilation_result->input_mapping[i];
|
int arg_num = compilation_result->input_mapping[i];
|
||||||
CHECK_GE(arg_num, missing_ctx_input_prefix);
|
CHECK_GE(arg_num, missing_ctx_input_prefix);
|
||||||
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
|
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
|
||||||
const Tensor* t = variables.count(arg_num)
|
const xla::Shape& device_shape =
|
||||||
? &(variables.at(arg_num).value())
|
transfer_manager->HostShapeToDeviceShape(shape);
|
||||||
|
|
||||||
|
bool is_resource_variable = resource_vars.count(arg_num);
|
||||||
|
bool is_updated_resource_variable =
|
||||||
|
is_resource_variable &&
|
||||||
|
absl::c_any_of(compilation_result->resource_updates,
|
||||||
|
[&](const XlaCompiler::ResourceUpdate& update) {
|
||||||
|
return update.input_index == i && update.modified;
|
||||||
|
});
|
||||||
|
|
||||||
|
const Tensor* t = is_resource_variable
|
||||||
|
? resource_vars.at(arg_num)
|
||||||
: &(ctx->input(arg_num - missing_ctx_input_prefix));
|
: &(ctx->input(arg_num - missing_ctx_input_prefix));
|
||||||
CHECK(t);
|
CHECK(t);
|
||||||
|
bool donate_buffer =
|
||||||
|
t->RefCountIsOne() && is_updated_resource_variable &&
|
||||||
|
input_output_alias.ParameterHasAlias(i, xla::ShapeIndex{});
|
||||||
|
VLOG(3) << "Processing input: " << i
|
||||||
|
<< "; is_resource_variable=" << is_resource_variable
|
||||||
|
<< "; is_updated_resource_variable=" << is_updated_resource_variable
|
||||||
|
<< "; donate_buffer=" << donate_buffer;
|
||||||
|
|
||||||
if (use_multiple_streams_) {
|
if (use_multiple_streams_) {
|
||||||
CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
|
CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
|
||||||
@ -215,23 +242,28 @@ void XlaComputationLaunchContext::PopulateInputs(
|
|||||||
ctx->op_device_context()->stream());
|
ctx->op_device_context()->stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(
|
arguments.emplace_back(device_shape, shape);
|
||||||
shape, transfer_manager->HostShapeToDeviceShape(shape))) {
|
xla::ExecutionInput& execution_input = arguments.back();
|
||||||
|
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape, device_shape)) {
|
||||||
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
|
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
|
||||||
arg_buffers_.emplace_back(
|
PopulateExecutionInputBuffer(execution_input, xla::ShapeIndex{}, dmem,
|
||||||
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
|
donate_buffer, device_ordinal_,
|
||||||
client_->platform(), client_->default_device_ordinal());
|
xla_allocator_);
|
||||||
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
|
|
||||||
arg_ptrs_[i] = &arg_buffers_.back();
|
|
||||||
} else {
|
} else {
|
||||||
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||||
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
|
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
|
||||||
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
|
xla_tensor->shaped_buffer().buffers().ForEachMutableElement(
|
||||||
|
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||||
|
PopulateExecutionInputBuffer(execution_input, index, *buffer,
|
||||||
|
donate_buffer, device_ordinal_,
|
||||||
|
xla_allocator_);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return std::move(arguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Construct the tensor for given type and buffer.
|
// Construct the tensor for the given type and buffer.
|
||||||
static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
|
static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
|
||||||
se::DeviceMemoryBase buffer, Allocator* allocator) {
|
se::DeviceMemoryBase buffer, Allocator* allocator) {
|
||||||
size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
|
size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
|
||||||
@ -247,28 +279,26 @@ static Tensor GetOrCreateTensorForOutput(
|
|||||||
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
||||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||||
absl::Span<const int> input_mapping,
|
absl::Span<const int> input_mapping,
|
||||||
const ResourceVarsSnapshot& resource_var_snapshots, DataType output_dtype,
|
const std::map<int, const Tensor*>& resource_vars_snapshots,
|
||||||
const TensorShape& output_shape, se::DeviceMemoryBase output_buffer,
|
DataType output_dtype, const TensorShape& output_shape,
|
||||||
Allocator* output_allocator) {
|
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
|
||||||
xla::ShapeIndex output_index = input_output_alias.shape().IsTuple()
|
xla::ShapeIndex output_index = input_output_alias.shape().IsTuple()
|
||||||
? xla::ShapeIndex({output_num})
|
? xla::ShapeIndex({output_num})
|
||||||
: xla::ShapeIndex({});
|
: xla::ShapeIndex({});
|
||||||
|
|
||||||
CHECK(input_output_alias.shape().IsTuple() || output_num == 0);
|
CHECK(input_output_alias.shape().IsTuple() || output_num == 0);
|
||||||
if (absl::optional<xla::HloInputOutputAliasConfig::Alias> alias =
|
if (absl::optional<xla::HloInputOutputAliasConfig::Alias> alias =
|
||||||
input_output_alias.GetAliasedParameter(output_index)) {
|
input_output_alias.GetAliasedParameter(output_index)) {
|
||||||
|
VLOG(3) << "Found alias: " << alias->ToString();
|
||||||
int tf_param =
|
int tf_param =
|
||||||
input_mapping[alias->parameter_number] - missing_ctx_input_prefix;
|
input_mapping[alias->parameter_number] - missing_ctx_input_prefix;
|
||||||
const Tensor* input_tensor = &ctx->input(tf_param);
|
const Tensor input_tensor =
|
||||||
|
ctx->input(tf_param).dtype() != DT_RESOURCE
|
||||||
// If input tensor is a resource variable, alias to the snapshot we took at
|
? ctx->input(tf_param)
|
||||||
// entry time.
|
: *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param);
|
||||||
if (input_tensor->dtype() == DT_RESOURCE) {
|
if (output_buffer.opaque() == input_tensor.data()) {
|
||||||
const absl::optional<Tensor>& v =
|
return input_tensor;
|
||||||
resource_var_snapshots.at(missing_ctx_input_prefix + tf_param);
|
|
||||||
CHECK(v.has_value());
|
|
||||||
return *v;
|
|
||||||
}
|
}
|
||||||
return *input_tensor;
|
|
||||||
}
|
}
|
||||||
return MakeTensor(output_dtype, output_shape, output_buffer,
|
return MakeTensor(output_dtype, output_shape, output_buffer,
|
||||||
output_allocator);
|
output_allocator);
|
||||||
@ -291,12 +321,10 @@ static Status SetOutputForConstant(
|
|||||||
OpKernelContext* ctx, se::Stream* stream,
|
OpKernelContext* ctx, se::Stream* stream,
|
||||||
const XlaCompiler::CompilationResult* compilation_result, int output_num) {
|
const XlaCompiler::CompilationResult* compilation_result, int output_num) {
|
||||||
CHECK(compilation_result->outputs[output_num].is_constant);
|
CHECK(compilation_result->outputs[output_num].is_constant);
|
||||||
// Output is a constant.
|
|
||||||
const Tensor& const_tensor =
|
const Tensor& const_tensor =
|
||||||
compilation_result->outputs[output_num].constant_value;
|
compilation_result->outputs[output_num].constant_value;
|
||||||
Tensor* output_tensor;
|
Tensor* output_tensor;
|
||||||
const size_t total_bytes = const_tensor.TotalBytes();
|
if (stream && const_tensor.TotalBytes() > 0) {
|
||||||
if (stream && total_bytes > 0) {
|
|
||||||
// Copy host -> device. (Empty tensors don't have backing buffers.)
|
// Copy host -> device. (Empty tensors don't have backing buffers.)
|
||||||
// Manually allocate memory using an XlaTensorBuffer so we can allocate
|
// Manually allocate memory using an XlaTensorBuffer so we can allocate
|
||||||
// as much memory as the device requires (as given by
|
// as much memory as the device requires (as given by
|
||||||
@ -335,52 +363,55 @@ static Status SetOutputForConstant(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a list of updates resource variables.
|
static xla::StatusOr<Var*> GetOrCreateResourceVar(
|
||||||
static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
|
OpKernelContext* ctx, const ResourceHandle& handle,
|
||||||
OpKernelContext* ctx,
|
const XlaCompiler::ResourceUpdate& write) {
|
||||||
const XlaCompiler::CompilationResult* compilation_result,
|
Var* variable = nullptr;
|
||||||
int missing_ctx_input_prefix) {
|
TF_RETURN_IF_ERROR(
|
||||||
std::vector<VariableInfo> variable_infos;
|
LookupOrCreateResource<Var>(ctx, handle, &variable, [&write](Var** ptr) {
|
||||||
variable_infos.reserve(compilation_result->resource_updates.size());
|
*ptr = new Var(write.type);
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
return variable;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
|
||||||
|
OpKernelContext* ctx,
|
||||||
|
const XlaCompiler::CompilationResult& compilation_result,
|
||||||
|
int missing_ctx_input_prefix) {
|
||||||
|
std::vector<VariableInfo> out;
|
||||||
|
out.reserve(compilation_result.resource_updates.size());
|
||||||
|
for (int i = 0; i < compilation_result.resource_updates.size(); ++i) {
|
||||||
const XlaCompiler::ResourceUpdate& write =
|
const XlaCompiler::ResourceUpdate& write =
|
||||||
compilation_result->resource_updates[i];
|
compilation_result.resource_updates[i];
|
||||||
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
||||||
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
|
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
|
||||||
return errors::Internal("Invalid input index for variable write.");
|
return errors::Internal("Invalid input index for variable write.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
|
|
||||||
// not a Tensor.
|
|
||||||
Var* variable = nullptr;
|
|
||||||
const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
|
const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
|
||||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(ctx, handle, &variable,
|
TF_ASSIGN_OR_RETURN(Var * variable,
|
||||||
[&write](Var** ptr) {
|
GetOrCreateResourceVar(ctx, handle, write));
|
||||||
*ptr = new Var(write.type);
|
out.emplace_back(actual_input_index, handle.name(), variable);
|
||||||
return Status::OK();
|
|
||||||
}));
|
|
||||||
variable_infos.emplace_back(actual_input_index, handle.name(), variable);
|
|
||||||
}
|
}
|
||||||
return variable_infos;
|
return std::move(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status XlaComputationLaunchContext::PopulateOutputs(
|
Status XlaComputationLaunchContext::PopulateOutputs(
|
||||||
OpKernelContext* ctx,
|
OpKernelContext* ctx,
|
||||||
const XlaCompiler::CompilationResult* compilation_result,
|
const XlaCompiler::CompilationResult* compilation_result,
|
||||||
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||||
|
absl::Span<VariableInfo> variable_infos,
|
||||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||||
const ResourceVarsSnapshot& resource_var_snapshots) {
|
const std::map<int, const Tensor*>& resource_vars) {
|
||||||
se::Stream* stream =
|
se::Stream* stream =
|
||||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||||
|
|
||||||
// Computation output should always be a tuple.
|
// Computation output should always be a tuple.
|
||||||
if (VLOG_IS_ON(2)) {
|
VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
|
||||||
VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
|
VLOG(2) << "Result tuple shape (on device): "
|
||||||
VLOG(2) << "Result tuple shape (on device): "
|
<< output.on_device_shape().DebugString();
|
||||||
<< output.on_device_shape().DebugString();
|
|
||||||
}
|
|
||||||
CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
|
CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
|
||||||
|
|
||||||
// If the on-host-shape isn't a tuple, create a new single-element tuple
|
// If the on-host-shape isn't a tuple, create a new single-element tuple
|
||||||
@ -438,8 +469,8 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||||
const TensorShape& shape = output_tensor_shapes[i];
|
const TensorShape& shape = output_tensor_shapes[i];
|
||||||
const DataType& type = compilation_result->outputs[i].type;
|
const DataType& type = compilation_result->outputs[i].type;
|
||||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
VLOG(2) << "Populating output for retval " << i << " shape "
|
||||||
<< DataTypeString(type);
|
<< shape.DebugString() << " type " << DataTypeString(type);
|
||||||
if (type == DT_VARIANT) {
|
if (type == DT_VARIANT) {
|
||||||
return errors::Unimplemented(
|
return errors::Unimplemented(
|
||||||
"Support for TensorList crossing the XLA/TF boundary "
|
"Support for TensorList crossing the XLA/TF boundary "
|
||||||
@ -467,30 +498,37 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||||
compilation_result->input_mapping, resource_var_snapshots,
|
compilation_result->input_mapping, resource_vars,
|
||||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
|
||||||
ctx->set_output(i, output_tensor);
|
ctx->set_output(i, output_tensor);
|
||||||
}
|
}
|
||||||
|
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||||
++output_num;
|
++output_num;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (VLOG_IS_ON(3)) {
|
|
||||||
VLOG(3) << ctx->mutable_output(i)->DeviceSafeDebugString();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply variable updates, if any.
|
// input_index -> index into variable_infos.
|
||||||
VLOG(2) << "Applying variable updates";
|
absl::flat_hash_map<int, int> variable_info_lookup;
|
||||||
TF_ASSIGN_OR_RETURN(
|
for (int i = 0; i < variable_infos.size(); i++) {
|
||||||
std::vector<VariableInfo> variable_infos,
|
variable_info_lookup.emplace(variable_infos[i].index(), i);
|
||||||
GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix));
|
}
|
||||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
|
||||||
|
|
||||||
|
// Apply variable updates, if any.
|
||||||
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
||||||
const XlaCompiler::ResourceUpdate& write =
|
const XlaCompiler::ResourceUpdate& write =
|
||||||
compilation_result->resource_updates[i];
|
compilation_result->resource_updates[i];
|
||||||
if (variable_infos[i].var()->tensor()->dtype() != write.type) {
|
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
||||||
|
CHECK_GE(actual_input_index, 0);
|
||||||
|
CHECK_LT(actual_input_index, ctx->num_inputs());
|
||||||
|
Var* var = variable_infos[variable_info_lookup[actual_input_index]].var();
|
||||||
|
CHECK(var);
|
||||||
|
|
||||||
|
VLOG(2) << "Updating variable #" << i
|
||||||
|
<< " at input index: " << actual_input_index << " with shape "
|
||||||
|
<< write.shape.DebugString() << "; variable tensor has shape: "
|
||||||
|
<< var->tensor()->shape().DebugString();
|
||||||
|
|
||||||
|
if (var->is_initialized && var->tensor()->dtype() != write.type) {
|
||||||
return errors::Internal("Mismatched type in variable write");
|
return errors::Internal("Mismatched type in variable write");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -504,14 +542,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
|
||||||
output_tensor = GetOrCreateTensorForOutput(
|
output_tensor = GetOrCreateTensorForOutput(
|
||||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||||
compilation_result->input_mapping, resource_var_snapshots, write.type,
|
compilation_result->input_mapping, resource_vars, write.type,
|
||||||
write.shape, buffer, allocator);
|
write.shape, buffer, allocator);
|
||||||
}
|
}
|
||||||
*variable_infos[i].var()->tensor() = output_tensor;
|
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||||
variable_infos[i].var()->is_initialized |= write.modified;
|
var->is_initialized |= write.modified;
|
||||||
|
*var->tensor() = output_tensor;
|
||||||
++output_num;
|
++output_num;
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -562,7 +600,7 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
|||||||
arg.name = std::string(variable.name());
|
arg.name = std::string(variable.name());
|
||||||
arg.kind = XlaCompiler::Argument::kResource;
|
arg.kind = XlaCompiler::Argument::kResource;
|
||||||
arg.resource_kind = XlaResource::kVariable;
|
arg.resource_kind = XlaResource::kVariable;
|
||||||
if (variable.var()) {
|
if (variable.var() && variable.var()->is_initialized) {
|
||||||
const Tensor* value = variable.var()->tensor();
|
const Tensor* value = variable.var()->tensor();
|
||||||
arg.type = value->dtype();
|
arg.type = value->dtype();
|
||||||
arg.shape = value->shape();
|
arg.shape = value->shape();
|
||||||
|
@ -81,6 +81,12 @@ class VariableInfo {
|
|||||||
bool lock_held_ = false;
|
bool lock_held_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Creates a list of updated resource variables.
|
||||||
|
xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
|
||||||
|
OpKernelContext* ctx,
|
||||||
|
const XlaCompiler::CompilationResult& compilation_result,
|
||||||
|
int missing_ctx_input_prefix);
|
||||||
|
|
||||||
// Takes a snapshot of the values of resource variable arguments, whose indices
|
// Takes a snapshot of the values of resource variable arguments, whose indices
|
||||||
// are specified in `variable_indices` argument. We snapshot tensors that back
|
// are specified in `variable_indices` argument. We snapshot tensors that back
|
||||||
// resource variables since concurrent updates may modify the shape, and it is
|
// resource variables since concurrent updates may modify the shape, and it is
|
||||||
@ -124,7 +130,7 @@ class XlaComputationLaunchContext {
|
|||||||
// objects.
|
// objects.
|
||||||
XlaComputationLaunchContext(xla::LocalClient* client,
|
XlaComputationLaunchContext(xla::LocalClient* client,
|
||||||
se::DeviceMemoryAllocator* xla_allocator,
|
se::DeviceMemoryAllocator* xla_allocator,
|
||||||
bool allocate_xla_tensors,
|
int device_ordinal, bool allocate_xla_tensors,
|
||||||
bool use_multiple_streams);
|
bool use_multiple_streams);
|
||||||
|
|
||||||
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
|
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
|
||||||
@ -142,10 +148,12 @@ class XlaComputationLaunchContext {
|
|||||||
// missing and adjusts input indices accordingly. All elements in kernel's
|
// missing and adjusts input indices accordingly. All elements in kernel's
|
||||||
// input_mapping must be greater than or equal to `missing_ctx_input_prefix`
|
// input_mapping must be greater than or equal to `missing_ctx_input_prefix`
|
||||||
// (in other words, no inputs actually required by the kernel can be missing).
|
// (in other words, no inputs actually required by the kernel can be missing).
|
||||||
void PopulateInputs(OpKernelContext* ctx,
|
xla::StatusOr<std::vector<xla::ExecutionInput>> PopulateInputs(
|
||||||
const XlaCompiler::CompilationResult* compilation_result,
|
OpKernelContext* ctx,
|
||||||
const ResourceVarsSnapshot& variables,
|
const XlaCompiler::CompilationResult* compilation_result,
|
||||||
int missing_ctx_input_prefix);
|
const std::map<int, const Tensor*>& resource_vars,
|
||||||
|
int missing_ctx_input_prefix,
|
||||||
|
const xla::HloInputOutputAliasConfig& input_output_alias);
|
||||||
|
|
||||||
// Given the XLA output in `output`, populate all outputs of `ctx`. Also
|
// Given the XLA output in `output`, populate all outputs of `ctx`. Also
|
||||||
// writes out the resource variable updates.
|
// writes out the resource variable updates.
|
||||||
@ -161,20 +169,16 @@ class XlaComputationLaunchContext {
|
|||||||
OpKernelContext* ctx,
|
OpKernelContext* ctx,
|
||||||
const XlaCompiler::CompilationResult* compilation_result,
|
const XlaCompiler::CompilationResult* compilation_result,
|
||||||
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||||
|
absl::Span<VariableInfo> variable_infos,
|
||||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||||
const ResourceVarsSnapshot& resource_var_snapshots);
|
const std::map<int, const Tensor*>& resource_vars);
|
||||||
|
|
||||||
// Return the argument list. Only valid after PopulateInputs() has been
|
|
||||||
// called.
|
|
||||||
const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
xla::LocalClient* client_;
|
xla::LocalClient* client_;
|
||||||
se::DeviceMemoryAllocator* xla_allocator_;
|
se::DeviceMemoryAllocator* xla_allocator_;
|
||||||
bool allocate_xla_tensors_;
|
bool allocate_xla_tensors_;
|
||||||
bool use_multiple_streams_;
|
bool use_multiple_streams_;
|
||||||
std::deque<xla::ShapedBuffer> arg_buffers_;
|
int device_ordinal_;
|
||||||
std::vector<xla::ShapedBuffer*> arg_ptrs_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// A simple TensorBuffer implementation that allows us to create Tensors that
|
// A simple TensorBuffer implementation that allows us to create Tensors that
|
||||||
|
@ -32,6 +32,7 @@ from tensorflow.python.ops import math_ops
|
|||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -403,6 +404,69 @@ class DefFunctionTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(inner_retracings, 1)
|
self.assertEqual(inner_retracings, 1)
|
||||||
|
|
||||||
|
def testUpdateVariable(self):
|
||||||
|
v = variables.Variable(3.1)
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def update_var(a, b):
|
||||||
|
v.assign_add(a * b)
|
||||||
|
|
||||||
|
update_var(constant_op.constant(0.7), constant_op.constant(0.6))
|
||||||
|
self.assertAllClose(v, 3.52)
|
||||||
|
|
||||||
|
def testUpdateVariableVector(self):
|
||||||
|
v = variables.Variable([3.1, 3.1])
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def update_var(a, b):
|
||||||
|
v.assign_add(a * b)
|
||||||
|
|
||||||
|
update_var(
|
||||||
|
constant_op.constant([0.7, 0.7]), constant_op.constant([0.6, 0.6]))
|
||||||
|
self.assertAllClose(v, [3.52, 3.52])
|
||||||
|
|
||||||
|
def testUpdateVariableInClass(self):
|
||||||
|
|
||||||
|
class C(object):
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def update_var(self, a, b):
|
||||||
|
if not hasattr(self, 'v'):
|
||||||
|
self.v = variables.Variable(3.1)
|
||||||
|
self.v.assign_add(a * b)
|
||||||
|
|
||||||
|
c = C()
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def outer():
|
||||||
|
c.update_var(constant_op.constant(0.7), constant_op.constant(0.6))
|
||||||
|
|
||||||
|
outer()
|
||||||
|
self.assertAllClose(c.v, 3.52)
|
||||||
|
|
||||||
|
def testUpdateVariableMultipleOutputs(self):
|
||||||
|
v = variables.Variable(3.1)
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def update_var(a, b):
|
||||||
|
v.assign_add(a * b)
|
||||||
|
return a * b + v
|
||||||
|
|
||||||
|
out = update_var(constant_op.constant(0.7), constant_op.constant(0.6))
|
||||||
|
self.assertAllClose(v, 3.52)
|
||||||
|
self.assertAllClose(out, 3.94)
|
||||||
|
|
||||||
|
def testReturnIdentity(self):
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def f(a, b):
|
||||||
|
return (a, b)
|
||||||
|
|
||||||
|
a = constant_op.constant([0.7])
|
||||||
|
b = constant_op.constant([0.6])
|
||||||
|
|
||||||
|
f(a, b)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
Loading…
Reference in New Issue
Block a user