Prune the FunctionLibraryDefinition in ClusterFLR::ConstructFunctionGraph().
This change avoids unnecessarily copying FunctionDef protos that are unused when constructing the graph for a remote function. PiperOrigin-RevId: 274359615
This commit is contained in:
parent
6204b3b27f
commit
deadb15945
@ -38,8 +38,20 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
|
||||
const FunctionLibraryDefinition& flib_def, GraphDef* gdef,
|
||||
std::vector<string>* send_keys, std::vector<string>* recv_keys) {
|
||||
const string& target = options.target;
|
||||
const string& func_name = sig.name();
|
||||
const FunctionDef* func_def = flib_def.Find(sig.name());
|
||||
if (func_def == nullptr) {
|
||||
return errors::InvalidArgument("Function ", func_name,
|
||||
" not found in flib_def.");
|
||||
}
|
||||
|
||||
Graph g(flib_def);
|
||||
// Build a smaller flib_def containing only the functions used by the given
|
||||
// function, plus that function itself.
|
||||
FunctionLibraryDefinition pruned_flib_def =
|
||||
flib_def.ReachableDefinitions(*func_def);
|
||||
TF_RETURN_IF_ERROR(pruned_flib_def.CopyFunctionDefFrom(func_name, flib_def));
|
||||
|
||||
Graph g(pruned_flib_def);
|
||||
|
||||
std::vector<Node*> input_nodes;
|
||||
input_nodes.reserve(sig.input_arg_size());
|
||||
@ -82,8 +94,8 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
|
||||
}
|
||||
|
||||
NodeDef function_node_def;
|
||||
function_node_def.set_name(sig.name());
|
||||
function_node_def.set_op(sig.name());
|
||||
function_node_def.set_name(func_name);
|
||||
function_node_def.set_op(func_name);
|
||||
i = 0;
|
||||
function_node_def.set_device(target);
|
||||
for (const auto& p : attrs) {
|
||||
@ -112,7 +124,7 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
|
||||
|
||||
auto output_node_builder =
|
||||
NodeDefBuilder(strings::StrCat("_send_", out.name(), "_", i), "_Send")
|
||||
.Input(sig.name(), i, dtypes[0])
|
||||
.Input(func_name, i, dtypes[0])
|
||||
.Attr("tensor_name", out.name())
|
||||
.Attr("send_device", target)
|
||||
.Attr("recv_device", target)
|
||||
@ -144,9 +156,9 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
|
||||
// inlined graph.
|
||||
inline_options.uniquify_frame_names = false;
|
||||
std::unique_ptr<FunctionBody> function_body;
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*flib_def.Find(sig.name()), attrs,
|
||||
&flib_def, &function_body));
|
||||
TF_RETURN_IF_ERROR(InlineFunctionBody(flib_def, &g, function_node,
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*func_def, attrs, &pruned_flib_def,
|
||||
&function_body));
|
||||
TF_RETURN_IF_ERROR(InlineFunctionBody(pruned_flib_def, &g, function_node,
|
||||
function_body.get(), inline_options));
|
||||
|
||||
g.ToGraphDef(gdef);
|
||||
|
Loading…
x
Reference in New Issue
Block a user