[TF2XLA] [NFC] Simplify GetBodyAndConstantsAndResources
PiperOrigin-RevId: 329850852 Change-Id: I33b9dd04f14104e25e60cc31adc8f6dd68f24e8b
This commit is contained in:
parent
bedf4eb166
commit
5104953f4c
@ -530,16 +530,11 @@ bool CanCreateXlaKernel(const NodeDef& node_def) {
|
||||
}
|
||||
|
||||
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
const NameAttrList& function,
|
||||
const FunctionBody** fbody,
|
||||
std::vector<int>* constant_arg_indices,
|
||||
std::vector<int>* resource_arg_indices) {
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
// If node_def is not instantiable, e.g., the function does not exist,
|
||||
// simply bail out.
|
||||
NameAttrList function;
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
|
||||
*fbody = flr->GetFunctionBody(handle);
|
||||
|
@ -267,14 +267,13 @@ class RecursiveCompilabilityChecker {
|
||||
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||
const XlaOpRegistry::DeviceRegistration& registration);
|
||||
|
||||
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
|
||||
// runtime, returns this function's body in `fbody` as well as the indices
|
||||
// of its constant and resource arguments.
|
||||
// Given a FunctionLibraryRuntime and a `function`, returns this function's body
|
||||
// in `fbody` as well as the indices of its constant and resource arguments.
|
||||
// `fbody` is owned by `flr`.
|
||||
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
|
||||
// They are sorted in ascending order on this function's return.
|
||||
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
const NameAttrList& function,
|
||||
const FunctionBody** fbody,
|
||||
std::vector<int>* constant_arg_indices,
|
||||
std::vector<int>* resource_arg_indices);
|
||||
|
@ -38,10 +38,12 @@ Status ForceXlaConstantsOnHostPass::Run(
|
||||
std::vector<int> constant_arg_indices;
|
||||
std::vector<int> resource_arg_indices;
|
||||
|
||||
NameAttrList function;
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node->def(), &function));
|
||||
|
||||
// Force all constants to be on the host memory.
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, node->def(), &fbody, &constant_arg_indices,
|
||||
&resource_arg_indices));
|
||||
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
VLOG(3) << "Found constant arg indices: "
|
||||
<< absl::StrJoin(constant_arg_indices, ", ");
|
||||
|
||||
|
@ -122,11 +122,13 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
||||
}
|
||||
|
||||
// Get function body, constant args, and resource args.
|
||||
NameAttrList function;
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
const FunctionBody* fbody = nullptr;
|
||||
std::vector<int> constant_arg_indices;
|
||||
std::vector<int> resource_arg_indices;
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
|
||||
// Set input and output memory types.
|
||||
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
|
||||
@ -176,8 +178,6 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
||||
}
|
||||
|
||||
// Create the kernel.
|
||||
NameAttrList function;
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
Device* dev = flr->device();
|
||||
Status s;
|
||||
auto props = std::make_shared<NodeProperties>(
|
||||
|
Loading…
Reference in New Issue
Block a user