[TF2XLA] [NFC] Refactor datastructures for resource variables to not require snapshotting for compilation
Previously, `BuildXlaCompilerArguments` required taking a snapshot of all resource variables in order to start compiling. In this CL it can operate with a span of pointers to resource variables instead (we CHECK at runtime that the lock is held). That refactoring allows to launch the XLA compilation without creating an extra reference to the underlying Tensor of the passed resource variables. PiperOrigin-RevId: 317706126 Change-Id: I37a97601a08f165b23b4745e1b032bf91c21c313
This commit is contained in:
parent
0a946a0069
commit
9ca89a201c
tensorflow/compiler/jit
@ -108,8 +108,7 @@ class XlaExecutableClosure {
|
||||
explicit XlaExecutableClosure(
|
||||
xla::LocalClient* client, xla::LocalExecutable* executable,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
std::map<int, OptionalTensor> resource_var_snapshots,
|
||||
int num_constant_args)
|
||||
ResourceVarsSnapshot resource_var_snapshots, int num_constant_args)
|
||||
: client_(client),
|
||||
executable_(executable),
|
||||
compilation_result_(compilation_result),
|
||||
@ -124,7 +123,7 @@ class XlaExecutableClosure {
|
||||
const XlaCompiler::CompilationResult* compilation_result() const {
|
||||
return compilation_result_;
|
||||
}
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots() const {
|
||||
const ResourceVarsSnapshot& resource_var_snapshots() const {
|
||||
return resource_var_snapshots_;
|
||||
}
|
||||
int num_constant_args() const { return num_constant_args_; }
|
||||
@ -133,7 +132,7 @@ class XlaExecutableClosure {
|
||||
xla::LocalClient* client_;
|
||||
xla::LocalExecutable* executable_;
|
||||
const XlaCompiler::CompilationResult* compilation_result_;
|
||||
std::map<int, OptionalTensor> resource_var_snapshots_;
|
||||
ResourceVarsSnapshot resource_var_snapshots_;
|
||||
int num_constant_args_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
|
||||
@ -276,10 +275,10 @@ static Status BuildCompilationCache(OpKernelContext* ctx,
|
||||
|
||||
static Status CompileToLocalExecutable(
|
||||
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
|
||||
const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
|
||||
const XlaPlatformInfo& platform_info,
|
||||
absl::Span<VariableInfo const> variable_infos,
|
||||
absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
|
||||
std::map<int, OptionalTensor>* variables,
|
||||
const XlaCompiler::CompilationResult** kernel,
|
||||
const XlaCompiler::CompilationResult** compilation_result,
|
||||
xla::LocalExecutable** executable) {
|
||||
// We store information about the JIT-compiled XLA computation
|
||||
// in the ResourceMgr.
|
||||
@ -299,7 +298,6 @@ static Status CompileToLocalExecutable(
|
||||
// this is more obviously correct.)
|
||||
core::ScopedUnref cache_ref(cache);
|
||||
|
||||
TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables));
|
||||
*client = static_cast<xla::LocalClient*>(cache->client());
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
@ -337,11 +335,11 @@ static Status CompileToLocalExecutable(
|
||||
|
||||
std::vector<XlaCompiler::Argument> args;
|
||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_args, *variables, ctx, &args));
|
||||
constant_args, variable_infos, ctx, &args));
|
||||
return cache->Compile(options, function, args, compile_options,
|
||||
lazy ? XlaCompilationCache::CompileMode::kLazy
|
||||
: XlaCompilationCache::CompileMode::kStrict,
|
||||
kernel, executable);
|
||||
compilation_result, executable);
|
||||
}
|
||||
|
||||
void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
@ -349,16 +347,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
|
||||
|
||||
xla::LocalClient* client;
|
||||
const XlaCompiler::CompilationResult* kernel;
|
||||
const XlaCompiler::CompilationResult* compilation_result;
|
||||
xla::LocalExecutable* executable;
|
||||
std::map<int, OptionalTensor> variables;
|
||||
|
||||
ResourceVarsSnapshot variables;
|
||||
{
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
||||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
||||
Status s = CompileToLocalExecutable(
|
||||
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
|
||||
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
|
||||
&executable);
|
||||
variable_infos, constants_, /*lazy=*/false, &client,
|
||||
&compilation_result, &executable);
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
|
||||
variable_infos, &variables));
|
||||
}
|
||||
|
||||
se::Stream* stream =
|
||||
@ -373,7 +377,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
client, allocator,
|
||||
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
|
||||
platform_info_.UseMultipleStreams());
|
||||
launch_context.PopulateInputs(ctx, kernel, variables,
|
||||
launch_context.PopulateInputs(ctx, compilation_result, variables,
|
||||
/*missing_ctx_input_prefix=*/0);
|
||||
|
||||
// Execute the computation.
|
||||
@ -413,7 +417,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
OP_REQUIRES_OK(
|
||||
ctx, launch_context.PopulateOutputs(
|
||||
ctx, kernel, run_result.ConsumeValueOrDie(),
|
||||
ctx, compilation_result, run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias, variables));
|
||||
VLOG(1) << "Done";
|
||||
}
|
||||
@ -494,7 +498,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
xla::LocalClient* client;
|
||||
const XlaCompiler::CompilationResult* kernel;
|
||||
xla::LocalExecutable* executable;
|
||||
std::map<int, OptionalTensor> variables;
|
||||
ResourceVarsSnapshot variables;
|
||||
|
||||
bool cannot_compile_cluster;
|
||||
{
|
||||
@ -506,9 +510,16 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
cannot_compile_cluster) {
|
||||
executable = nullptr;
|
||||
} else {
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
||||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
||||
Status status = CompileToLocalExecutable(
|
||||
ctx, function_, has_ref_vars_, platform_info_, resources_, constants_,
|
||||
/*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
|
||||
ctx, function_, has_ref_vars_, platform_info_, variable_infos,
|
||||
constants_,
|
||||
/*lazy=*/!must_compile_, &client, &kernel, &executable);
|
||||
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
|
||||
variable_infos, &variables));
|
||||
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
|
||||
OP_REQUIRES_OK(ctx, status);
|
||||
}
|
||||
|
@ -28,32 +28,23 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
|
||||
std::map<int, OptionalTensor> variables;
|
||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
||||
// Returns argument indices corresponding to the resource variable inputs of
|
||||
// kernel context `ctx`.
|
||||
static std::vector<int> GetResourceVariableIndices(OpKernelContext* ctx) {
|
||||
std::vector<int> out;
|
||||
for (int64 i = 0; i < ctx->num_inputs(); i++) {
|
||||
if (ctx->input(i).dtype() == DT_RESOURCE) {
|
||||
core::RefCountPtr<Var> variable;
|
||||
ResourceHandle handle = HandleFromInput(ctx, i);
|
||||
OptionalTensor& optional = variables[i];
|
||||
optional.name = handle.name();
|
||||
if (LookupResource(ctx, handle, &variable).ok()) {
|
||||
tf_shared_lock lock(*variable->mu());
|
||||
optional.present = true;
|
||||
optional.value = *variable->tensor();
|
||||
out.push_back(i);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
return variables;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
const XlaDevice::Metadata& metadata,
|
||||
const XlaCompiler::CompilationResult* result,
|
||||
xla::LocalExecutable* executable) {
|
||||
std::map<int, OptionalTensor> variables = GetVariables(ctx);
|
||||
|
||||
xla::LocalExecutable* executable,
|
||||
const ResourceVarsSnapshot& variable_args) {
|
||||
xla::LocalClient* client = metadata.client();
|
||||
|
||||
// Builds an XLA allocator for the device.
|
||||
@ -62,7 +53,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
/*allocate_xla_tensors=*/true,
|
||||
/*use_multiple_streams=*/metadata.UseMultipleStreams());
|
||||
|
||||
launch_context.PopulateInputs(ctx, result, variables,
|
||||
launch_context.PopulateInputs(ctx, result, variable_args,
|
||||
/*missing_ctx_input_prefix=*/0);
|
||||
|
||||
se::Stream* stream =
|
||||
@ -87,7 +78,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
|
||||
ctx, result, run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias, variables));
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias, variable_args));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -115,7 +106,7 @@ Status XlaCompileOnDemandOp::ShouldArgumentBeConstant(
|
||||
Status XlaCompileOnDemandOp::Compile(
|
||||
OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
|
||||
const XlaCompiler::CompilationResult** result,
|
||||
xla::LocalExecutable** executable) {
|
||||
ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable) {
|
||||
std::map<int, Tensor> constant_arguments;
|
||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
||||
const Tensor& device_tensor = ctx->input(i);
|
||||
@ -190,12 +181,18 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
// rather than a one-element tuple.
|
||||
compile_options.always_return_tuple = false;
|
||||
|
||||
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
|
||||
|
||||
std::vector<int> variables_indices = GetResourceVariableIndices(ctx);
|
||||
std::vector<XlaCompiler::Argument> args;
|
||||
|
||||
{
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetVariableInfosFromCtxInputs(ctx, variables_indices, &variable_infos));
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||
TF_RETURN_IF_ERROR(SnapshotResourceVariables(
|
||||
ctx, variables_indices, variable_infos, variable_args));
|
||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_arguments, variable_args, ctx, &args));
|
||||
constant_arguments, variable_infos, ctx, &args));
|
||||
}
|
||||
|
||||
return cache->CompileSingleOp(options, args, ctx, compile_options, result,
|
||||
executable);
|
||||
@ -206,8 +203,10 @@ void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
|
||||
xla::LocalExecutable* executable;
|
||||
const XlaDevice::Metadata* metadata;
|
||||
OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata));
|
||||
OP_REQUIRES_OK(ctx, Compile(ctx, *metadata, &result, &executable));
|
||||
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable));
|
||||
ResourceVarsSnapshot variable_args;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
Compile(ctx, *metadata, &result, &variable_args, &executable));
|
||||
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable, variable_args));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -47,10 +48,12 @@ class XlaCompileOnDemandOp : public OpKernel {
|
||||
bool* result);
|
||||
Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
|
||||
const XlaCompiler::CompilationResult** result,
|
||||
ResourceVarsSnapshot* variable_args,
|
||||
xla::LocalExecutable** executable);
|
||||
Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
|
||||
const XlaCompiler::CompilationResult* result,
|
||||
xla::LocalExecutable* executable);
|
||||
xla::LocalExecutable* executable,
|
||||
const ResourceVarsSnapshot& variable_args);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -52,7 +52,8 @@ const char kPossibleNonVariableResourceHintMessage[] =
|
||||
"resource inputs to XLA.";
|
||||
} // anonymous namespace
|
||||
|
||||
VariableInfo::VariableInfo(int index, Var* var) : index_(index), var_(var) {}
|
||||
VariableInfo::VariableInfo(int index, absl::string_view name, Var* var)
|
||||
: index_(index), name_(name), var_(var) {}
|
||||
VariableInfo::VariableInfo(VariableInfo&& other)
|
||||
: index_(other.index_), var_(other.var_), lock_held_(other.lock_held_) {
|
||||
other.index_ = -1;
|
||||
@ -87,8 +88,8 @@ VariableInfo::~VariableInfo() {
|
||||
// Returns a vector of VariableInfo instances for the resource variable inputs
|
||||
// to the kernel with context `ctx`. The input indices for the resource
|
||||
// variable inputs are in `variable_indices`.
|
||||
static Status GetVariableInfosFromCtxInputs(
|
||||
OpKernelContext* ctx, absl::Span<const int> variable_indices,
|
||||
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::vector<VariableInfo>* result) {
|
||||
std::vector<const ResourceHandle*> resource_handles;
|
||||
absl::c_transform(
|
||||
@ -96,7 +97,6 @@ static Status GetVariableInfosFromCtxInputs(
|
||||
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
|
||||
|
||||
std::vector<core::RefCountPtr<Var>> variables;
|
||||
|
||||
Status s = LookupResources(ctx, resource_handles, &variables);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage);
|
||||
@ -109,7 +109,9 @@ static Status GetVariableInfosFromCtxInputs(
|
||||
// *Release* the variable because we're going to unref it later in
|
||||
// ~VariableInfo.
|
||||
Var* variable = variables[i].release();
|
||||
result->emplace_back(variable_indices[i], variable);
|
||||
int input_idx = variable_indices[i];
|
||||
std::string var_name = HandleFromInput(ctx, input_idx).name();
|
||||
result->emplace_back(input_idx, var_name, variable);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -162,21 +164,12 @@ Status LockVariables(absl::Span<VariableInfo> variables) {
|
||||
|
||||
Status SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::map<int, OptionalTensor>* result) {
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetVariableInfosFromCtxInputs(ctx, variable_indices, &variable_infos));
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||
|
||||
absl::Span<VariableInfo const> variable_infos,
|
||||
ResourceVarsSnapshot* result) {
|
||||
for (int i = 0; i < variable_indices.size(); i++) {
|
||||
if (variable_infos[i].var()) {
|
||||
OptionalTensor& tensor = (*result)[variable_indices[i]];
|
||||
tensor.name = HandleFromInput(ctx, variable_indices[i]).name();
|
||||
tensor.present = true;
|
||||
tensor.value = *variable_infos[i].var()->tensor();
|
||||
} else {
|
||||
(*result)[variable_indices[i]] = OptionalTensor();
|
||||
}
|
||||
Var* var = variable_infos[i].var();
|
||||
(*result)[variable_indices[i]] =
|
||||
var ? absl::make_optional(*var->tensor()) : absl::nullopt;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -197,8 +190,7 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
|
||||
void XlaComputationLaunchContext::PopulateInputs(
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
const std::map<int, OptionalTensor>& variables,
|
||||
int missing_ctx_input_prefix) {
|
||||
const ResourceVarsSnapshot& variables, int missing_ctx_input_prefix) {
|
||||
// Build ShapedBuffers that point directly to the Tensor buffers.
|
||||
arg_ptrs_ =
|
||||
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
|
||||
@ -210,7 +202,7 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
CHECK_GE(arg_num, missing_ctx_input_prefix);
|
||||
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
|
||||
const Tensor* t = variables.count(arg_num)
|
||||
? &(variables.at(arg_num).value)
|
||||
? &(variables.at(arg_num).value())
|
||||
: &(ctx->input(arg_num - missing_ctx_input_prefix));
|
||||
CHECK(t);
|
||||
|
||||
@ -262,7 +254,7 @@ static const Tensor* FindAliasedTensorForOutput(
|
||||
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
absl::Span<const int> input_mapping,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots) {
|
||||
const ResourceVarsSnapshot& resource_var_snapshots) {
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
int xla_param = input_output_alias.GetAliasedParameter({output_num})
|
||||
.value()
|
||||
@ -274,8 +266,8 @@ static const Tensor* FindAliasedTensorForOutput(
|
||||
// entry time.
|
||||
if (input_tensor->dtype() == DT_RESOURCE) {
|
||||
auto& v = resource_var_snapshots.at(missing_ctx_input_prefix + tf_param);
|
||||
CHECK(v.present);
|
||||
return &v.value;
|
||||
CHECK(v.has_value());
|
||||
return &v.value();
|
||||
}
|
||||
return input_tensor;
|
||||
}
|
||||
@ -298,9 +290,9 @@ static Tensor GetOrCreateTensorForOutput(
|
||||
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
absl::Span<const int> input_mapping,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots,
|
||||
DataType output_dtype, const TensorShape& output_shape,
|
||||
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
|
||||
const ResourceVarsSnapshot& resource_var_snapshots, DataType output_dtype,
|
||||
const TensorShape& output_shape, se::DeviceMemoryBase output_buffer,
|
||||
Allocator* output_allocator) {
|
||||
if (const Tensor* aliased_tensor = FindAliasedTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
input_mapping, resource_var_snapshots)) {
|
||||
@ -431,13 +423,13 @@ static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
|
||||
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
|
||||
// not a Tensor.
|
||||
Var* variable = nullptr;
|
||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
|
||||
ctx, HandleFromInput(ctx, actual_input_index), &variable,
|
||||
const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
|
||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(ctx, handle, &variable,
|
||||
[&write](Var** ptr) {
|
||||
*ptr = new Var(write.type);
|
||||
return Status::OK();
|
||||
}));
|
||||
variable_infos.emplace_back(actual_input_index, variable);
|
||||
variable_infos.emplace_back(actual_input_index, handle.name(), variable);
|
||||
}
|
||||
return variable_infos;
|
||||
}
|
||||
@ -447,7 +439,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots) {
|
||||
const ResourceVarsSnapshot& resource_var_snapshots) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
@ -564,12 +556,21 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
|
||||
Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
|
||||
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
|
||||
std::vector<XlaCompiler::Argument>* args) {
|
||||
args->resize(ctx->num_inputs());
|
||||
|
||||
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
|
||||
for (const VariableInfo& info : variable_args) {
|
||||
CHECK(!info.var() || info.lock_held())
|
||||
<< "Need to hold the lock on resource variables "
|
||||
"before calling BuildXlaCompilerArguments";
|
||||
variable_info_lookup.emplace(info.index(), &info);
|
||||
}
|
||||
|
||||
for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
|
||||
XlaCompiler::Argument& arg = (*args)[input_num];
|
||||
|
||||
if (constant_args.count(input_num) > 0) {
|
||||
// Handles compile-time constants.
|
||||
const Tensor& input = constant_args.at(input_num);
|
||||
@ -578,7 +579,7 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
arg.type = input.dtype();
|
||||
arg.shape = input.shape();
|
||||
arg.constant_value = input;
|
||||
} else if (variable_args.count(input_num) == 0) {
|
||||
} else if (variable_info_lookup.count(input_num) == 0) {
|
||||
// Handles the non-constant arguments.
|
||||
const Tensor& input = ctx->input(input_num);
|
||||
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
|
||||
@ -594,14 +595,14 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
// Handles resource variables.
|
||||
const Tensor& input = ctx->input(input_num);
|
||||
TF_RET_CHECK(input.dtype() == DT_RESOURCE);
|
||||
const OptionalTensor& variable = variable_args.at(input_num);
|
||||
arg.name = variable.name;
|
||||
const VariableInfo& variable = *variable_info_lookup[input_num];
|
||||
arg.name = std::string(variable.name());
|
||||
arg.kind = XlaCompiler::Argument::kResource;
|
||||
arg.resource_kind = XlaResource::kVariable;
|
||||
if (variable.present) {
|
||||
const Tensor& value = variable.value;
|
||||
arg.type = value.dtype();
|
||||
arg.shape = value.shape();
|
||||
if (variable.var()) {
|
||||
const Tensor* value = variable.var()->tensor();
|
||||
arg.type = value->dtype();
|
||||
arg.shape = value->shape();
|
||||
arg.initialized = true;
|
||||
} else {
|
||||
// The values of uninitialized variables are not passed as inputs, since
|
||||
|
@ -34,36 +34,17 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Struct that represents a possibly-absent Tensor.
|
||||
struct OptionalTensor {
|
||||
string name; // A descriptive name
|
||||
bool present = false; // Is the tensor present?
|
||||
Tensor value; // If present, what is the Tensor's value?
|
||||
};
|
||||
|
||||
// Takes a snapshot of the values of resource variable arguments, whose indices
|
||||
// are specified in `variable_indices` argument. We snapshot tensors that back
|
||||
// resource variables since concurrent updates may modify the shape, and it is
|
||||
// important that the shapes used for compilation match the true shapes of the
|
||||
// buffers.
|
||||
//
|
||||
// We snapshot the entire set of resource variables as one atomic operation.
|
||||
// This models Read->* dependencies between resource variable operations. See
|
||||
// jit/resource_operation_safety_analysis for details.
|
||||
//
|
||||
// Returns a map of TensorFlow argument index to resource variable. If a
|
||||
// resource variable is not initialized, the corresponding OptionalTensor
|
||||
// will have its `present` field set to false.
|
||||
Status SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::map<int, OptionalTensor>* result);
|
||||
// Snapshot of resource variables for a TF kernel invocation, mapping from
|
||||
// parameter number to values at execution time. If the resource variable is not
|
||||
// initialized, the value will not be present.
|
||||
using ResourceVarsSnapshot = absl::flat_hash_map<int, absl::optional<Tensor>>;
|
||||
|
||||
// Information about the state of a variable passed as input to the _XlaCompile
|
||||
// and _XlaRun operators. Unlocks the resource variable and decrements its
|
||||
// refcount on destruction.
|
||||
class VariableInfo {
|
||||
public:
|
||||
explicit VariableInfo(int index, Var* var);
|
||||
explicit VariableInfo(int index, absl::string_view name, Var* var);
|
||||
VariableInfo(VariableInfo&& other);
|
||||
|
||||
VariableInfo& operator=(VariableInfo&& other);
|
||||
@ -79,6 +60,9 @@ class VariableInfo {
|
||||
// "empty", i.e. it does not track a resource variable.
|
||||
Var* var() const { return var_; }
|
||||
|
||||
// Returns the variable name.
|
||||
absl::string_view name() const { return name_; }
|
||||
|
||||
// Returns true if the resource variable lock was successfully acquired by
|
||||
// this thread.
|
||||
bool lock_held() const { return lock_held_; }
|
||||
@ -88,6 +72,7 @@ class VariableInfo {
|
||||
|
||||
private:
|
||||
int index_;
|
||||
std::string name_;
|
||||
Var* var_;
|
||||
|
||||
// We can't use a optional<mutex_lock> here because it confuses the compiler's
|
||||
@ -96,6 +81,20 @@ class VariableInfo {
|
||||
bool lock_held_ = false;
|
||||
};
|
||||
|
||||
// Takes a snapshot of the values of resource variable arguments, whose indices
|
||||
// are specified in `variable_indices` argument. We snapshot tensors that back
|
||||
// resource variables since concurrent updates may modify the shape, and it is
|
||||
// important that the shapes used for compilation match the true shapes of the
|
||||
// buffers.
|
||||
//
|
||||
// We snapshot the entire set of resource variables as one atomic operation.
|
||||
// This models Read->* dependencies between resource variable operations. See
|
||||
// jit/resource_operation_safety_analysis for details.
|
||||
Status SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
absl::Span<VariableInfo const> variable_infos,
|
||||
ResourceVarsSnapshot* result);
|
||||
|
||||
// Acquires the mutexes for all the variables in `variables` using a
|
||||
// deadlock-safe protocol (acquire the mutexes in increasing-address order).
|
||||
//
|
||||
@ -104,6 +103,13 @@ class VariableInfo {
|
||||
Status LockVariables(absl::Span<VariableInfo> variables)
|
||||
TF_EXCLUSIVE_LOCK_FUNCTION();
|
||||
|
||||
// Returns a vector of VariableInfo instances for the resource variable inputs
|
||||
// to the kernel with context `ctx`. The input indices for the resource
|
||||
// variable inputs are in `variable_indices`.
|
||||
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::vector<VariableInfo>* result);
|
||||
|
||||
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
|
||||
// ShapedBuffers suitable for passing to an XLA computation.
|
||||
class XlaComputationLaunchContext {
|
||||
@ -123,9 +129,10 @@ class XlaComputationLaunchContext {
|
||||
|
||||
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
|
||||
// op.
|
||||
// Precondition: variables in `variable_args` are locked.
|
||||
static Status BuildXlaCompilerArguments(
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
|
||||
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
|
||||
std::vector<XlaCompiler::Argument>* args);
|
||||
|
||||
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
|
||||
@ -137,7 +144,7 @@ class XlaComputationLaunchContext {
|
||||
// (in other words, no inputs actually required by the kernel can be missing).
|
||||
void PopulateInputs(OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
const std::map<int, OptionalTensor>& variables,
|
||||
const ResourceVarsSnapshot& variables,
|
||||
int missing_ctx_input_prefix);
|
||||
|
||||
// Given the XLA output in `output`, populate all outputs of `ctx`. Also
|
||||
@ -155,7 +162,7 @@ class XlaComputationLaunchContext {
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots);
|
||||
const ResourceVarsSnapshot& resource_var_snapshots);
|
||||
|
||||
// Return the argument list. Only valid after PopulateInputs() has been
|
||||
// called.
|
||||
|
Loading…
Reference in New Issue
Block a user