diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc index 3492949dafa..a8217bd3d11 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc @@ -38,8 +38,20 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph( const FunctionLibraryDefinition& flib_def, GraphDef* gdef, std::vector* send_keys, std::vector* recv_keys) { const string& target = options.target; + const string& func_name = sig.name(); + const FunctionDef* func_def = flib_def.Find(sig.name()); + if (func_def == nullptr) { + return errors::InvalidArgument("Function ", func_name, + " not found in flib_def."); + } - Graph g(flib_def); + // Build a smaller flib_def containing only the functions used by the given + // function, plus that function itself. + FunctionLibraryDefinition pruned_flib_def = + flib_def.ReachableDefinitions(*func_def); + TF_RETURN_IF_ERROR(pruned_flib_def.CopyFunctionDefFrom(func_name, flib_def)); + + Graph g(pruned_flib_def); std::vector input_nodes; input_nodes.reserve(sig.input_arg_size()); @@ -82,8 +94,8 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph( } NodeDef function_node_def; - function_node_def.set_name(sig.name()); - function_node_def.set_op(sig.name()); + function_node_def.set_name(func_name); + function_node_def.set_op(func_name); i = 0; function_node_def.set_device(target); for (const auto& p : attrs) { @@ -112,7 +124,7 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph( auto output_node_builder = NodeDefBuilder(strings::StrCat("_send_", out.name(), "_", i), "_Send") - .Input(sig.name(), i, dtypes[0]) + .Input(func_name, i, dtypes[0]) .Attr("tensor_name", out.name()) .Attr("send_device", target) .Attr("recv_device", target) @@ -144,9 +156,9 @@ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph( // inlined graph. inline_options.uniquify_frame_names = false; std::unique_ptr function_body; - TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*flib_def.Find(sig.name()), attrs, - &flib_def, &function_body)); - TF_RETURN_IF_ERROR(InlineFunctionBody(flib_def, &g, function_node, + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*func_def, attrs, &pruned_flib_def, + &function_body)); + TF_RETURN_IF_ERROR(InlineFunctionBody(pruned_flib_def, &g, function_node, function_body.get(), inline_options)); g.ToGraphDef(gdef);