[TF2XLA] [NFC] Simplify GetBodyAndConstantsAndResources

PiperOrigin-RevId: 329850852
Change-Id: I33b9dd04f14104e25e60cc31adc8f6dd68f24e8b
This commit is contained in:
George Karpenkov 2020-09-02 21:58:34 -07:00 committed by TensorFlower Gardener
parent bedf4eb166
commit 5104953f4c
4 changed files with 11 additions and 15 deletions

View File

@ -530,16 +530,11 @@ bool CanCreateXlaKernel(const NodeDef& node_def) {
}
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
const NodeDef& node_def,
const NameAttrList& function,
const FunctionBody** fbody,
std::vector<int>* constant_arg_indices,
std::vector<int>* resource_arg_indices) {
FunctionLibraryRuntime::Handle handle;
// If node_def is not instantiable, e.g., the function does not exist,
// simply bail out.
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
TF_RETURN_IF_ERROR(
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
*fbody = flr->GetFunctionBody(handle);

View File

@ -267,14 +267,13 @@ class RecursiveCompilabilityChecker {
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
const XlaOpRegistry::DeviceRegistration& registration);
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
// runtime, returns this function's body in `fbody` as well as the indices
// of its constant and resource arguments.
// Given a FunctionLibraryRuntime and a `function`, returns this function's body
// in `fbody` as well as the indices of its constant and resource arguments.
// `fbody` is owned by `flr`.
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
// They are sorted in ascending order on this function's return.
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
const NodeDef& node_def,
const NameAttrList& function,
const FunctionBody** fbody,
std::vector<int>* constant_arg_indices,
std::vector<int>* resource_arg_indices);

View File

@ -38,10 +38,12 @@ Status ForceXlaConstantsOnHostPass::Run(
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node->def(), &function));
// Force all constants to be on the host memory.
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, node->def(), &fbody, &constant_arg_indices,
&resource_arg_indices));
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
VLOG(3) << "Found constant arg indices: "
<< absl::StrJoin(constant_arg_indices, ", ");

View File

@ -122,11 +122,13 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
}
// Get function body, constant args, and resource args.
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
// Set input and output memory types.
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
@ -176,8 +178,6 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
}
// Create the kernel.
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
Device* dev = flr->device();
Status s;
auto props = std::make_shared<NodeProperties>(