[TF2XLA] [NFC] Simplify GetVariableInfosFromInputs and BuildXlaCompilerArguments
PiperOrigin-RevId: 329848962 Change-Id: I26ae860a588e73046457257227f45f03158a20b1
This commit is contained in:
parent
d65352ddd3
commit
72d8f7ba27
@ -164,6 +164,7 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
|||||||
static Status CompileToLocalExecutable(
|
static Status CompileToLocalExecutable(
|
||||||
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
|
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
|
||||||
const XlaPlatformInfo& platform_info,
|
const XlaPlatformInfo& platform_info,
|
||||||
|
absl::Span<const Tensor* const> inputs,
|
||||||
absl::Span<VariableInfo const> variable_infos,
|
absl::Span<VariableInfo const> variable_infos,
|
||||||
absl::Span<const int> constants, bool lazy, bool may_alias_resource_update,
|
absl::Span<const int> constants, bool lazy, bool may_alias_resource_update,
|
||||||
xla::LocalClient** client,
|
xla::LocalClient** client,
|
||||||
@ -195,11 +196,6 @@ static Status CompileToLocalExecutable(
|
|||||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||||
platform_info, has_ref_vars, &tf_allocator_adapter);
|
platform_info, has_ref_vars, &tf_allocator_adapter);
|
||||||
|
|
||||||
std::map<int, Tensor> constant_args;
|
|
||||||
for (int i : constants) {
|
|
||||||
constant_args.insert({i, ctx->input(i)});
|
|
||||||
}
|
|
||||||
|
|
||||||
XlaCompiler::CompileOptions compile_options;
|
XlaCompiler::CompileOptions compile_options;
|
||||||
compile_options.is_entry_computation = true;
|
compile_options.is_entry_computation = true;
|
||||||
// Optimization: where possible, have the computation return a naked array
|
// Optimization: where possible, have the computation return a naked array
|
||||||
@ -209,10 +205,11 @@ static Status CompileToLocalExecutable(
|
|||||||
!platform_info.is_on_xla_device() &&
|
!platform_info.is_on_xla_device() &&
|
||||||
may_alias_resource_update;
|
may_alias_resource_update;
|
||||||
|
|
||||||
std::vector<XlaCompiler::Argument> args;
|
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs,
|
||||||
constant_args, variable_infos, ctx, &args));
|
variable_infos);
|
||||||
return cache->Compile(options, function, args, compile_options,
|
TF_RETURN_IF_ERROR(args.status());
|
||||||
|
return cache->Compile(options, function, *args, compile_options,
|
||||||
lazy ? XlaCompilationCache::CompileMode::kLazy
|
lazy ? XlaCompilationCache::CompileMode::kLazy
|
||||||
: XlaCompilationCache::CompileMode::kStrict,
|
: XlaCompilationCache::CompileMode::kStrict,
|
||||||
compilation_result, executable);
|
compilation_result, executable);
|
||||||
@ -222,6 +219,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
VLOG(1) << "XlaLocalLaunchOpBase::Compute "
|
VLOG(1) << "XlaLocalLaunchOpBase::Compute "
|
||||||
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
|
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
|
||||||
|
|
||||||
|
std::vector<const Tensor*> inputs = InputsFromContext(ctx);
|
||||||
xla::LocalClient* client;
|
xla::LocalClient* client;
|
||||||
const XlaCompiler::CompilationResult* compilation_result;
|
const XlaCompiler::CompilationResult* compilation_result;
|
||||||
xla::LocalExecutable* executable;
|
xla::LocalExecutable* executable;
|
||||||
@ -229,10 +227,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
std::vector<VariableInfo> variable_infos;
|
std::vector<VariableInfo> variable_infos;
|
||||||
{
|
{
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
|
||||||
|
inputs, resources_, &variable_infos));
|
||||||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
||||||
Status s = CompileToLocalExecutable(
|
Status s = CompileToLocalExecutable(
|
||||||
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
|
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, inputs,
|
||||||
variable_infos, constants_, /*lazy=*/false,
|
variable_infos, constants_, /*lazy=*/false,
|
||||||
/*may_alias_resource_update=*/true, &client, &compilation_result,
|
/*may_alias_resource_update=*/true, &client, &compilation_result,
|
||||||
&executable);
|
&executable);
|
||||||
@ -389,6 +388,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
|||||||
xla::LocalExecutable* executable;
|
xla::LocalExecutable* executable;
|
||||||
ResourceVarsSnapshot variables;
|
ResourceVarsSnapshot variables;
|
||||||
|
|
||||||
|
std::vector<const Tensor*> inputs = InputsFromContext(ctx);
|
||||||
bool cannot_compile_cluster;
|
bool cannot_compile_cluster;
|
||||||
{
|
{
|
||||||
mutex_lock guard(cannot_compile_cluster_mu_);
|
mutex_lock guard(cannot_compile_cluster_mu_);
|
||||||
@ -401,13 +401,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
|||||||
} else {
|
} else {
|
||||||
std::vector<VariableInfo> variable_infos;
|
std::vector<VariableInfo> variable_infos;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
|
||||||
|
inputs, resources_, &variable_infos));
|
||||||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
||||||
|
|
||||||
// Do not alias resource updates as locking variables in XlaCompile and
|
// Do not alias resource updates as locking variables in XlaCompile and
|
||||||
// unlocking them in XlaRun may lead to deadlocks.
|
// unlocking them in XlaRun may lead to deadlocks.
|
||||||
Status status = CompileToLocalExecutable(
|
Status status = CompileToLocalExecutable(
|
||||||
ctx, function_, has_ref_vars_, platform_info_, variable_infos,
|
ctx, function_, has_ref_vars_, platform_info_, inputs, variable_infos,
|
||||||
constants_,
|
constants_,
|
||||||
/*lazy=*/!must_compile_,
|
/*lazy=*/!must_compile_,
|
||||||
/*may_alias_resource_update=*/false, &client, &kernel, &executable);
|
/*may_alias_resource_update=*/false, &client, &kernel, &executable);
|
||||||
|
@ -103,26 +103,16 @@ Status XlaCompileOnDemandOp::Compile(
|
|||||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult** result,
|
OpKernelContext* ctx, const XlaCompiler::CompilationResult** result,
|
||||||
XlaCompilationCache** cache, ResourceVarsSnapshot* variable_args,
|
XlaCompilationCache** cache, ResourceVarsSnapshot* variable_args,
|
||||||
xla::LocalExecutable** executable) {
|
xla::LocalExecutable** executable) {
|
||||||
std::map<int, Tensor> constant_arguments;
|
|
||||||
|
|
||||||
std::vector<int> constant_input_indices;
|
std::vector<int> constant_input_indices;
|
||||||
TF_RETURN_IF_ERROR(GetCompileTimeConstInputs(
|
TF_RETURN_IF_ERROR(GetCompileTimeConstInputs(
|
||||||
&ctx->op_kernel(), &constant_input_indices, ctx->function_library()));
|
&ctx->op_kernel(), &constant_input_indices, ctx->function_library()));
|
||||||
CHECK(absl::c_is_sorted(constant_input_indices));
|
if (!absl::c_all_of(constant_input_indices, [&](int idx) {
|
||||||
|
return ctx->input_memory_type(idx) == HOST_MEMORY;
|
||||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
})) {
|
||||||
const Tensor& device_tensor = ctx->input(i);
|
return errors::Internal("Unexpected device placement for a constant input");
|
||||||
|
|
||||||
if (!constant_arguments.count(i)) {
|
|
||||||
if (absl::c_binary_search(constant_input_indices, i)) {
|
|
||||||
if (ctx->input_memory_type(i) != HOST_MEMORY) {
|
|
||||||
return errors::Internal(
|
|
||||||
"Expected constant argument not in host memory");
|
|
||||||
}
|
|
||||||
constant_arguments[i] = device_tensor;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
std::vector<const Tensor*> inputs = InputsFromContext(ctx);
|
||||||
|
|
||||||
// We store information about the JIT-compiled XLA computation
|
// We store information about the JIT-compiled XLA computation
|
||||||
// in the ResourceMgr.
|
// in the ResourceMgr.
|
||||||
@ -150,19 +140,23 @@ Status XlaCompileOnDemandOp::Compile(
|
|||||||
compile_options.always_return_tuple = false;
|
compile_options.always_return_tuple = false;
|
||||||
|
|
||||||
std::vector<int> variables_indices = GetResourceVariableIndices(ctx);
|
std::vector<int> variables_indices = GetResourceVariableIndices(ctx);
|
||||||
std::vector<XlaCompiler::Argument> args;
|
xla::StatusOr<std::vector<XlaCompiler::Argument>> args;
|
||||||
{
|
{
|
||||||
std::vector<VariableInfo> variable_infos;
|
std::vector<VariableInfo> variable_infos;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
GetVariableInfosFromCtxInputs(ctx, variables_indices, &variable_infos));
|
GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
|
||||||
|
inputs, variables_indices, &variable_infos));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||||
TF_RETURN_IF_ERROR(SnapshotResourceVariables(
|
TF_RETURN_IF_ERROR(SnapshotResourceVariables(
|
||||||
ctx, variables_indices, variable_infos, variable_args));
|
ctx, variables_indices, variable_infos, variable_args));
|
||||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
|
||||||
constant_arguments, variable_infos, ctx, &args));
|
args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||||
|
constant_input_indices, inputs, variable_infos);
|
||||||
|
TF_RETURN_IF_ERROR(args.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
return (*cache)->CompileSingleOp(options, args, ctx, compile_options, result,
|
return (*cache)->CompileSingleOp(options, *args, ctx, compile_options, result,
|
||||||
executable);
|
executable);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,19 +79,22 @@ VariableInfo::~VariableInfo() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a vector of VariableInfo instances for the resource variable inputs
|
Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev,
|
||||||
// to the kernel with context `ctx`. The input indices for the resource
|
absl::Span<const Tensor* const> inputs,
|
||||||
// variable inputs are in `variable_indices`.
|
absl::Span<const int> variable_indices,
|
||||||
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
|
std::vector<VariableInfo>* result) {
|
||||||
absl::Span<const int> variable_indices,
|
|
||||||
std::vector<VariableInfo>* result) {
|
|
||||||
result->clear();
|
result->clear();
|
||||||
result->reserve(variable_indices.size());
|
result->reserve(variable_indices.size());
|
||||||
for (int var_idx : variable_indices) {
|
for (int var_idx : variable_indices) {
|
||||||
Var* variable = nullptr;
|
Var* variable = nullptr;
|
||||||
ResourceHandle handle = HandleFromInput(ctx, var_idx);
|
ResourceHandle handle = inputs[var_idx]->flat<ResourceHandle>()(0);
|
||||||
TF_RETURN_IF_ERROR(
|
if (handle.device() != dev->attributes().name()) {
|
||||||
LookupOrCreateResource<Var>(ctx, handle, &variable, [&](Var** ptr) {
|
return errors::InvalidArgument("Trying to access resource ",
|
||||||
|
handle.name(), " located in device ",
|
||||||
|
dev->name());
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(rm->LookupOrCreate<Var>(
|
||||||
|
handle.container(), handle.name(), &variable, [](Var** ptr) {
|
||||||
// This var is uninitialized for now.
|
// This var is uninitialized for now.
|
||||||
*ptr = new Var(DT_INVALID);
|
*ptr = new Var(DT_INVALID);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -101,6 +104,15 @@ Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<const Tensor*> InputsFromContext(OpKernelContext* ctx) {
|
||||||
|
std::vector<const Tensor*> inputs;
|
||||||
|
inputs.reserve(ctx->num_inputs());
|
||||||
|
for (int input_idx = 0; input_idx < ctx->num_inputs(); input_idx++) {
|
||||||
|
inputs.push_back(&ctx->input(input_idx));
|
||||||
|
}
|
||||||
|
return inputs;
|
||||||
|
}
|
||||||
|
|
||||||
Status LockVariables(absl::Span<VariableInfo> variables) {
|
Status LockVariables(absl::Span<VariableInfo> variables) {
|
||||||
std::vector<int> lock_order(variables.size());
|
std::vector<int> lock_order(variables.size());
|
||||||
std::iota(lock_order.begin(), lock_order.end(), 0);
|
std::iota(lock_order.begin(), lock_order.end(), 0);
|
||||||
@ -548,11 +560,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
xla::StatusOr<std::vector<XlaCompiler::Argument>>
|
||||||
const std::map<int, Tensor>& must_be_constant_args,
|
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||||
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
|
absl::Span<int const> must_be_constant_idxs,
|
||||||
std::vector<XlaCompiler::Argument>* args) {
|
absl::Span<const Tensor* const> inputs,
|
||||||
args->resize(ctx->num_inputs());
|
absl::Span<VariableInfo const> variable_args) {
|
||||||
|
CHECK(absl::c_is_sorted(must_be_constant_idxs));
|
||||||
|
std::vector<XlaCompiler::Argument> out;
|
||||||
|
out.resize(inputs.size());
|
||||||
|
|
||||||
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
|
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
|
||||||
for (const VariableInfo& info : variable_args) {
|
for (const VariableInfo& info : variable_args) {
|
||||||
@ -562,33 +577,20 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
|||||||
variable_info_lookup.emplace(info.index(), &info);
|
variable_info_lookup.emplace(info.index(), &info);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
|
for (int64 input_num = 0; input_num < inputs.size(); ++input_num) {
|
||||||
XlaCompiler::Argument& arg = (*args)[input_num];
|
const Tensor* input = inputs[input_num];
|
||||||
|
|
||||||
if (must_be_constant_args.count(input_num) > 0) {
|
XlaCompiler::Argument& arg = out[input_num];
|
||||||
|
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
||||||
// Handles compile-time constants.
|
// Handles compile-time constants.
|
||||||
const Tensor& input = must_be_constant_args.at(input_num);
|
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
|
||||||
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
|
|
||||||
arg.kind = XlaCompiler::Argument::kConstant;
|
arg.kind = XlaCompiler::Argument::kConstant;
|
||||||
arg.type = input.dtype();
|
arg.type = input->dtype();
|
||||||
arg.shape = input.shape();
|
arg.shape = input->shape();
|
||||||
arg.constant_value = input;
|
arg.constant_value = *input;
|
||||||
} else if (variable_info_lookup.count(input_num) == 0) {
|
} else if (variable_info_lookup.count(input_num)) {
|
||||||
// Handles the non-constant arguments.
|
|
||||||
const Tensor& input = ctx->input(input_num);
|
|
||||||
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
|
|
||||||
if (input.NumElements() > 0) {
|
|
||||||
arg.kind = XlaCompiler::Argument::kParameter;
|
|
||||||
} else {
|
|
||||||
arg.kind = XlaCompiler::Argument::kConstant;
|
|
||||||
arg.constant_value = input;
|
|
||||||
}
|
|
||||||
arg.type = input.dtype();
|
|
||||||
arg.shape = input.shape();
|
|
||||||
} else {
|
|
||||||
// Handles resource variables.
|
// Handles resource variables.
|
||||||
const Tensor& input = ctx->input(input_num);
|
TF_RET_CHECK(input->dtype() == DT_RESOURCE);
|
||||||
TF_RET_CHECK(input.dtype() == DT_RESOURCE);
|
|
||||||
const VariableInfo& variable = *variable_info_lookup[input_num];
|
const VariableInfo& variable = *variable_info_lookup[input_num];
|
||||||
arg.name = std::string(variable.name());
|
arg.name = std::string(variable.name());
|
||||||
arg.kind = XlaCompiler::Argument::kResource;
|
arg.kind = XlaCompiler::Argument::kResource;
|
||||||
@ -607,10 +609,21 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
|||||||
arg.type = DT_INVALID;
|
arg.type = DT_INVALID;
|
||||||
arg.shape = TensorShape();
|
arg.shape = TensorShape();
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// Normal inputs.
|
||||||
|
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
|
||||||
|
if (input->NumElements() > 0) {
|
||||||
|
arg.kind = XlaCompiler::Argument::kParameter;
|
||||||
|
} else {
|
||||||
|
arg.kind = XlaCompiler::Argument::kConstant;
|
||||||
|
arg.constant_value = *input;
|
||||||
|
}
|
||||||
|
arg.type = input->dtype();
|
||||||
|
arg.shape = input->shape();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -109,12 +109,16 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
|
|||||||
Status LockVariables(absl::Span<VariableInfo> variables)
|
Status LockVariables(absl::Span<VariableInfo> variables)
|
||||||
TF_EXCLUSIVE_LOCK_FUNCTION();
|
TF_EXCLUSIVE_LOCK_FUNCTION();
|
||||||
|
|
||||||
// Returns a vector of VariableInfo instances for the resource variable inputs
|
// Returns a vector of VariableInfo instances for the resource variable inputs,
|
||||||
// to the kernel with context `ctx`. The input indices for the resource
|
// given that *all* inputs are in `inputs`. The input indices for the resource
|
||||||
// variable inputs are in `variable_indices`.
|
// variable inputs are in `variable_indices`.
|
||||||
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
|
Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev,
|
||||||
absl::Span<const int> variable_indices,
|
absl::Span<const Tensor* const> inputs,
|
||||||
std::vector<VariableInfo>* result);
|
absl::Span<const int> variable_indices,
|
||||||
|
std::vector<VariableInfo>* result);
|
||||||
|
|
||||||
|
// Returns pointers to inputs stored in `ctx`.
|
||||||
|
std::vector<const Tensor*> InputsFromContext(OpKernelContext* ctx);
|
||||||
|
|
||||||
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
|
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
|
||||||
// ShapedBuffers suitable for passing to an XLA computation.
|
// ShapedBuffers suitable for passing to an XLA computation.
|
||||||
@ -136,10 +140,10 @@ class XlaComputationLaunchContext {
|
|||||||
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
|
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
|
||||||
// op.
|
// op.
|
||||||
// Precondition: variables in `variable_args` are locked.
|
// Precondition: variables in `variable_args` are locked.
|
||||||
static Status BuildXlaCompilerArguments(
|
static xla::StatusOr<std::vector<XlaCompiler::Argument>>
|
||||||
const std::map<int, Tensor>& constant_args,
|
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,
|
||||||
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
|
absl::Span<const Tensor* const> inputs,
|
||||||
std::vector<XlaCompiler::Argument>* args);
|
absl::Span<VariableInfo const> variable_args);
|
||||||
|
|
||||||
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
|
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
|
||||||
// `variables` is a map from TensorFlow argument number to resource variable.
|
// `variables` is a map from TensorFlow argument number to resource variable.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user