[Grappler] Don't remove constant feed nodes in LoopOptimizer RemoveDeadBranches.
PiperOrigin-RevId: 233633382
This commit is contained in:
parent
e7771a7894
commit
7141c42808
@ -581,8 +581,19 @@ Status EvaluateBoolOpForConstantOperands(const NodeDef& op_node,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(lyandy): Consolidate with ConstantFolding implementation.
|
||||
bool IsReallyConstant(const NodeDef& node,
|
||||
const absl::flat_hash_set<string>& feed_nodes) {
|
||||
if (!IsConstant(node)) {
|
||||
return false;
|
||||
}
|
||||
// If the node is fed it's not constant anymore.
|
||||
return feed_nodes.find(node.name()) == feed_nodes.end();
|
||||
}
|
||||
|
||||
Status CheckForDeadFanout(const MutableGraphView& view,
|
||||
const NodeDef& switch_node, const NodeMap& node_map,
|
||||
const absl::flat_hash_set<string>& feed_nodes,
|
||||
DeviceBase* cpu_device, ResourceMgr* resource_mgr,
|
||||
bool* has_dead_fanout, int* dead_fanout) {
|
||||
*has_dead_fanout = false;
|
||||
@ -591,7 +602,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
|
||||
view.GetRegularFanin(switch_loopcond_port).node;
|
||||
|
||||
// CASE 1: Control is a constant.
|
||||
if (IsConstant(*switch_predicate)) {
|
||||
if (IsReallyConstant(*switch_predicate, feed_nodes)) {
|
||||
Tensor selector;
|
||||
CHECK(selector.FromProto(switch_predicate->attr().at("value").tensor()));
|
||||
*has_dead_fanout = true;
|
||||
@ -630,7 +641,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
|
||||
if (IsMerge(*node)) {
|
||||
merge_node = node;
|
||||
}
|
||||
if (IsConstant(*node)) {
|
||||
if (IsReallyConstant(*node, feed_nodes)) {
|
||||
constant_ctrl_input = node;
|
||||
constant_index = i;
|
||||
}
|
||||
@ -646,7 +657,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
|
||||
if (IsEnter(*node)) {
|
||||
enter_node = node;
|
||||
}
|
||||
if (IsConstant(*node)) {
|
||||
if (IsReallyConstant(*node, feed_nodes)) {
|
||||
constant_init_node = node;
|
||||
}
|
||||
}
|
||||
@ -654,7 +665,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
|
||||
if (constant_init_node != nullptr) return Status::OK();
|
||||
for (const auto& input : enter_node->input()) {
|
||||
NodeDef* node = node_map.GetNode(input);
|
||||
if (IsConstant(*node)) {
|
||||
if (IsReallyConstant(*node, feed_nodes)) {
|
||||
constant_init_node = node;
|
||||
}
|
||||
}
|
||||
@ -710,8 +721,12 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
// TODO(srjoglekar): Figure out if we can optimize NodeMap creations across
|
||||
// optimizer passes.
|
||||
NodeMap node_map(optimized_graph);
|
||||
TF_RETURN_IF_ERROR(
|
||||
RemoveDeadBranches(item.NodesToPreserve(), node_map, optimized_graph));
|
||||
absl::flat_hash_set<string> feed_nodes;
|
||||
for (const auto& feed : item.feed) {
|
||||
feed_nodes.insert(NodeName(feed.first));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(RemoveDeadBranches(item.NodesToPreserve(), node_map,
|
||||
feed_nodes, optimized_graph));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -719,7 +734,8 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
|
||||
Status LoopOptimizer::RemoveDeadBranches(
|
||||
const std::unordered_set<string>& nodes_to_preserve,
|
||||
const NodeMap& node_map, GraphDef* optimized_graph) {
|
||||
const NodeMap& node_map, const absl::flat_hash_set<string>& feed_nodes,
|
||||
GraphDef* optimized_graph) {
|
||||
std::unordered_set<const NodeDef*> dead_nodes;
|
||||
std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs;
|
||||
// TODO(bsteiner): also rewrite switches as identity. For now we just record
|
||||
@ -737,9 +753,9 @@ Status LoopOptimizer::RemoveDeadBranches(
|
||||
|
||||
int dead_fanout;
|
||||
bool has_dead_fanout;
|
||||
TF_RETURN_IF_ERROR(CheckForDeadFanout(view, node, node_map, cpu_device_,
|
||||
resource_mgr_.get(), &has_dead_fanout,
|
||||
&dead_fanout));
|
||||
TF_RETURN_IF_ERROR(CheckForDeadFanout(view, node, node_map, feed_nodes,
|
||||
cpu_device_, resource_mgr_.get(),
|
||||
&has_dead_fanout, &dead_fanout));
|
||||
if (!has_dead_fanout) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -60,7 +60,9 @@ class LoopOptimizer : public GraphOptimizer {
|
||||
};
|
||||
|
||||
Status RemoveDeadBranches(const std::unordered_set<string>& nodes_to_preserve,
|
||||
const NodeMap& node_map, GraphDef* optimized_graph);
|
||||
const NodeMap& node_map,
|
||||
const absl::flat_hash_set<string>& feed_nodes,
|
||||
GraphDef* optimized_graph);
|
||||
|
||||
RewriterConfig::Toggle opt_level_;
|
||||
DeviceBase* cpu_device_;
|
||||
|
||||
@ -504,11 +504,11 @@ void VerifyGraphsEqual(const GraphDef& original_graph,
|
||||
for (int i = 0; i < original_graph.node_size(); ++i) {
|
||||
const NodeDef& original = original_graph.node(i);
|
||||
const NodeDef& optimized = optimized_graph.node(i);
|
||||
EXPECT_EQ(original.name(), optimized.name()) << func;
|
||||
EXPECT_EQ(original.op(), optimized.op()) << func;
|
||||
EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
|
||||
EXPECT_EQ(optimized.name(), original.name()) << func;
|
||||
EXPECT_EQ(optimized.op(), original.op()) << func;
|
||||
ASSERT_EQ(optimized.input_size(), original.input_size()) << func;
|
||||
for (int j = 0; j < original.input_size(); ++j) {
|
||||
EXPECT_EQ(original.input(j), optimized.input(j)) << func;
|
||||
EXPECT_EQ(optimized.input(j), original.input(j)) << func;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -528,7 +528,7 @@ TEST_F(LoopOptimizerTest, NoOp) {
|
||||
VerifyGraphsEqual(item.graph, output, __FUNCTION__);
|
||||
}
|
||||
|
||||
TEST_F(LoopOptimizerTest, RemovePush_NoOp) {
|
||||
TEST_F(LoopOptimizerTest, RemovePushNoOp) {
|
||||
GrapplerItem item;
|
||||
GraphDef& graph = item.graph;
|
||||
AddSimpleNode("c", "Const", {}, &graph);
|
||||
@ -557,7 +557,7 @@ TEST_F(LoopOptimizerTest, RemovePush_NoOp) {
|
||||
VerifyGraphsEqual(item.graph, output, __FUNCTION__);
|
||||
}
|
||||
|
||||
TEST_F(LoopOptimizerTest, RemovePush_NoPopButStackLives) {
|
||||
TEST_F(LoopOptimizerTest, RemovePushNoPopButStackLives) {
|
||||
GrapplerItem item;
|
||||
GraphDef& graph = item.graph;
|
||||
AddSimpleNode("c", "Const", {}, &graph);
|
||||
@ -609,32 +609,32 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
EXPECT_EQ(13, output.node_size());
|
||||
EXPECT_EQ(output.node_size(), 13);
|
||||
for (int i = 0; i < output.node_size(); ++i) {
|
||||
const NodeDef& node = output.node(i);
|
||||
if (node.name() == "push1") {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("c", node.input(0));
|
||||
EXPECT_EQ("^stack1", node.input(1));
|
||||
EXPECT_EQ(node.op(), "Identity");
|
||||
ASSERT_EQ(node.input_size(), 2);
|
||||
EXPECT_EQ(node.input(0), "c");
|
||||
EXPECT_EQ(node.input(1), "^stack1");
|
||||
} else if (node.name() == "push2") {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("enter_c", node.input(0));
|
||||
EXPECT_EQ("^enter_stack2", node.input(1));
|
||||
EXPECT_EQ(node.op(), "Identity");
|
||||
ASSERT_EQ(node.input_size(), 2);
|
||||
EXPECT_EQ(node.input(0), "enter_c");
|
||||
EXPECT_EQ(node.input(1), "^enter_stack2");
|
||||
} else if (node.name() == "push3") {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("c", node.input(0));
|
||||
EXPECT_EQ("^stack3", node.input(1));
|
||||
EXPECT_EQ(node.op(), "Identity");
|
||||
ASSERT_EQ(node.input_size(), 2);
|
||||
EXPECT_EQ(node.input(0), "c");
|
||||
EXPECT_EQ(node.input(1), "^stack3");
|
||||
} else {
|
||||
const NodeDef& orig_node = item.graph.node(i);
|
||||
EXPECT_EQ(orig_node.ShortDebugString(), node.ShortDebugString());
|
||||
EXPECT_EQ(node.ShortDebugString(), orig_node.ShortDebugString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(LoopOptimizerTest, RemoveDeadBranches_ConstantCondition) {
|
||||
TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
|
||||
Scope scope = Scope::NewRootScope();
|
||||
Output v_in = ops::Variable(scope.WithOpName("v_in"), {3}, DT_FLOAT);
|
||||
|
||||
@ -691,57 +691,57 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranches_ConstantCondition) {
|
||||
|
||||
for (const NodeDef& node : output.node()) {
|
||||
// These nodes should have been pruned
|
||||
EXPECT_NE("Square1", node.name());
|
||||
EXPECT_NE("Sqrt2", node.name());
|
||||
EXPECT_NE("m5", node.name());
|
||||
EXPECT_NE("m7", node.name());
|
||||
EXPECT_NE(node.name(), "Square1");
|
||||
EXPECT_NE(node.name(), "Sqrt2");
|
||||
EXPECT_NE(node.name(), "m5");
|
||||
EXPECT_NE(node.name(), "m7");
|
||||
|
||||
if (node.name() == "m1") {
|
||||
// sqrt1 is dead
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
EXPECT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("square1", node.input(0));
|
||||
EXPECT_EQ(node.op(), "Identity");
|
||||
ASSERT_EQ(node.input_size(), 1);
|
||||
EXPECT_EQ(node.input(0), "square1");
|
||||
} else if (node.name() == "m2") {
|
||||
// both inputs are alive
|
||||
EXPECT_EQ("Merge", node.op());
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("v_in", node.input(0));
|
||||
EXPECT_EQ("square1", node.input(1));
|
||||
EXPECT_EQ(node.op(), "Merge");
|
||||
ASSERT_EQ(node.input_size(), 2);
|
||||
EXPECT_EQ(node.input(0), "v_in");
|
||||
EXPECT_EQ(node.input(1), "square1");
|
||||
} else if (node.name() == "m3") {
|
||||
// sqrt1 is dead
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
EXPECT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("v_in", node.input(0));
|
||||
EXPECT_EQ(node.op(), "Identity");
|
||||
ASSERT_EQ(node.input_size(), 1);
|
||||
EXPECT_EQ(node.input(0), "v_in");
|
||||
} else if (node.name() == "m4") {
|
||||
// both inputs are alive
|
||||
EXPECT_EQ("Merge", node.op());
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("square1", node.input(0));
|
||||
EXPECT_EQ("sqrt2", node.input(1));
|
||||
EXPECT_EQ(node.op(), "Merge");
|
||||
ASSERT_EQ(node.input_size(), 2);
|
||||
EXPECT_EQ(node.input(0), "square1");
|
||||
EXPECT_EQ(node.input(1), "sqrt2");
|
||||
} else if (node.name() == "m6") {
|
||||
// both inputs are alive and the control dependency can get triggered
|
||||
EXPECT_EQ("Merge", node.op());
|
||||
EXPECT_EQ(3, node.input_size());
|
||||
EXPECT_EQ("v_in", node.input(0));
|
||||
EXPECT_EQ("square1", node.input(1));
|
||||
EXPECT_EQ("^sqrt2", node.input(2));
|
||||
EXPECT_EQ(node.op(), "Merge");
|
||||
ASSERT_EQ(node.input_size(), 3);
|
||||
EXPECT_EQ(node.input(0), "v_in");
|
||||
EXPECT_EQ(node.input(1), "square1");
|
||||
EXPECT_EQ(node.input(2), "^sqrt2");
|
||||
} else if (node.name() == "m8") {
|
||||
// The node is to be preserved because of a fetch
|
||||
EXPECT_EQ("Merge", node.op());
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("id1", node.input(0));
|
||||
EXPECT_EQ("id2", node.input(1));
|
||||
EXPECT_EQ(node.op(), "Merge");
|
||||
ASSERT_EQ(node.input_size(), 2);
|
||||
EXPECT_EQ(node.input(0), "id1");
|
||||
EXPECT_EQ(node.input(1), "id2");
|
||||
} else if (node.name() == "m9") {
|
||||
// The node is to be preserved because of a fetch
|
||||
EXPECT_EQ("Merge", node.op());
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("id3", node.input(0));
|
||||
EXPECT_EQ("id4", node.input(1));
|
||||
EXPECT_EQ(node.op(), "Merge");
|
||||
ASSERT_EQ(2, node.input_size());
|
||||
EXPECT_EQ(node.input(0), "id3");
|
||||
EXPECT_EQ(node.input(1), "id4");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(LoopOptimizerTest, RemoveDeadBranches_FullyRemoveDeadBranches) {
|
||||
TEST_F(LoopOptimizerTest, RemoveDeadBranchesFullyRemoveDeadBranches) {
|
||||
const string gdef_ascii = R"EOF(
|
||||
node {
|
||||
name: "episodicreplaybuffer_add_readvariableop_resource"
|
||||
@ -1153,7 +1153,7 @@ versions {
|
||||
<< "Merge node was deleted, but it shouldn't have been.";
|
||||
}
|
||||
|
||||
TEST_F(LoopOptimizerTest, RemoveDeadBranches_ZeroIterWhile) {
|
||||
TEST_F(LoopOptimizerTest, RemoveDeadBranchesZeroIterWhile) {
|
||||
const string gdef_ascii = R"EOF(
|
||||
node {
|
||||
name: "Const"
|
||||
@ -1358,15 +1358,15 @@ versions {
|
||||
CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph));
|
||||
item.fetch = {"while/Exit"};
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
EXPECT_EQ(1, tensors_expected.size());
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
|
||||
LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr);
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
TF_CHECK_OK(status);
|
||||
auto tensors_got = EvaluateNodes(output, item.fetch);
|
||||
EXPECT_EQ(1, tensors_got.size());
|
||||
test::ExpectTensorEqual<int32>(tensors_expected[0], tensors_got[0]);
|
||||
ASSERT_EQ(tensors_got.size(), 1);
|
||||
test::ExpectTensorEqual<int32>(tensors_got[0], tensors_expected[0]);
|
||||
|
||||
int nodes_present = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
@ -1382,7 +1382,200 @@ versions {
|
||||
}
|
||||
++nodes_present;
|
||||
}
|
||||
EXPECT_EQ(8, nodes_present);
|
||||
EXPECT_EQ(nodes_present, 8);
|
||||
}
|
||||
|
||||
TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantFeed) {
|
||||
const string gdef_ascii = R"EOF(
|
||||
node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
device: "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_STRING
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
string_val: "I\'m a value!"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Switch_1"
|
||||
op: "Switch"
|
||||
input: "Const"
|
||||
input: "Const_1"
|
||||
device: "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Const"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const_1"
|
||||
op: "Const"
|
||||
device: "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_BOOL
|
||||
tensor_shape {
|
||||
}
|
||||
bool_val: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Switch"
|
||||
op: "Switch"
|
||||
input: "Const_1"
|
||||
input: "Const_1"
|
||||
device: "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/switch_t"
|
||||
op: "Identity"
|
||||
input: "cond/Switch:1"
|
||||
device: "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Const"
|
||||
op: "Const"
|
||||
input: "^cond/switch_t"
|
||||
device: "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_STRING
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
string_val: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Merge"
|
||||
op: "Merge"
|
||||
input: "cond/Switch_1"
|
||||
input: "cond/Const"
|
||||
device: "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
attr {
|
||||
key: "N"
|
||||
value {
|
||||
i: 2
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Identity"
|
||||
op: "Identity"
|
||||
input: "cond/Merge"
|
||||
device: "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
library {
|
||||
}
|
||||
versions {
|
||||
producer: 27
|
||||
}
|
||||
)EOF";
|
||||
|
||||
GrapplerItem item;
|
||||
CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph));
|
||||
item.fetch = {"Identity"};
|
||||
Tensor feed_tensor(DT_BOOL, {});
|
||||
feed_tensor.flat<bool>()(1) = false;
|
||||
item.feed.push_back({"Const_1", feed_tensor});
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
|
||||
LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr);
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
TF_CHECK_OK(status);
|
||||
auto tensors_got = EvaluateNodes(output, item.fetch);
|
||||
ASSERT_EQ(tensors_got.size(), 1);
|
||||
test::ExpectTensorEqual<string>(tensors_got[0], tensors_expected[0]);
|
||||
|
||||
EXPECT_EQ(output.node_size(), 8);
|
||||
|
||||
// No rewrite because branch has a constant feed node.
|
||||
bool found = false;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "cond/Merge") {
|
||||
EXPECT_EQ(node.op(), "Merge");
|
||||
ASSERT_EQ(node.input_size(), 2);
|
||||
EXPECT_EQ(node.input(0), "cond/Switch_1");
|
||||
EXPECT_EQ(node.input(1), "cond/Const");
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(found);
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user