[TF2XLA] [NFC] Simplify GetVariableInfosFromInputs and BuildXlaCompilerArguments

PiperOrigin-RevId: 329848962
Change-Id: I26ae860a588e73046457257227f45f03158a20b1
This commit is contained in:
George Karpenkov 2020-09-02 21:36:16 -07:00 committed by TensorFlower Gardener
parent d65352ddd3
commit 72d8f7ba27
4 changed files with 92 additions and 80 deletions

View File

@ -164,6 +164,7 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
static Status CompileToLocalExecutable(
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
const XlaPlatformInfo& platform_info,
absl::Span<const Tensor* const> inputs,
absl::Span<VariableInfo const> variable_infos,
absl::Span<const int> constants, bool lazy, bool may_alias_resource_update,
xla::LocalClient** client,
@ -195,11 +196,6 @@ static Status CompileToLocalExecutable(
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info, has_ref_vars, &tf_allocator_adapter);
std::map<int, Tensor> constant_args;
for (int i : constants) {
constant_args.insert({i, ctx->input(i)});
}
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
// Optimization: where possible, have the computation return a naked array
@ -209,10 +205,11 @@ static Status CompileToLocalExecutable(
!platform_info.is_on_xla_device() &&
may_alias_resource_update;
std::vector<XlaCompiler::Argument> args;
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_args, variable_infos, ctx, &args));
return cache->Compile(options, function, args, compile_options,
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs,
variable_infos);
TF_RETURN_IF_ERROR(args.status());
return cache->Compile(options, function, *args, compile_options,
lazy ? XlaCompilationCache::CompileMode::kLazy
: XlaCompilationCache::CompileMode::kStrict,
compilation_result, executable);
@ -222,6 +219,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
VLOG(1) << "XlaLocalLaunchOpBase::Compute "
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
std::vector<const Tensor*> inputs = InputsFromContext(ctx);
xla::LocalClient* client;
const XlaCompiler::CompilationResult* compilation_result;
xla::LocalExecutable* executable;
@ -229,10 +227,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
std::vector<VariableInfo> variable_infos;
{
OP_REQUIRES_OK(
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
inputs, resources_, &variable_infos));
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
Status s = CompileToLocalExecutable(
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, inputs,
variable_infos, constants_, /*lazy=*/false,
/*may_alias_resource_update=*/true, &client, &compilation_result,
&executable);
@ -389,6 +388,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
xla::LocalExecutable* executable;
ResourceVarsSnapshot variables;
std::vector<const Tensor*> inputs = InputsFromContext(ctx);
bool cannot_compile_cluster;
{
mutex_lock guard(cannot_compile_cluster_mu_);
@ -401,13 +401,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
} else {
std::vector<VariableInfo> variable_infos;
OP_REQUIRES_OK(
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
inputs, resources_, &variable_infos));
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
// Do not alias resource updates as locking variables in XlaCompile and
// unlocking them in XlaRun may lead to deadlocks.
Status status = CompileToLocalExecutable(
ctx, function_, has_ref_vars_, platform_info_, variable_infos,
ctx, function_, has_ref_vars_, platform_info_, inputs, variable_infos,
constants_,
/*lazy=*/!must_compile_,
/*may_alias_resource_update=*/false, &client, &kernel, &executable);

View File

@ -103,26 +103,16 @@ Status XlaCompileOnDemandOp::Compile(
OpKernelContext* ctx, const XlaCompiler::CompilationResult** result,
XlaCompilationCache** cache, ResourceVarsSnapshot* variable_args,
xla::LocalExecutable** executable) {
std::map<int, Tensor> constant_arguments;
std::vector<int> constant_input_indices;
TF_RETURN_IF_ERROR(GetCompileTimeConstInputs(
&ctx->op_kernel(), &constant_input_indices, ctx->function_library()));
CHECK(absl::c_is_sorted(constant_input_indices));
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
const Tensor& device_tensor = ctx->input(i);
if (!constant_arguments.count(i)) {
if (absl::c_binary_search(constant_input_indices, i)) {
if (ctx->input_memory_type(i) != HOST_MEMORY) {
return errors::Internal(
"Expected constant argument not in host memory");
}
constant_arguments[i] = device_tensor;
}
}
if (!absl::c_all_of(constant_input_indices, [&](int idx) {
return ctx->input_memory_type(idx) == HOST_MEMORY;
})) {
return errors::Internal("Unexpected device placement for a constant input");
}
std::vector<const Tensor*> inputs = InputsFromContext(ctx);
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
@ -150,19 +140,23 @@ Status XlaCompileOnDemandOp::Compile(
compile_options.always_return_tuple = false;
std::vector<int> variables_indices = GetResourceVariableIndices(ctx);
std::vector<XlaCompiler::Argument> args;
xla::StatusOr<std::vector<XlaCompiler::Argument>> args;
{
std::vector<VariableInfo> variable_infos;
TF_RETURN_IF_ERROR(
GetVariableInfosFromCtxInputs(ctx, variables_indices, &variable_infos));
GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
inputs, variables_indices, &variable_infos));
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
TF_RETURN_IF_ERROR(SnapshotResourceVariables(
ctx, variables_indices, variable_infos, variable_args));
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arguments, variable_infos, ctx, &args));
args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_input_indices, inputs, variable_infos);
TF_RETURN_IF_ERROR(args.status());
}
return (*cache)->CompileSingleOp(options, args, ctx, compile_options, result,
return (*cache)->CompileSingleOp(options, *args, ctx, compile_options, result,
executable);
}

View File

