Prune the FunctionLibraryDefinition in ClusterFLR::ConstructFunctionGraph().

This change avoids unnecessarily copying FunctionDef protos that are unused when constructing the graph for a remote function.

PiperOrigin-RevId: 274359615
This commit is contained in:
Derek Murray 2019-10-12 11:06:10 -07:00 committed by TensorFlower Gardener
parent 6204b3b27f
commit deadb15945

View File

@ -38,8 +38,20 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
const FunctionLibraryDefinition& flib_def, GraphDef* gdef,
std::vector<string>* send_keys, std::vector<string>* recv_keys) {
const string& target = options.target;
const string& func_name = sig.name();
const FunctionDef* func_def = flib_def.Find(sig.name());
if (func_def == nullptr) {
return errors::InvalidArgument("Function ", func_name,
" not found in flib_def.");
}
Graph g(flib_def);
// Build a smaller flib_def containing only the functions used by the given
// function, plus that function itself.
FunctionLibraryDefinition pruned_flib_def =
flib_def.ReachableDefinitions(*func_def);
TF_RETURN_IF_ERROR(pruned_flib_def.CopyFunctionDefFrom(func_name, flib_def));
Graph g(pruned_flib_def);
std::vector<Node*> input_nodes;
input_nodes.reserve(sig.input_arg_size());
@ -82,8 +94,8 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
}
NodeDef function_node_def;
function_node_def.set_name(sig.name());
function_node_def.set_op(sig.name());
function_node_def.set_name(func_name);
function_node_def.set_op(func_name);
i = 0;
function_node_def.set_device(target);
for (const auto& p : attrs) {
@ -112,7 +124,7 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
auto output_node_builder =
NodeDefBuilder(strings::StrCat("_send_", out.name(), "_", i), "_Send")
.Input(sig.name(), i, dtypes[0])
.Input(func_name, i, dtypes[0])
.Attr("tensor_name", out.name())
.Attr("send_device", target)
.Attr("recv_device", target)
@ -144,9 +156,9 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
// inlined graph.
inline_options.uniquify_frame_names = false;
std::unique_ptr<FunctionBody> function_body;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*flib_def.Find(sig.name()), attrs,
&flib_def, &function_body));
TF_RETURN_IF_ERROR(InlineFunctionBody(flib_def, &g, function_node,
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*func_def, attrs, &pruned_flib_def,
&function_body));
TF_RETURN_IF_ERROR(InlineFunctionBody(pruned_flib_def, &g, function_node,
function_body.get(), inline_options));
g.ToGraphDef(gdef);