Fix support for functions to grappler items.

PiperOrigin-RevId: 179429486
This commit is contained in:
A. Unique TensorFlower 2017-12-18 09:03:42 -08:00 committed by TensorFlower Gardener
parent 13a8558846
commit 566df46de2
3 changed files with 32 additions and 2 deletions

View File

@ -159,6 +159,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
],
)

View File

@ -450,8 +450,11 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
}
// Instantiate all the missing attributes with their default values.
Status attr_status =
AddDefaultAttrsToGraphDef(&new_item->graph, *OpRegistry::Global(), 0);
Status attr_status = AddDefaultAttrsToGraphDef(
&new_item->graph,
FunctionLibraryDefinition(OpRegistry::Global(),
new_item->graph.library()),
0);
if (!attr_status.ok()) {
LOG(ERROR) << "Failed to instantiate default attribute values: "
<< attr_status.error_message();

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/cc/gradients/grad_testutil.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
@ -253,6 +254,31 @@ TEST_F(GrapplerItemBuilderTest, AssetFilepathOverrideTest_FileNotAccessible) {
ASSERT_TRUE(item == nullptr);
}
TEST_F(GrapplerItemBuilderTest, GraphWithFunctions) {
MetaGraphDef meta_graph;
// y = XTimesTwo(x)
constexpr char device[] = "/cpu:0";
*meta_graph.mutable_graph_def() = test::function::GDef(
{test::function::NDef("x", "Const", {}, {{"dtype", DT_FLOAT}}, device),
test::function::NDef("y", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}},
device)},
// FunctionLib
{
test::function::XTimesTwo(),
});
CollectionDef train_op;
train_op.mutable_node_list()->add_value("y");
(*meta_graph.mutable_collection_def())["train_op"] = train_op;
ItemConfig cfg;
cfg.inline_functions = false;
std::unique_ptr<GrapplerItem> item =
GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
ASSERT_TRUE(item != nullptr);
}
} // namespace
} // namespace grappler
} // namespace tensorflow