[Grappler] Don't remove constant feed nodes in LoopOptimizer RemoveDeadBranches.

PiperOrigin-RevId: 233633382
This commit is contained in:
Andy Ly 2019-02-12 10:34:06 -08:00 committed by TensorFlower Gardener
parent e7771a7894
commit 7141c42808
3 changed files with 280 additions and 69 deletions

View File

@ -581,8 +581,19 @@ Status EvaluateBoolOpForConstantOperands(const NodeDef& op_node,
return Status::OK();
}
// TODO(lyandy): Consolidate with ConstantFolding implementation.
bool IsReallyConstant(const NodeDef& node,
const absl::flat_hash_set<string>& feed_nodes) {
if (!IsConstant(node)) {
return false;
}
// If the node is fed it's not constant anymore.
return feed_nodes.find(node.name()) == feed_nodes.end();
}
Status CheckForDeadFanout(const MutableGraphView& view,
const NodeDef& switch_node, const NodeMap& node_map,
const absl::flat_hash_set<string>& feed_nodes,
DeviceBase* cpu_device, ResourceMgr* resource_mgr,
bool* has_dead_fanout, int* dead_fanout) {
*has_dead_fanout = false;
@ -591,7 +602,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
view.GetRegularFanin(switch_loopcond_port).node;
// CASE 1: Control is a constant.
if (IsConstant(*switch_predicate)) {
if (IsReallyConstant(*switch_predicate, feed_nodes)) {
Tensor selector;
CHECK(selector.FromProto(switch_predicate->attr().at("value").tensor()));
*has_dead_fanout = true;
@ -630,7 +641,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
if (IsMerge(*node)) {
merge_node = node;
}
if (IsConstant(*node)) {
if (IsReallyConstant(*node, feed_nodes)) {
constant_ctrl_input = node;
constant_index = i;
}
@ -646,7 +657,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
if (IsEnter(*node)) {
enter_node = node;
}
if (IsConstant(*node)) {
if (IsReallyConstant(*node, feed_nodes)) {
constant_init_node = node;
}
}
@ -654,7 +665,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
if (constant_init_node != nullptr) return Status::OK();
for (const auto& input : enter_node->input()) {
NodeDef* node = node_map.GetNode(input);
if (IsConstant(*node)) {
if (IsReallyConstant(*node, feed_nodes)) {
constant_init_node = node;
}
}
@ -710,8 +721,12 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// TODO(srjoglekar): Figure out if we can optimize NodeMap creations across
// optimizer passes.
NodeMap node_map(optimized_graph);
TF_RETURN_IF_ERROR(
RemoveDeadBranches(item.NodesToPreserve(), node_map, optimized_graph));
absl::flat_hash_set<string> feed_nodes;
for (const auto& feed : item.feed) {
feed_nodes.insert(NodeName(feed.first));
}
TF_RETURN_IF_ERROR(RemoveDeadBranches(item.NodesToPreserve(), node_map,
feed_nodes, optimized_graph));
}
return Status::OK();
@ -719,7 +734,8 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
Status LoopOptimizer::RemoveDeadBranches(
const std::unordered_set<string>& nodes_to_preserve,
const NodeMap& node_map, GraphDef* optimized_graph) {
const NodeMap& node_map, const absl::flat_hash_set<string>& feed_nodes,
GraphDef* optimized_graph) {
std::unordered_set<const NodeDef*> dead_nodes;
std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs;
// TODO(bsteiner): also rewrite switches as identity. For now we just record
@ -737,9 +753,9 @@ Status LoopOptimizer::RemoveDeadBranches(
int dead_fanout;
bool has_dead_fanout;
TF_RETURN_IF_ERROR(CheckForDeadFanout(view, node, node_map, cpu_device_,
resource_mgr_.get(), &has_dead_fanout,
&dead_fanout));
TF_RETURN_IF_ERROR(CheckForDeadFanout(view, node, node_map, feed_nodes,
cpu_device_, resource_mgr_.get(),
&has_dead_fanout, &dead_fanout));
if (!has_dead_fanout) {
continue;
}

View File

@ -60,7 +60,9 @@ class LoopOptimizer : public GraphOptimizer {
};
Status RemoveDeadBranches(const std::unordered_set<string>& nodes_to_preserve,
const NodeMap& node_map, GraphDef* optimized_graph);
const NodeMap& node_map,
const absl::flat_hash_set<string>& feed_nodes,
GraphDef* optimized_graph);
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;

View File

@ -504,11 +504,11 @@ void VerifyGraphsEqual(const GraphDef& original_graph,
for (int i = 0; i < original_graph.node_size(); ++i) {
const NodeDef& original = original_graph.node(i);
const NodeDef& optimized = optimized_graph.node(i);
EXPECT_EQ(original.name(), optimized.name()) << func;
EXPECT_EQ(original.op(), optimized.op()) << func;
EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
EXPECT_EQ(optimized.name(), original.name()) << func;
EXPECT_EQ(optimized.op(), original.op()) << func;
ASSERT_EQ(optimized.input_size(), original.input_size()) << func;
for (int j = 0; j < original.input_size(); ++j) {
EXPECT_EQ(original.input(j), optimized.input(j)) << func;
EXPECT_EQ(optimized.input(j), original.input(j)) << func;
}
}
}
@ -528,7 +528,7 @@ TEST_F(LoopOptimizerTest, NoOp) {
VerifyGraphsEqual(item.graph, output, __FUNCTION__);
}
TEST_F(LoopOptimizerTest, RemovePush_NoOp) {
TEST_F(LoopOptimizerTest, RemovePushNoOp) {
GrapplerItem item;
GraphDef& graph = item.graph;
AddSimpleNode("c", "Const", {}, &graph);
@ -557,7 +557,7 @@ TEST_F(LoopOptimizerTest, RemovePush_NoOp) {
VerifyGraphsEqual(item.graph, output, __FUNCTION__);
}
TEST_F(LoopOptimizerTest, RemovePush_NoPopButStackLives) {
TEST_F(LoopOptimizerTest, RemovePushNoPopButStackLives) {
GrapplerItem item;
GraphDef& graph = item.graph;
AddSimpleNode("c", "Const", {}, &graph);
@ -609,32 +609,32 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
EXPECT_EQ(13, output.node_size());
EXPECT_EQ(output.node_size(), 13);
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
if (node.name() == "push1") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("c", node.input(0));
EXPECT_EQ("^stack1", node.input(1));
EXPECT_EQ(node.op(), "Identity");
ASSERT_EQ(node.input_size(), 2);
EXPECT_EQ(node.input(0), "c");
EXPECT_EQ(node.input(1), "^stack1");
} else if (node.name() == "push2") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("enter_c", node.input(0));
EXPECT_EQ("^enter_stack2", node.input(1));
EXPECT_EQ(node.op(), "Identity");
ASSERT_EQ(node.input_size(), 2);
EXPECT_EQ(node.input(0), "enter_c");
EXPECT_EQ(node.input(1), "^enter_stack2");
} else if (node.name() == "push3") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("c", node.input(0));
EXPECT_EQ("^stack3", node.input(1));
EXPECT_EQ(node.op(), "Identity");
ASSERT_EQ(node.input_size(), 2);
EXPECT_EQ(node.input(0), "c");
EXPECT_EQ(node.input(1), "^stack3");
} else {
const NodeDef& orig_node = item.graph.node(i);
EXPECT_EQ(orig_node.ShortDebugString(), node.ShortDebugString());
EXPECT_EQ(node.ShortDebugString(), orig_node.ShortDebugString());
}
}
}
TEST_F(LoopOptimizerTest, RemoveDeadBranches_ConstantCondition) {
TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
Scope scope = Scope::NewRootScope();
Output v_in = ops::Variable(scope.WithOpName("v_in"), {3}, DT_FLOAT);
@ -691,57 +691,57 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranches_ConstantCondition) {
for (const NodeDef& node : output.node()) {
// These nodes should have been pruned
EXPECT_NE("Square1", node.name());
EXPECT_NE("Sqrt2", node.name());
EXPECT_NE("m5", node.name());
EXPECT_NE("m7", node.name());
EXPECT_NE(node.name(), "Square1");
EXPECT_NE(node.name(), "Sqrt2");
EXPECT_NE(node.name(), "m5");
EXPECT_NE(node.name(), "m7");
if (node.name() == "m1") {
// sqrt1 is dead
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("square1", node.input(0));
EXPECT_EQ(node.op(), "Identity");
ASSERT_EQ(node.input_size(), 1);
EXPECT_EQ(node.input(0), "square1");
} else if (node.name() == "m2") {
// both inputs are alive
EXPECT_EQ("Merge", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("v_in", node.input(0));
EXPECT_EQ("square1", node.input(1));
EXPECT_EQ(node.op(), "Merge");
ASSERT_EQ(node.input_size(), 2);
EXPECT_EQ(node.input(0), "v_in");
EXPECT_EQ(node.input(1), "square1");
} else if (node.name() == "m3") {
// sqrt1 is dead
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("v_in", node.input(0));
EXPECT_EQ(node.op(), "Identity");
ASSERT_EQ(node.input_size(), 1);
EXPECT_EQ(node.input(0), "v_in");
} else if (node.name() == "m4") {
// both inputs are alive
EXPECT_EQ("Merge", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("square1", node.input(0));
EXPECT_EQ("sqrt2", node.input(1));
EXPECT_EQ(node.op(), "Merge");
ASSERT_EQ(node.input_size(), 2);
EXPECT_EQ(node.input(0), "square1");
EXPECT_EQ(node.input(1), "sqrt2");
} else if (node.name() == "m6") {
// both inputs are alive and the control dependency can get triggered
EXPECT_EQ("Merge", node.op());
EXPECT_EQ(3, node.input_size());
EXPECT_EQ("v_in", node.input(0));
EXPECT_EQ("square1", node.input(1));
EXPECT_EQ("^sqrt2", node.input(2));
EXPECT_EQ(node.op(), "Merge");
ASSERT_EQ(node.input_size(), 3);
EXPECT_EQ(node.input(0), "v_in");
EXPECT_EQ(node.input(1), "square1");
EXPECT_EQ(node.input(2), "^sqrt2");
} else if (node.name() == "m8") {
// The node is to be preserved because of a fetch
EXPECT_EQ("Merge", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("id1", node.input(0));
EXPECT_EQ("id2", node.input(1));
EXPECT_EQ(node.op(), "Merge");
ASSERT_EQ(node.input_size(), 2);
EXPECT_EQ(node.input(0), "id1");
EXPECT_EQ(node.input(1), "id2");
} else if (node.name() == "m9") {
// The node is to be preserved because of a fetch
EXPECT_EQ("Merge", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("id3", node.input(0));
EXPECT_EQ("id4", node.input(1));
EXPECT_EQ(node.op(), "Merge");
ASSERT_EQ(2, node.input_size());
EXPECT_EQ(node.input(0), "id3");
EXPECT_EQ(node.input(1), "id4");
}
}
}
TEST_F(LoopOptimizerTest, RemoveDeadBranches_FullyRemoveDeadBranches) {
TEST_F(LoopOptimizerTest, RemoveDeadBranchesFullyRemoveDeadBranches) {
const string gdef_ascii = R"EOF(
node {
name: "episodicreplaybuffer_add_readvariableop_resource"
@ -1153,7 +1153,7 @@ versions {
<< "Merge node was deleted, but it shouldn't have been.";
}
TEST_F(LoopOptimizerTest, RemoveDeadBranches_ZeroIterWhile) {
TEST_F(LoopOptimizerTest, RemoveDeadBranchesZeroIterWhile) {
const string gdef_ascii = R"EOF(
node {
name: "Const"
@ -1358,15 +1358,15 @@ versions {
CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph));
item.fetch = {"while/Exit"};
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
ASSERT_EQ(tensors_expected.size(), 1);
LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_CHECK_OK(status);
auto tensors_got = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors_got.size());
test::ExpectTensorEqual<int32>(tensors_expected[0], tensors_got[0]);
ASSERT_EQ(tensors_got.size(), 1);
test::ExpectTensorEqual<int32>(tensors_got[0], tensors_expected[0]);
int nodes_present = 0;
for (const NodeDef& node : output.node()) {
@ -1382,7 +1382,200 @@ versions {
}
++nodes_present;
}
EXPECT_EQ(8, nodes_present);
EXPECT_EQ(nodes_present, 8);
}
TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantFeed) {
const string gdef_ascii = R"EOF(
node {
name: "Const"
op: "Const"
device: "/job:localhost/replica:0/task:0/device:CPU:0"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 1
}
}
string_val: "I\'m a value!"
}
}
}
}
node {
name: "cond/Switch_1"
op: "Switch"
input: "Const"
input: "Const_1"
device: "/job:localhost/replica:0/task:0/device:CPU:0"
attr {
key: "T"
value {
type: DT_STRING
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Const"
}
}
}
}
node {
name: "Const_1"
op: "Const"
device: "/job:localhost/replica:0/task:0/device:CPU:0"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_BOOL
tensor_shape {
}
bool_val: true
}
}
}
}
node {
name: "cond/Switch"
op: "Switch"
input: "Const_1"
input: "Const_1"
device: "/job:localhost/replica:0/task:0/device:CPU:0"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "cond/switch_t"
op: "Identity"
input: "cond/Switch:1"
device: "/job:localhost/replica:0/task:0/device:CPU:0"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "cond/Const"
op: "Const"
input: "^cond/switch_t"
device: "/job:localhost/replica:0/task:0/device:CPU:0"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 1
}
}
string_val: ""
}
}
}
}
node {
name: "cond/Merge"
op: "Merge"
input: "cond/Switch_1"
input: "cond/Const"
device: "/job:localhost/replica:0/task:0/device:CPU:0"
attr {
key: "N"
value {
i: 2
}
}
attr {
key: "T"
value {
type: DT_STRING
}
}
}
node {
name: "Identity"
op: "Identity"
input: "cond/Merge"
device: "/job:localhost/replica:0/task:0/device:CPU:0"
attr {
key: "T"
value {
type: DT_STRING
}
}
}
library {
}
versions {
producer: 27
}
)EOF";
GrapplerItem item;
CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph));
item.fetch = {"Identity"};
Tensor feed_tensor(DT_BOOL, {});
feed_tensor.flat<bool>()(1) = false;
item.feed.push_back({"Const_1", feed_tensor});
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
ASSERT_EQ(tensors_expected.size(), 1);
LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_CHECK_OK(status);
auto tensors_got = EvaluateNodes(output, item.fetch);
ASSERT_EQ(tensors_got.size(), 1);
test::ExpectTensorEqual<string>(tensors_got[0], tensors_expected[0]);
EXPECT_EQ(output.node_size(), 8);
// No rewrite because branch has a constant feed node.
bool found = false;
for (const NodeDef& node : output.node()) {
if (node.name() == "cond/Merge") {
EXPECT_EQ(node.op(), "Merge");
ASSERT_EQ(node.input_size(), 2);
EXPECT_EQ(node.input(0), "cond/Switch_1");
EXPECT_EQ(node.input(1), "cond/Const");
found = true;
break;
}
}
EXPECT_TRUE(found);
}
} // namespace grappler