diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index 99f13180722..2ca9b720ee1 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -159,6 +159,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", ], ) diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index ca3c1a66672..866f87688c8 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -450,8 +450,11 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( } // Instantiate all the missing attributes with their default values. - Status attr_status = - AddDefaultAttrsToGraphDef(&new_item->graph, *OpRegistry::Global(), 0); + Status attr_status = AddDefaultAttrsToGraphDef( + &new_item->graph, + FunctionLibraryDefinition(OpRegistry::Global(), + new_item->graph.library()), + 0); if (!attr_status.ok()) { LOG(ERROR) << "Failed to instantiate default attribute values: " << attr_status.error_message(); diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc index 4272179d3cb..09d9aa4ef19 100644 --- a/tensorflow/core/grappler/grappler_item_builder_test.cc +++ b/tensorflow/core/grappler/grappler_item_builder_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/cc/gradients/grad_testutil.h" #include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" @@ -253,6 +254,31 @@ TEST_F(GrapplerItemBuilderTest, AssetFilepathOverrideTest_FileNotAccessible) { ASSERT_TRUE(item == nullptr); } +TEST_F(GrapplerItemBuilderTest, GraphWithFunctions) { + MetaGraphDef meta_graph; + // y = XTimesTwo(x) + constexpr char device[] = "/cpu:0"; + *meta_graph.mutable_graph_def() = test::function::GDef( + {test::function::NDef("x", "Const", {}, {{"dtype", DT_FLOAT}}, device), + test::function::NDef("y", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, + device)}, + // FunctionLib + { + test::function::XTimesTwo(), + }); + + CollectionDef train_op; + train_op.mutable_node_list()->add_value("y"); + (*meta_graph.mutable_collection_def())["train_op"] = train_op; + + ItemConfig cfg; + cfg.inline_functions = false; + + std::unique_ptr item = + GrapplerItemFromMetaGraphDef("0", meta_graph, cfg); + ASSERT_TRUE(item != nullptr); +} + } // namespace } // namespace grappler } // namespace tensorflow