Also pass the ConfigProto through distributed function calls both in the standard graph registration mode and in the new eager master setup. The PFLR stores a std::optional<ConfigProto> instead of a pointer, because it may be created with a pointer that would dangle after its creation. At the same time, we need to know if a ConfigProto was available at creation time, which is why it's a std::optional. In contrast, the FLR gets a pointer directly because it is given a valid pointer that will outlast it in all cases. PiperOrigin-RevId: 272763578
504 lines
19 KiB
C++
504 lines
19 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
|
|
|
#include "absl/strings/match.h"
|
|
#include "absl/strings/str_cat.h"
|
|
#include "absl/strings/string_view.h"
|
|
#include "tensorflow/cc/framework/ops.h"
|
|
#include "tensorflow/cc/ops/data_flow_ops.h"
|
|
#include "tensorflow/cc/ops/function_ops.h"
|
|
#include "tensorflow/cc/ops/functional_ops.h"
|
|
#include "tensorflow/cc/ops/list_ops.h"
|
|
#include "tensorflow/cc/ops/standard_ops.h"
|
|
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
|
#include "tensorflow/core/common_runtime/function.h"
|
|
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
|
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
|
#include "tensorflow/core/framework/function.h"
|
|
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
|
#include "tensorflow/core/framework/node_def.pb.h"
|
|
#include "tensorflow/core/graph/graph.h"
|
|
#include "tensorflow/core/lib/core/status.h"
|
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
|
#include "tensorflow/core/platform/test.h"
|
|
#include "tensorflow/core/public/version.h"
|
|
|
|
namespace tensorflow {
|
|
namespace {
|
|
|
|
void ExpectErrorContains(const Status& status, absl::string_view str) {
|
|
EXPECT_NE(Status::OK(), status);
|
|
EXPECT_TRUE(absl::StrContains(status.error_message(), str))
|
|
<< "expected error: " << status.error_message() << " to contain: " << str;
|
|
}
|
|
|
|
TEST(ValidateConfig, Good) {
|
|
tf2xla::Config config;
|
|
tf2xla::Feed* feed = config.add_feed();
|
|
feed->mutable_id()->set_node_name("foo");
|
|
feed->mutable_id()->set_output_index(123);
|
|
feed->set_name("foo_debug");
|
|
feed = config.add_feed();
|
|
feed->mutable_id()->set_node_name("bar");
|
|
feed->mutable_id()->set_output_index(0);
|
|
tf2xla::Fetch* fetch = config.add_fetch();
|
|
fetch->mutable_id()->set_node_name("baz");
|
|
fetch->mutable_id()->set_output_index(456);
|
|
fetch->set_name("baz_debug");
|
|
fetch = config.add_fetch();
|
|
fetch->mutable_id()->set_node_name("banana");
|
|
fetch->mutable_id()->set_output_index(0);
|
|
TF_EXPECT_OK(ValidateConfig(config));
|
|
}
|
|
|
|
TEST(ValidateConfig, BadEmpty) {
|
|
tf2xla::Config config;
|
|
ExpectErrorContains(ValidateConfig(config), "fetches must be specified");
|
|
}
|
|
|
|
TEST(ValidateConfig, BadNoFetch) {
|
|
tf2xla::Config config;
|
|
tf2xla::Feed* feed = config.add_feed();
|
|
feed->mutable_id()->set_node_name("foo");
|
|
ExpectErrorContains(ValidateConfig(config), "fetches must be specified");
|
|
}
|
|
|
|
TEST(ValidateConfig, BadFeedNodeName) {
|
|
tf2xla::Config config;
|
|
config.add_feed();
|
|
ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
|
|
}
|
|
|
|
TEST(ValidateConfig, BadFeedOutputIndex) {
|
|
tf2xla::Config config;
|
|
tf2xla::Feed* feed = config.add_feed();
|
|
feed->mutable_id()->set_node_name("foo");
|
|
feed->mutable_id()->set_output_index(-1);
|
|
ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
|
|
}
|
|
|
|
TEST(ValidateConfig, BadFetchNodeName) {
|
|
tf2xla::Config config;
|
|
tf2xla::Feed* feed = config.add_feed();
|
|
feed->mutable_id()->set_node_name("foo");
|
|
config.add_fetch();
|
|
ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
|
|
}
|
|
|
|
TEST(ValidateConfig, BadFetchOutputIndex) {
|
|
tf2xla::Config config;
|
|
tf2xla::Feed* feed = config.add_feed();
|
|
feed->mutable_id()->set_node_name("foo");
|
|
tf2xla::Fetch* fetch = config.add_fetch();
|
|
fetch->mutable_id()->set_node_name("bar");
|
|
fetch->mutable_id()->set_output_index(-1);
|
|
ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
|
|
}
|
|
|
|
TEST(ValidateConfig, DuplicateFeedName) {
|
|
tf2xla::Config config;
|
|
tf2xla::Feed* feed = config.add_feed();
|
|
feed->mutable_id()->set_node_name("foo");
|
|
feed->set_name("dup");
|
|
feed = config.add_feed();
|
|
feed->mutable_id()->set_node_name("bar");
|
|
feed->set_name("dup");
|
|
ExpectErrorContains(ValidateConfig(config), "duplicate feed name");
|
|
}
|
|
|
|
TEST(ValidateConfig, DuplicateFetchName) {
|
|
tf2xla::Config config;
|
|
tf2xla::Feed* feed = config.add_feed();
|
|
feed->mutable_id()->set_node_name("foo");
|
|
tf2xla::Fetch* fetch = config.add_fetch();
|
|
fetch->mutable_id()->set_node_name("bar");
|
|
fetch->set_name("dup");
|
|
fetch = config.add_fetch();
|
|
fetch->mutable_id()->set_node_name("baz");
|
|
fetch->set_name("dup");
|
|
ExpectErrorContains(ValidateConfig(config), "duplicate fetch name");
|
|
}
|
|
|
|
TEST(ValidateConfig, ConflictingFeedName) {
|
|
tf2xla::Config config;
|
|
tf2xla::Feed* feed = config.add_feed();
|
|
feed->mutable_id()->set_node_name("foo");
|
|
feed->set_name("conflict");
|
|
feed = config.add_feed();
|
|
feed->mutable_id()->set_node_name("bar");
|
|
feed->set_name("conflict_data");
|
|
ExpectErrorContains(ValidateConfig(config), "conflicting feed name");
|
|
}
|
|
|
|
TEST(ValidateConfig, ConflictingFetchName) {
|
|
tf2xla::Config config;
|
|
tf2xla::Feed* feed = config.add_feed();
|
|
feed->mutable_id()->set_node_name("foo");
|
|
tf2xla::Fetch* fetch = config.add_fetch();
|
|
fetch->mutable_id()->set_node_name("bar");
|
|
fetch->set_name("conflict");
|
|
fetch = config.add_fetch();
|
|
fetch->mutable_id()->set_node_name("baz");
|
|
fetch->set_name("conflict_data");
|
|
ExpectErrorContains(ValidateConfig(config), "conflicting fetch name");
|
|
}
|
|
|
|
static tf2xla::Config FetchesConfig(std::vector<string> fetches) {
|
|
tf2xla::Config config;
|
|
for (const auto& fetch_node_name : fetches) {
|
|
auto* fetch = config.add_fetch();
|
|
fetch->set_name(absl::StrCat("fetch_", fetch_node_name));
|
|
fetch->mutable_id()->set_node_name(fetch_node_name);
|
|
}
|
|
return config;
|
|
}
|
|
|
|
TEST(PruneGraphDefInto, Basic) {
|
|
GraphDef def;
|
|
auto* n = def.add_node();
|
|
n->set_name("a");
|
|
n->add_input("b:0");
|
|
n->add_input("^c");
|
|
|
|
GraphDef copy;
|
|
ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"missing"}), def, ©),
|
|
"node missing needed");
|
|
ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, ©),
|
|
"node b needed");
|
|
|
|
n = def.add_node();
|
|
n->set_name("b");
|
|
ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, ©),
|
|
"node c needed");
|
|
n->add_input("d:1");
|
|
|
|
n = def.add_node();
|
|
n->set_name("c");
|
|
n->add_input("d:1");
|
|
|
|
n = def.add_node();
|
|
n->set_name("d");
|
|
|
|
// Graph is full, no pruning done.
|
|
// Graph right now has diamond from d:
|
|
// d --> b --> a
|
|
// d --> c --> a
|
|
TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, ©));
|
|
EXPECT_EQ(def.DebugString(), copy.DebugString());
|
|
GraphDef pruned_a = copy;
|
|
|
|
// Add some unrelated fields that use b and c, but are not needed for a.
|
|
n = def.add_node();
|
|
n->set_name("e");
|
|
n->add_input("^d");
|
|
n->add_input("b:2");
|
|
copy.Clear();
|
|
TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, ©));
|
|
EXPECT_EQ(pruned_a.DebugString(), copy.DebugString());
|
|
|
|
// Fetch "a" and "e" to get the original graph.
|
|
copy.Clear();
|
|
TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a", "e"}), def, ©));
|
|
EXPECT_EQ(def.DebugString(), copy.DebugString());
|
|
}
|
|
|
|
TEST(SetNodeShardingFromNeighbors, Basic) {
|
|
// Builds a graph that adds two Tensors.
|
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
|
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
|
|
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
|
|
auto c = ops::Add(scope.WithOpName("C"), a, b);
|
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
|
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
|
|
|
Node* a_node = nullptr;
|
|
Node* b_node = nullptr;
|
|
Node* c_node = nullptr;
|
|
for (Node* n : graph->nodes()) {
|
|
if (n->name() == "A") a_node = n;
|
|
if (n->name() == "B") b_node = n;
|
|
if (n->name() == "C") c_node = n;
|
|
}
|
|
|
|
const int num_cores_per_replica = 4;
|
|
|
|
a_node->set_assigned_device_name("foo");
|
|
EXPECT_FALSE(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false).ok());
|
|
|
|
// Test where one input to c_node has a device.
|
|
a_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:2");
|
|
TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false));
|
|
auto parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica);
|
|
TF_ASSERT_OK(parse_status.status());
|
|
ASSERT_TRUE(parse_status.ValueOrDie().has_value());
|
|
EXPECT_EQ(2, parse_status.ValueOrDie().value().tile_assignment_devices(0));
|
|
|
|
// Test where two inputs to c_node have a device.
|
|
b_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:1");
|
|
TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false));
|
|
parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica);
|
|
TF_ASSERT_OK(parse_status.status());
|
|
ASSERT_TRUE(parse_status.ValueOrDie().has_value());
|
|
EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0));
|
|
|
|
// Test setting based on out edges.
|
|
TF_ASSERT_OK(SetNodeShardingFromNeighbors(a_node, /*out_edges=*/true));
|
|
parse_status = ParseShardingFromDevice(*a_node, num_cores_per_replica);
|
|
TF_ASSERT_OK(parse_status.status());
|
|
ASSERT_TRUE(parse_status.ValueOrDie().has_value());
|
|
EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0));
|
|
}
|
|
|
|
REGISTER_OP("One")
|
|
.Output("y: T")
|
|
.Attr("T: {float, double, int32, int64}")
|
|
.Doc(R"doc(
|
|
Returns a tensor with a single element (1) of type T.
|
|
|
|
y: A scalar in type T.
|
|
|
|
)doc");
|
|
|
|
// Tests that CachedFunctionHandles class works.
|
|
TEST(CachedFunctionHandles, Basic) {
|
|
FunctionDef func = FunctionDefHelper::Define(
|
|
// Name
|
|
"TestFunc",
|
|
// Args
|
|
{},
|
|
// Return values
|
|
{"y:T"},
|
|
// Attr def
|
|
{"T:{float, double, int32, int64}"},
|
|
// Nodes
|
|
{
|
|
{{"y"}, "One", {}, {{"T", "$T"}}},
|
|
});
|
|
FunctionDefLibrary proto;
|
|
*proto.add_function() = func;
|
|
FunctionLibraryDefinition fld(OpRegistry::Global(), proto);
|
|
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
|
new ProcessFunctionLibraryRuntime(
|
|
/*device_mgr=*/nullptr, Env::Default(), /*config=*/nullptr,
|
|
TF_GRAPH_DEF_VERSION, &fld, OptimizerOptions()));
|
|
FunctionLibraryRuntime* flr =
|
|
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
|
|
|
CachedFunctionHandles cached_function_handles(flr);
|
|
|
|
// Tests that GetOrInstantiate() works.
|
|
FunctionLibraryRuntime::Handle first_handle;
|
|
AttrValue attr;
|
|
attr.set_type(DT_FLOAT);
|
|
AttrValueMap attrs;
|
|
attrs["T"] = attr;
|
|
TF_ASSERT_OK(cached_function_handles.GetOrInstantiate(
|
|
"TestFunc", AttrSlice(&attrs), &first_handle));
|
|
|
|
// Tests that we can get FunctionBody.
|
|
const FunctionBody* body = flr->GetFunctionBody(first_handle);
|
|
EXPECT_NE(body, nullptr);
|
|
|
|
// Tests that GetOrInstantiate() returns cached handle when called with same
|
|
// function name and attributes.
|
|
FunctionLibraryRuntime::Handle second_handle;
|
|
TF_ASSERT_OK(cached_function_handles.GetOrInstantiate(
|
|
"TestFunc", AttrSlice(&attrs), &second_handle));
|
|
EXPECT_EQ(first_handle, second_handle);
|
|
|
|
// Tests that GetOrInstantiate() returns new handle when called with same
|
|
// function name but different attributes.
|
|
attr.set_type(DT_INT32);
|
|
attrs["T"] = attr;
|
|
FunctionLibraryRuntime::Handle third_handle;
|
|
TF_ASSERT_OK(cached_function_handles.GetOrInstantiate(
|
|
"TestFunc", AttrSlice(&attrs), &third_handle));
|
|
EXPECT_NE(first_handle, third_handle);
|
|
|
|
// Tests that ReleaseAllHandles() works.
|
|
TF_EXPECT_OK(cached_function_handles.ReleaseAllHandles());
|
|
}
|
|
|
|
TEST(PropagateConstIntoFunctionalNodes, WhileLoopWithResourceInput) {
|
|
FunctionLibraryDefinition fld(OpRegistry::Global(), {});
|
|
{
|
|
// Cond graph & body graph.
|
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
|
auto pred = ops::_Arg(scope.WithOpName("pred"), DT_BOOL, 0);
|
|
auto input = ops::_Arg(scope.WithOpName("input"), DT_RESOURCE, 1);
|
|
auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0);
|
|
Graph graph(OpRegistry::Global());
|
|
TF_ASSERT_OK(scope.ToGraph(&graph));
|
|
FunctionDef cond_fdef;
|
|
TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef));
|
|
TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef));
|
|
FunctionDef body_fdef;
|
|
TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef));
|
|
TF_ASSERT_OK(fld.AddFunctionDef(body_fdef));
|
|
}
|
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
|
auto pred = ops::Const(scope.WithOpName("pred"), false, TensorShape({}));
|
|
auto input = ops::Const(scope.WithOpName("input"), 0, TensorShape({}));
|
|
NameAttrList cond_fn, body_fn;
|
|
cond_fn.set_name("cond");
|
|
body_fn.set_name("body");
|
|
auto while_op =
|
|
ops::While(scope.WithOpName("while"),
|
|
std::initializer_list<Input>{pred, input}, cond_fn, body_fn);
|
|
Graph graph(OpRegistry::Global());
|
|
TF_ASSERT_OK(scope.ToGraph(&graph));
|
|
|
|
TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld));
|
|
}
|
|
|
|
TEST(PropagateConstIntoFunctionalNodes, CopiedConstNodeHasUniqueName) {
|
|
FunctionLibraryDefinition fld(OpRegistry::Global(), {});
|
|
{
|
|
// Cond graph & body graph.
|
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
|
auto pred = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0);
|
|
auto input = ops::_Arg(scope.WithOpName("arg1"), DT_BOOL, 1);
|
|
auto duplicate_name = ops::NoOp(scope.WithOpName("duplicate_name"));
|
|
auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0);
|
|
Graph graph(OpRegistry::Global());
|
|
TF_ASSERT_OK(scope.ToGraph(&graph));
|
|
FunctionDef cond_fdef;
|
|
TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef));
|
|
TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef));
|
|
FunctionDef body_fdef;
|
|
TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef));
|
|
TF_ASSERT_OK(fld.AddFunctionDef(body_fdef));
|
|
}
|
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
|
auto pred =
|
|
ops::Const(scope.WithOpName("duplicate_name"), false, TensorShape({}));
|
|
auto input = ops::Const(scope.WithOpName("input"), false, TensorShape({}));
|
|
NameAttrList cond_fn, body_fn;
|
|
cond_fn.set_name("cond");
|
|
body_fn.set_name("body");
|
|
auto while_op =
|
|
ops::While(scope.WithOpName("while"),
|
|
std::initializer_list<Input>{pred, input}, cond_fn, body_fn);
|
|
Graph graph(OpRegistry::Global());
|
|
TF_ASSERT_OK(scope.ToGraph(&graph));
|
|
|
|
TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld));
|
|
|
|
// Check that in rewritten body function, the NoOp node still has name
|
|
// "duplicate_name", and the copied Const node has name "duplicate_name/_0".
|
|
auto node_name_index = graph.BuildNodeNameIndex();
|
|
Node* while_node = node_name_index["while"];
|
|
ASSERT_NE(while_node, nullptr);
|
|
TF_ASSERT_OK(GetNodeAttr(while_node->def(), "body", &body_fn));
|
|
const FunctionDef* rewritten_body_fn = fld.Find(body_fn.name());
|
|
ASSERT_NE(rewritten_body_fn, nullptr);
|
|
std::unordered_map<string, NodeDef> nodes;
|
|
for (const NodeDef& node_def : rewritten_body_fn->node_def()) {
|
|
nodes[node_def.name()] = node_def;
|
|
}
|
|
auto noop_def = nodes.find("duplicate_name");
|
|
ASSERT_NE(noop_def, nodes.end());
|
|
EXPECT_EQ(noop_def->second.op(), "NoOp");
|
|
auto const_def = nodes.find("duplicate_name/_0");
|
|
ASSERT_NE(const_def, nodes.end());
|
|
EXPECT_EQ(const_def->second.op(), "Const");
|
|
}
|
|
|
|
TEST(PropagateConstIntoFunctionalNodes, RewriteTensorListWithConstMember) {
|
|
FunctionLibraryDefinition fld(OpRegistry::Global(), {});
|
|
{
|
|
// Cond graph
|
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
|
auto input = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
|
|
auto result =
|
|
ops::Const(scope.WithOpName("result"), false, TensorShape({}));
|
|
auto ret = ops::_Retval(scope.WithOpName("ret"), result, 0);
|
|
Graph graph(OpRegistry::Global());
|
|
TF_ASSERT_OK(scope.ToGraph(&graph));
|
|
FunctionDef fdef;
|
|
TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &fdef));
|
|
TF_ASSERT_OK(fld.AddFunctionDef(fdef));
|
|
}
|
|
{
|
|
// Forward body graph
|
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
|
auto input = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
|
|
auto element = ops::Const(scope.WithOpName("element"), 0, TensorShape({}));
|
|
auto push =
|
|
ops::TensorListPushBack(scope.WithOpName("push"), input, element);
|
|
auto ret = ops::_Retval(scope.WithOpName("ret"), push.output_handle, 0);
|
|
Graph graph(OpRegistry::Global());
|
|
TF_ASSERT_OK(scope.ToGraph(&graph));
|
|
FunctionDef fdef;
|
|
TF_ASSERT_OK(GraphToFunctionDef(graph, "fwd_body", &fdef));
|
|
TF_ASSERT_OK(fld.AddFunctionDef(fdef));
|
|
}
|
|
{
|
|
// Backward body graph
|
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
|
auto input = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
|
|
auto shape = ops::Const(scope.WithOpName("element"), -1, TensorShape({}));
|
|
auto pop =
|
|
ops::TensorListPopBack(scope.WithOpName("pop"), input, shape, DT_INT32);
|
|
auto identity = ops::Identity(scope.WithOpName("identity"), pop.tensor);
|
|
auto ret = ops::_Retval(scope.WithOpName("ret"), pop.output_handle, 0);
|
|
Graph graph(OpRegistry::Global());
|
|
TF_ASSERT_OK(scope.ToGraph(&graph));
|
|
FunctionDef fdef;
|
|
TF_ASSERT_OK(GraphToFunctionDef(graph, "bwd_body", &fdef));
|
|
TF_ASSERT_OK(fld.AddFunctionDef(fdef));
|
|
}
|
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
|
auto shape = ops::Const(scope.WithOpName("element"), -1, TensorShape({}));
|
|
auto max_num_elements =
|
|
ops::Const(scope.WithOpName("max_num_elements"), 10, TensorShape({}));
|
|
auto tl = ops::EmptyTensorList(scope.WithOpName("tl"), shape,
|
|
max_num_elements, DT_INT32);
|
|
NameAttrList cond_fn, fwd_body_fn, bwd_body_fn;
|
|
cond_fn.set_name("cond");
|
|
fwd_body_fn.set_name("fwd_body");
|
|
bwd_body_fn.set_name("bwd_body");
|
|
auto fwd_while_op =
|
|
ops::While(scope.WithOpName("fwd_while"),
|
|
std::initializer_list<Input>{tl}, cond_fn, fwd_body_fn);
|
|
auto bwd_while_op =
|
|
ops::While(scope.WithOpName("bwd_while"),
|
|
std::initializer_list<Input>{fwd_while_op.output[0]}, cond_fn,
|
|
bwd_body_fn);
|
|
Graph graph(OpRegistry::Global());
|
|
TF_ASSERT_OK(scope.ToGraph(&graph));
|
|
|
|
TF_EXPECT_OK(RewriteTensorListWithConstElement(&graph, &fld));
|
|
|
|
// Check that in rewritten backward While body function, the Identity node now
|
|
// has Const node as input.
|
|
const FunctionDef* bwd_body = fld.Find("bwd_body_tl_rewrite_0");
|
|
ASSERT_NE(bwd_body, nullptr);
|
|
std::unique_ptr<FunctionBody> bwd_fbody;
|
|
TF_CHECK_OK(
|
|
FunctionDefToBodyHelper(*bwd_body, AttrSlice(), &fld, &bwd_fbody));
|
|
auto node_name_index = bwd_fbody->graph->BuildNodeNameIndex();
|
|
const Node* identity = node_name_index.at("identity");
|
|
ASSERT_NE(identity, nullptr);
|
|
const Node* input;
|
|
TF_ASSERT_OK(identity->input_node(0, &input));
|
|
EXPECT_EQ(input->type_string(), "Const");
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace tensorflow
|