@ -79,19 +79,22 @@ VariableInfo::~VariableInfo() {
}
}
// Returns a vector of VariableInfo instances for the resource variable inputs
// to the kernel with context `ctx`. The input indices for the resource
// variable inputs are in `variable_indices`.
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result) {
Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev,
absl::Span<const Tensor* const> inputs,
absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result) {
result->clear();
result->reserve(variable_indices.size());
for (int var_idx : variable_indices) {
Var* variable = nullptr;
ResourceHandle handle = HandleFromInput(ctx, var_idx);
TF_RETURN_IF_ERROR(
LookupOrCreateResource<Var>(ctx, handle, &variable, [&](Var** ptr) {
ResourceHandle handle = inputs[var_idx]->flat<ResourceHandle>()(0);
if (handle.device() != dev->attributes().name()) {
return errors::InvalidArgument("Trying to access resource ",
handle.name(), " located in device ",
dev->name());
}
TF_RETURN_IF_ERROR(rm->LookupOrCreate<Var>(
handle.container(), handle.name(), &variable, [](Var** ptr) {
// This var is uninitialized for now.
*ptr = new Var(DT_INVALID);
return Status::OK();
@ -101,6 +104,15 @@ Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
return Status::OK();
}
std::vector<const Tensor*> InputsFromContext(OpKernelContext* ctx) {
std::vector<const Tensor*> inputs;
inputs.reserve(ctx->num_inputs());
for (int input_idx = 0; input_idx < ctx->num_inputs(); input_idx++) {
inputs.push_back(&ctx->input(input_idx));
}
return inputs;
}
Status LockVariables(absl::Span<VariableInfo> variables) {
std::vector<int> lock_order(variables.size());
std::iota(lock_order.begin(), lock_order.end(), 0);
@ -548,11 +560,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
return Status::OK();
}
Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
const std::map<int, Tensor>& must_be_constant_args,
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
std::vector<XlaCompiler::Argument>* args) {
args->resize(ctx->num_inputs());
xla::StatusOr<std::vector<XlaCompiler::Argument>>
XlaComputationLaunchContext::BuildXlaCompilerArguments(
absl::Span<int const> must_be_constant_idxs,
absl::Span<const Tensor* const> inputs,
absl::Span<VariableInfo const> variable_args) {
CHECK(absl::c_is_sorted(must_be_constant_idxs));
std::vector<XlaCompiler::Argument> out;
out.resize(inputs.size());
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
for (const VariableInfo& info : variable_args) {
@ -562,33 +577,20 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
variable_info_lookup.emplace(info.index(), &info);
}
for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
XlaCompiler::Argument& arg = (*args)[input_num];
for (int64 input_num = 0; input_num < inputs.size(); ++input_num) {
const Tensor* input = inputs[input_num];
if (must_be_constant_args.count(input_num) > 0) {
XlaCompiler::Argument& arg = out[input_num];
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
// Handles compile-time constants.
const Tensor& input = must_be_constant_args.at(input_num);
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
arg.kind = XlaCompiler::Argument::kConstant;
arg.type = input.dtype();
arg.shape = input.shape();
arg.constant_value = input;
} else if (variable_info_lookup.count(input_num) == 0) {
// Handles the non-constant arguments.
const Tensor& input = ctx->input(input_num);
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
if (input.NumElements() > 0) {
arg.kind = XlaCompiler::Argument::kParameter;
} else {
arg.kind = XlaCompiler::Argument::kConstant;
arg.constant_value = input;
}
arg.type = input.dtype();
arg.shape = input.shape();
} else {
arg.type = input->dtype();
arg.shape = input->shape();
arg.constant_value = *input;
} else if (variable_info_lookup.count(input_num)) {
// Handles resource variables.
const Tensor& input = ctx->input(input_num);
TF_RET_CHECK(input.dtype() == DT_RESOURCE);
TF_RET_CHECK(input->dtype() == DT_RESOURCE);
const VariableInfo& variable = *variable_info_lookup[input_num];
arg.name = std::string(variable.name());
arg.kind = XlaCompiler::Argument::kResource;
@ -607,10 +609,21 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
arg.type = DT_INVALID;
arg.shape = TensorShape();
}
} else {
// Normal inputs.
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
if (input->NumElements() > 0) {
arg.kind = XlaCompiler::Argument::kParameter;
} else {
arg.kind = XlaCompiler::Argument::kConstant;
arg.constant_value = *input;
}
arg.type = input->dtype();
arg.shape = input->shape();
}
}
return Status::OK();
return out;
}
} // namespace tensorflow

View File

@ -109,12 +109,16 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
Status LockVariables(absl::Span<VariableInfo> variables)
TF_EXCLUSIVE_LOCK_FUNCTION();
// Returns a vector of VariableInfo instances for the resource variable inputs
// to the kernel with context `ctx`. The input indices for the resource
// Returns a vector of VariableInfo instances for the resource variable inputs,
// given that *all* inputs are in `inputs`. The input indices for the resource
// variable inputs are in `variable_indices`.
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result);
Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev,
absl::Span<const Tensor* const> inputs,
absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result);
// Returns pointers to inputs stored in `ctx`.
std::vector<const Tensor*> InputsFromContext(OpKernelContext* ctx);
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
// ShapedBuffers suitable for passing to an XLA computation.
@ -136,10 +140,10 @@ class XlaComputationLaunchContext {
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
// op.
// Precondition: variables in `variable_args` are locked.
static Status BuildXlaCompilerArguments(
const std::map<int, Tensor>& constant_args,
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
std::vector<XlaCompiler::Argument>* args);
static xla::StatusOr<std::vector<XlaCompiler::Argument>>
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,
absl::Span<const Tensor* const> inputs,
absl::Span<VariableInfo const> variable_args);
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
// `variables` is a map from TensorFlow argument number to resource variable.