Fix support for functions to grappler items.
PiperOrigin-RevId: 179429486
This commit is contained in:
parent
13a8558846
commit
566df46de2
@ -159,6 +159,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
],
|
||||
)
|
||||
|
@ -450,8 +450,11 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
}
|
||||
|
||||
// Instantiate all the missing attributes with their default values.
|
||||
Status attr_status =
|
||||
AddDefaultAttrsToGraphDef(&new_item->graph, *OpRegistry::Global(), 0);
|
||||
Status attr_status = AddDefaultAttrsToGraphDef(
|
||||
&new_item->graph,
|
||||
FunctionLibraryDefinition(OpRegistry::Global(),
|
||||
new_item->graph.library()),
|
||||
0);
|
||||
if (!attr_status.ok()) {
|
||||
LOG(ERROR) << "Failed to instantiate default attribute values: "
|
||||
<< attr_status.error_message();
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/gradients/grad_testutil.h"
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
@ -253,6 +254,31 @@ TEST_F(GrapplerItemBuilderTest, AssetFilepathOverrideTest_FileNotAccessible) {
|
||||
ASSERT_TRUE(item == nullptr);
|
||||
}
|
||||
|
||||
TEST_F(GrapplerItemBuilderTest, GraphWithFunctions) {
|
||||
MetaGraphDef meta_graph;
|
||||
// y = XTimesTwo(x)
|
||||
constexpr char device[] = "/cpu:0";
|
||||
*meta_graph.mutable_graph_def() = test::function::GDef(
|
||||
{test::function::NDef("x", "Const", {}, {{"dtype", DT_FLOAT}}, device),
|
||||
test::function::NDef("y", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}},
|
||||
device)},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
});
|
||||
|
||||
CollectionDef train_op;
|
||||
train_op.mutable_node_list()->add_value("y");
|
||||
(*meta_graph.mutable_collection_def())["train_op"] = train_op;
|
||||
|
||||
ItemConfig cfg;
|
||||
cfg.inline_functions = false;
|
||||
|
||||
std::unique_ptr<GrapplerItem> item =
|
||||
GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
|
||||
ASSERT_TRUE(item != nullptr);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user