Fix support for functions to grappler items.
PiperOrigin-RevId: 179429486
This commit is contained in:
parent
13a8558846
commit
566df46de2
@ -159,6 +159,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -450,8 +450,11 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Instantiate all the missing attributes with their default values.
|
// Instantiate all the missing attributes with their default values.
|
||||||
Status attr_status =
|
Status attr_status = AddDefaultAttrsToGraphDef(
|
||||||
AddDefaultAttrsToGraphDef(&new_item->graph, *OpRegistry::Global(), 0);
|
&new_item->graph,
|
||||||
|
FunctionLibraryDefinition(OpRegistry::Global(),
|
||||||
|
new_item->graph.library()),
|
||||||
|
0);
|
||||||
if (!attr_status.ok()) {
|
if (!attr_status.ok()) {
|
||||||
LOG(ERROR) << "Failed to instantiate default attribute values: "
|
LOG(ERROR) << "Failed to instantiate default attribute values: "
|
||||||
<< attr_status.error_message();
|
<< attr_status.error_message();
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/gradients/grad_testutil.h"
|
#include "tensorflow/cc/gradients/grad_testutil.h"
|
||||||
#include "tensorflow/cc/ops/functional_ops.h"
|
#include "tensorflow/cc/ops/functional_ops.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/function_testlib.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||||
@ -253,6 +254,31 @@ TEST_F(GrapplerItemBuilderTest, AssetFilepathOverrideTest_FileNotAccessible) {
|
|||||||
ASSERT_TRUE(item == nullptr);
|
ASSERT_TRUE(item == nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GrapplerItemBuilderTest, GraphWithFunctions) {
|
||||||
|
MetaGraphDef meta_graph;
|
||||||
|
// y = XTimesTwo(x)
|
||||||
|
constexpr char device[] = "/cpu:0";
|
||||||
|
*meta_graph.mutable_graph_def() = test::function::GDef(
|
||||||
|
{test::function::NDef("x", "Const", {}, {{"dtype", DT_FLOAT}}, device),
|
||||||
|
test::function::NDef("y", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}},
|
||||||
|
device)},
|
||||||
|
// FunctionLib
|
||||||
|
{
|
||||||
|
test::function::XTimesTwo(),
|
||||||
|
});
|
||||||
|
|
||||||
|
CollectionDef train_op;
|
||||||
|
train_op.mutable_node_list()->add_value("y");
|
||||||
|
(*meta_graph.mutable_collection_def())["train_op"] = train_op;
|
||||||
|
|
||||||
|
ItemConfig cfg;
|
||||||
|
cfg.inline_functions = false;
|
||||||
|
|
||||||
|
std::unique_ptr<GrapplerItem> item =
|
||||||
|
GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
|
||||||
|
ASSERT_TRUE(item != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
x
Reference in New Issue
Block a user