[TF2XLA] [NFC] Simplify GetVariableInfosFromInputs and BuildXlaCompilerArguments
PiperOrigin-RevId: 329848962 Change-Id: I26ae860a588e73046457257227f45f03158a20b1
This commit is contained in:
parent
d65352ddd3
commit
72d8f7ba27
@ -164,6 +164,7 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||
static Status CompileToLocalExecutable(
|
||||
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
|
||||
const XlaPlatformInfo& platform_info,
|
||||
absl::Span<const Tensor* const> inputs,
|
||||
absl::Span<VariableInfo const> variable_infos,
|
||||
absl::Span<const int> constants, bool lazy, bool may_alias_resource_update,
|
||||
xla::LocalClient** client,
|
||||
@ -195,11 +196,6 @@ static Status CompileToLocalExecutable(
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info, has_ref_vars, &tf_allocator_adapter);
|
||||
|
||||
std::map<int, Tensor> constant_args;
|
||||
for (int i : constants) {
|
||||
constant_args.insert({i, ctx->input(i)});
|
||||
}
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.is_entry_computation = true;
|
||||
// Optimization: where possible, have the computation return a naked array
|
||||
@ -209,10 +205,11 @@ static Status CompileToLocalExecutable(
|
||||
!platform_info.is_on_xla_device() &&
|
||||
may_alias_resource_update;
|
||||
|
||||
std::vector<XlaCompiler::Argument> args;
|
||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_args, variable_infos, ctx, &args));
|
||||
return cache->Compile(options, function, args, compile_options,
|
||||
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs,
|
||||
variable_infos);
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
return cache->Compile(options, function, *args, compile_options,
|
||||
lazy ? XlaCompilationCache::CompileMode::kLazy
|
||||
: XlaCompilationCache::CompileMode::kStrict,
|
||||
compilation_result, executable);
|
||||
@ -222,6 +219,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "XlaLocalLaunchOpBase::Compute "
|
||||
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
|
||||
|
||||
std::vector<const Tensor*> inputs = InputsFromContext(ctx);
|
||||
xla::LocalClient* client;
|
||||
const XlaCompiler::CompilationResult* compilation_result;
|
||||
xla::LocalExecutable* executable;
|
||||
@ -229,10 +227,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
{
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
||||
ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
|
||||
inputs, resources_, &variable_infos));
|
||||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
||||
Status s = CompileToLocalExecutable(
|
||||
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
|
||||
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, inputs,
|
||||
variable_infos, constants_, /*lazy=*/false,
|
||||
/*may_alias_resource_update=*/true, &client, &compilation_result,
|
||||
&executable);
|
||||
@ -389,6 +388,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
xla::LocalExecutable* executable;
|
||||
ResourceVarsSnapshot variables;
|
||||
|
||||
std::vector<const Tensor*> inputs = InputsFromContext(ctx);
|
||||
bool cannot_compile_cluster;
|
||||
{
|
||||
mutex_lock guard(cannot_compile_cluster_mu_);
|
||||
@ -401,13 +401,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
} else {
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
||||
ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
|
||||
inputs, resources_, &variable_infos));
|
||||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
||||
|
||||
// Do not alias resource updates as locking variables in XlaCompile and
|
||||
// unlocking them in XlaRun may lead to deadlocks.
|
||||
Status status = CompileToLocalExecutable(
|
||||
ctx, function_, has_ref_vars_, platform_info_, variable_infos,
|
||||
ctx, function_, has_ref_vars_, platform_info_, inputs, variable_infos,
|
||||
constants_,
|
||||
/*lazy=*/!must_compile_,
|
||||
/*may_alias_resource_update=*/false, &client, &kernel, &executable);
|
||||
|
@ -103,26 +103,16 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult** result,
|
||||
XlaCompilationCache** cache, ResourceVarsSnapshot* variable_args,
|
||||
xla::LocalExecutable** executable) {
|
||||
std::map<int, Tensor> constant_arguments;
|
||||
|
||||
std::vector<int> constant_input_indices;
|
||||
TF_RETURN_IF_ERROR(GetCompileTimeConstInputs(
|
||||
&ctx->op_kernel(), &constant_input_indices, ctx->function_library()));
|
||||
CHECK(absl::c_is_sorted(constant_input_indices));
|
||||
|
||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
||||
const Tensor& device_tensor = ctx->input(i);
|
||||
|
||||
if (!constant_arguments.count(i)) {
|
||||
if (absl::c_binary_search(constant_input_indices, i)) {
|
||||
if (ctx->input_memory_type(i) != HOST_MEMORY) {
|
||||
return errors::Internal(
|
||||
"Expected constant argument not in host memory");
|
||||
}
|
||||
constant_arguments[i] = device_tensor;
|
||||
}
|
||||
}
|
||||
if (!absl::c_all_of(constant_input_indices, [&](int idx) {
|
||||
return ctx->input_memory_type(idx) == HOST_MEMORY;
|
||||
})) {
|
||||
return errors::Internal("Unexpected device placement for a constant input");
|
||||
}
|
||||
std::vector<const Tensor*> inputs = InputsFromContext(ctx);
|
||||
|
||||
// We store information about the JIT-compiled XLA computation
|
||||
// in the ResourceMgr.
|
||||
@ -150,19 +140,23 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
compile_options.always_return_tuple = false;
|
||||
|
||||
std::vector<int> variables_indices = GetResourceVariableIndices(ctx);
|
||||
std::vector<XlaCompiler::Argument> args;
|
||||
xla::StatusOr<std::vector<XlaCompiler::Argument>> args;
|
||||
{
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetVariableInfosFromCtxInputs(ctx, variables_indices, &variable_infos));
|
||||
GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
|
||||
inputs, variables_indices, &variable_infos));
|
||||
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||
TF_RETURN_IF_ERROR(SnapshotResourceVariables(
|
||||
ctx, variables_indices, variable_infos, variable_args));
|
||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_arguments, variable_infos, ctx, &args));
|
||||
|
||||
args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_input_indices, inputs, variable_infos);
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
}
|
||||
|
||||
return (*cache)->CompileSingleOp(options, args, ctx, compile_options, result,
|
||||
return (*cache)->CompileSingleOp(options, *args, ctx, compile_options, result,
|
||||
executable);
|
||||
}
|
||||
|
||||
|
@ -79,19 +79,22 @@ VariableInfo::~VariableInfo() {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a vector of VariableInfo instances for the resource variable inputs
|
||||
// to the kernel with context `ctx`. The input indices for the resource
|
||||
// variable inputs are in `variable_indices`.
|
||||
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::vector<VariableInfo>* result) {
|
||||
Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev,
|
||||
absl::Span<const Tensor* const> inputs,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::vector<VariableInfo>* result) {
|
||||
result->clear();
|
||||
result->reserve(variable_indices.size());
|
||||
for (int var_idx : variable_indices) {
|
||||
Var* variable = nullptr;
|
||||
ResourceHandle handle = HandleFromInput(ctx, var_idx);
|
||||
TF_RETURN_IF_ERROR(
|
||||
LookupOrCreateResource<Var>(ctx, handle, &variable, [&](Var** ptr) {
|
||||
ResourceHandle handle = inputs[var_idx]->flat<ResourceHandle>()(0);
|
||||
if (handle.device() != dev->attributes().name()) {
|
||||
return errors::InvalidArgument("Trying to access resource ",
|
||||
handle.name(), " located in device ",
|
||||
dev->name());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(rm->LookupOrCreate<Var>(
|
||||
handle.container(), handle.name(), &variable, [](Var** ptr) {
|
||||
// This var is uninitialized for now.
|
||||
*ptr = new Var(DT_INVALID);
|
||||
return Status::OK();
|
||||
@ -101,6 +104,15 @@ Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<const Tensor*> InputsFromContext(OpKernelContext* ctx) {
|
||||
std::vector<const Tensor*> inputs;
|
||||
inputs.reserve(ctx->num_inputs());
|
||||
for (int input_idx = 0; input_idx < ctx->num_inputs(); input_idx++) {
|
||||
inputs.push_back(&ctx->input(input_idx));
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
Status LockVariables(absl::Span<VariableInfo> variables) {
|
||||
std::vector<int> lock_order(variables.size());
|
||||
std::iota(lock_order.begin(), lock_order.end(), 0);
|
||||
@ -548,11 +560,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
const std::map<int, Tensor>& must_be_constant_args,
|
||||
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
|
||||
std::vector<XlaCompiler::Argument>* args) {
|
||||
args->resize(ctx->num_inputs());
|
||||
xla::StatusOr<std::vector<XlaCompiler::Argument>>
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
absl::Span<int const> must_be_constant_idxs,
|
||||
absl::Span<const Tensor* const> inputs,
|
||||
absl::Span<VariableInfo const> variable_args) {
|
||||
CHECK(absl::c_is_sorted(must_be_constant_idxs));
|
||||
std::vector<XlaCompiler::Argument> out;
|
||||
out.resize(inputs.size());
|
||||
|
||||
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
|
||||
for (const VariableInfo& info : variable_args) {
|
||||
@ -562,33 +577,20 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
variable_info_lookup.emplace(info.index(), &info);
|
||||
}
|
||||
|
||||
for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
|
||||
XlaCompiler::Argument& arg = (*args)[input_num];
|
||||
for (int64 input_num = 0; input_num < inputs.size(); ++input_num) {
|
||||
const Tensor* input = inputs[input_num];
|
||||
|
||||
if (must_be_constant_args.count(input_num) > 0) {
|
||||
XlaCompiler::Argument& arg = out[input_num];
|
||||
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
||||
// Handles compile-time constants.
|
||||
const Tensor& input = must_be_constant_args.at(input_num);
|
||||
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
|
||||
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.type = input.dtype();
|
||||
arg.shape = input.shape();
|
||||
arg.constant_value = input;
|
||||
} else if (variable_info_lookup.count(input_num) == 0) {
|
||||
// Handles the non-constant arguments.
|
||||
const Tensor& input = ctx->input(input_num);
|
||||
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
|
||||
if (input.NumElements() > 0) {
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
} else {
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.constant_value = input;
|
||||
}
|
||||
arg.type = input.dtype();
|
||||
arg.shape = input.shape();
|
||||
} else {
|
||||
arg.type = input->dtype();
|
||||
arg.shape = input->shape();
|
||||
arg.constant_value = *input;
|
||||
} else if (variable_info_lookup.count(input_num)) {
|
||||
// Handles resource variables.
|
||||
const Tensor& input = ctx->input(input_num);
|
||||
TF_RET_CHECK(input.dtype() == DT_RESOURCE);
|
||||
TF_RET_CHECK(input->dtype() == DT_RESOURCE);
|
||||
const VariableInfo& variable = *variable_info_lookup[input_num];
|
||||
arg.name = std::string(variable.name());
|
||||
arg.kind = XlaCompiler::Argument::kResource;
|
||||
@ -607,10 +609,21 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
arg.type = DT_INVALID;
|
||||
arg.shape = TensorShape();
|
||||
}
|
||||
} else {
|
||||
// Normal inputs.
|
||||
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
|
||||
if (input->NumElements() > 0) {
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
} else {
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.constant_value = *input;
|
||||
}
|
||||
arg.type = input->dtype();
|
||||
arg.shape = input->shape();
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -109,12 +109,16 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
Status LockVariables(absl::Span<VariableInfo> variables)
|
||||
TF_EXCLUSIVE_LOCK_FUNCTION();
|
||||
|
||||
// Returns a vector of VariableInfo instances for the resource variable inputs
|
||||
// to the kernel with context `ctx`. The input indices for the resource
|
||||
// Returns a vector of VariableInfo instances for the resource variable inputs,
|
||||
// given that *all* inputs are in `inputs`. The input indices for the resource
|
||||
// variable inputs are in `variable_indices`.
|
||||
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::vector<VariableInfo>* result);
|
||||
Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev,
|
||||
absl::Span<const Tensor* const> inputs,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::vector<VariableInfo>* result);
|
||||
|
||||
// Returns pointers to inputs stored in `ctx`.
|
||||
std::vector<const Tensor*> InputsFromContext(OpKernelContext* ctx);
|
||||
|
||||
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
|
||||
// ShapedBuffers suitable for passing to an XLA computation.
|
||||
@ -136,10 +140,10 @@ class XlaComputationLaunchContext {
|
||||
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
|
||||
// op.
|
||||
// Precondition: variables in `variable_args` are locked.
|
||||
static Status BuildXlaCompilerArguments(
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
|
||||
std::vector<XlaCompiler::Argument>* args);
|
||||
static xla::StatusOr<std::vector<XlaCompiler::Argument>>
|
||||
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,
|
||||
absl::Span<const Tensor* const> inputs,
|
||||
absl::Span<VariableInfo const> variable_args);
|
||||
|
||||
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
|
||||
// `variables` is a map from TensorFlow argument number to resource variable.
|
||||
|
Loading…
x
Reference in New Issue
Block a user