[TF2XLA] [NFC] Simplify GetBodyAndConstantsAndResources
PiperOrigin-RevId: 329850852 Change-Id: I33b9dd04f14104e25e60cc31adc8f6dd68f24e8b
This commit is contained in:
parent
bedf4eb166
commit
5104953f4c
@ -530,16 +530,11 @@ bool CanCreateXlaKernel(const NodeDef& node_def) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||||
const NodeDef& node_def,
|
const NameAttrList& function,
|
||||||
const FunctionBody** fbody,
|
const FunctionBody** fbody,
|
||||||
std::vector<int>* constant_arg_indices,
|
std::vector<int>* constant_arg_indices,
|
||||||
std::vector<int>* resource_arg_indices) {
|
std::vector<int>* resource_arg_indices) {
|
||||||
FunctionLibraryRuntime::Handle handle;
|
FunctionLibraryRuntime::Handle handle;
|
||||||
// If node_def is not instantiable, e.g., the function does not exist,
|
|
||||||
// simply bail out.
|
|
||||||
NameAttrList function;
|
|
||||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
|
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
|
||||||
*fbody = flr->GetFunctionBody(handle);
|
*fbody = flr->GetFunctionBody(handle);
|
||||||
|
@ -267,14 +267,13 @@ class RecursiveCompilabilityChecker {
|
|||||||
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||||
const XlaOpRegistry::DeviceRegistration& registration);
|
const XlaOpRegistry::DeviceRegistration& registration);
|
||||||
|
|
||||||
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
|
// Given a FunctionLibraryRuntime and a `function`, returns this function's body
|
||||||
// runtime, returns this function's body in `fbody` as well as the indices
|
// in `fbody` as well as the indices of its constant and resource arguments.
|
||||||
// of its constant and resource arguments.
|
|
||||||
// `fbody` is owned by `flr`.
|
// `fbody` is owned by `flr`.
|
||||||
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
|
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
|
||||||
// They are sorted in ascending order on this function's return.
|
// They are sorted in ascending order on this function's return.
|
||||||
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||||
const NodeDef& node_def,
|
const NameAttrList& function,
|
||||||
const FunctionBody** fbody,
|
const FunctionBody** fbody,
|
||||||
std::vector<int>* constant_arg_indices,
|
std::vector<int>* constant_arg_indices,
|
||||||
std::vector<int>* resource_arg_indices);
|
std::vector<int>* resource_arg_indices);
|
||||||
|
@ -38,10 +38,12 @@ Status ForceXlaConstantsOnHostPass::Run(
|
|||||||
std::vector<int> constant_arg_indices;
|
std::vector<int> constant_arg_indices;
|
||||||
std::vector<int> resource_arg_indices;
|
std::vector<int> resource_arg_indices;
|
||||||
|
|
||||||
|
NameAttrList function;
|
||||||
|
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node->def(), &function));
|
||||||
|
|
||||||
// Force all constants to be on the host memory.
|
// Force all constants to be on the host memory.
|
||||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||||
flr, node->def(), &fbody, &constant_arg_indices,
|
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||||
&resource_arg_indices));
|
|
||||||
VLOG(3) << "Found constant arg indices: "
|
VLOG(3) << "Found constant arg indices: "
|
||||||
<< absl::StrJoin(constant_arg_indices, ", ");
|
<< absl::StrJoin(constant_arg_indices, ", ");
|
||||||
|
|
||||||
|
@ -122,11 +122,13 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get function body, constant args, and resource args.
|
// Get function body, constant args, and resource args.
|
||||||
|
NameAttrList function;
|
||||||
|
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||||
const FunctionBody* fbody = nullptr;
|
const FunctionBody* fbody = nullptr;
|
||||||
std::vector<int> constant_arg_indices;
|
std::vector<int> constant_arg_indices;
|
||||||
std::vector<int> resource_arg_indices;
|
std::vector<int> resource_arg_indices;
|
||||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||||
flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
|
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||||
|
|
||||||
// Set input and output memory types.
|
// Set input and output memory types.
|
||||||
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
|
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
|
||||||
@ -176,8 +178,6 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create the kernel.
|
// Create the kernel.
|
||||||
NameAttrList function;
|
|
||||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
|
||||||
Device* dev = flr->device();
|
Device* dev = flr->device();
|
||||||
Status s;
|
Status s;
|
||||||
auto props = std::make_shared<NodeProperties>(
|
auto props = std::make_shared<NodeProperties>(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user