Automated rollback of commit 0c9d4b5ea1

PiperOrigin-RevId: 257255123
This commit is contained in:
A. Unique TensorFlower 2019-07-09 13:07:55 -07:00 committed by TensorFlower Gardener
parent 8e7014bcac
commit a1661765dd
3 changed files with 92 additions and 245 deletions

View File

@ -989,10 +989,11 @@ bool ConstantFolding::IsFoldable(const NodeDef& node,
}
}
// Don't fold nodes that have no outgoing edges except whitelisted nodes.
// Such nodes could be introduced by an earlier constant folding pass and are
// preserved in case users want to fetch their values; re-processing them
// would lead to an error of adding a duplicated node to graph.
// No need to (and don't) fold nodes that have no outgoing edges except
// whitelisted nodes. Such nodes could be introduced by an earlier constant
// folding pass and are preserved in case users want to fetch their values;
// re-processing them would lead to an error of adding a duplicated node
// to graph.
const auto& outputs = node_map_->GetOutputs(node.name());
if (outputs.empty() &&
nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) {
@ -1028,7 +1029,6 @@ bool ConstantFolding::IsFoldable(const NodeDef& node,
return false;
}
}
if (is_merge && !merge_has_constant_input) return false;
// If we know the output shapes, make sure that the outputs are small enough
// to materialize.
@ -1050,7 +1050,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node,
}
}
return true;
return !is_merge || merge_has_constant_input;
}
namespace {
@ -2827,162 +2827,83 @@ bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
// / \ / \
// X Y C Y = leaves
//
// where C is constant, X is non-constant, Y may be constant or non-constant,
// and '+' denotes an associative and commutative operator like addition or
// multiplication. This optimization pushes constants down in the tree to
// canonicalize it. Moreoever, in cases where the child node has a second
// constant input Y we will create a leaf node that can be folded, e.g.
// where C is constant and X is non-constant, and '+' denotes an
// associative and commutative operator like addition or multiplication.
// This optimization pushes constants down in the tree to canonicalize it.
// Moreoever, in cases where the child node has a second constant input Y
// we will create a leaf node that can be folded, e.g.
//
// Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
//
// We also handle the non-commutative cases of subtraction and division
// by rotating the tree locally, e.g.
// Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
// Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
//
// Note: Don't touch BiasAdd since they can't handle vectors as their first
// TODO(rmlarsen): Handle non-associative/non-commutative operators like
// subtraction and division, as well as mixed subtraction/addition,
// division/multiplication.
// Don't touch BiasAdd since they can't handle vectors as their first
// inputs.
if (has_fetch_ && (IsAdd(*node) || IsMul(*node)) &&
NumNonControlInputs(*node) == 2) {
NodeDef* left_child = node_map_->GetNode(node->input(0));
NodeDef* right_child = node_map_->GetNode(node->input(1));
// One child must be constant, and the other the same op as the parent.
if (node->op() != left_child->op() && node->op() != right_child->op()) {
return false;
}
const bool left_child_is_constant = IsReallyConstant(*left_child);
const bool right_child_is_constant = IsReallyConstant(*right_child);
if (!left_child_is_constant && !right_child_is_constant) {
return false;
}
if (node->device() != left_child->device() ||
node->device() != right_child->device()) {
return false;
}
NodeDef* op_child_node = left_child_is_constant ? right_child : left_child;
NodeDef* const_child_node =
left_child_is_constant ? left_child : right_child;
// Make sure that it is safe to change the value of the child node->
if (op_child_node->input_size() < 2 ||
nodes_to_preserve_.find(op_child_node->name()) !=
nodes_to_preserve_.end() ||
NumNonControlOutputs(*op_child_node, *node_map_) > 1) {
return false;
}
// Get parent op type.
const bool is_add = IsAdd(*node);
const bool is_mul = IsMul(*node);
const bool is_sub = IsSub(*node);
const bool is_div = IsDiv(*node);
const bool is_symmetric = is_add || is_mul;
if (!has_fetch_ || !(is_add || is_sub || is_mul || is_div) ||
NumNonControlInputs(*node) != 2) {
return false;
}
// Identify the nodes to swap.
NodeDef* left_leaf = node_map_->GetNode(op_child_node->input(0));
NodeDef* right_leaf = node_map_->GetNode(op_child_node->input(1));
const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
if (left_leaf_is_constant && right_leaf_is_constant) {
// Child is already foldable, leave it alone.
return false;
}
const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0;
const int parent_const_input = left_child_is_constant ? 0 : 1;
const auto& child_output = node_map_->GetOutputs(op_child_node->name());
if (child_output.find(const_child_node) != child_output.end()) {
// If there is a control edge from the child op to C, the transformation
// would create a cycle in the graph. We know that it must be a control
// edge. We can replace such a control edge with a control edge from A
// to C.
CHECK(MaybeRemoveControlInput(op_child_node->name(), const_child_node,
optimized_graph, node_map_.get()));
string other_leaf_input = left_leaf_is_constant ? op_child_node->input(0)
: op_child_node->input(1);
MaybeAddControlInput(other_leaf_input, const_child_node, optimized_graph,
node_map_.get());
}
NodeDef* left_child = node_map_->GetNode(node->input(0));
NodeDef* right_child = node_map_->GetNode(node->input(1));
const bool left_child_is_constant = IsReallyConstant(*left_child);
const bool right_child_is_constant = IsReallyConstant(*right_child);
if (!left_child_is_constant && !right_child_is_constant) {
return false;
// Swap the constant child with a non-constant leaf node.
node_map_->UpdateInput(node->name(), node->input(parent_const_input),
op_child_node->input(non_const_leaf_input));
node_map_->UpdateInput(op_child_node->name(),
op_child_node->input(non_const_leaf_input),
node->input(parent_const_input));
std::swap(*node->mutable_input(parent_const_input),
*op_child_node->mutable_input(non_const_leaf_input));
return true;
}
// Don't move nodes across devices.
if (node->device() != left_child->device() ||
node->device() != right_child->device()) {
return false;
}
NodeDef* op_child = left_child_is_constant ? right_child : left_child;
NodeDef* const_child = left_child_is_constant ? left_child : right_child;
// Don't rewrite the tree if it might create cycles.
// TODO(rmlarsen): Add back handling of control dependency from op to C.
const auto& child_output = node_map_->GetOutputs(op_child->name());
if (child_output.find(const_child) != child_output.end()) {
return false;
}
// Get child op type.
const bool is_child_add = IsAdd(*op_child);
const bool is_child_mul = IsMul(*op_child);
const bool is_child_sub = IsSub(*op_child);
const bool is_child_div = IsDiv(*op_child);
const bool is_add_sub = (is_add || is_sub) && (is_child_add || is_child_sub);
const bool is_mul_div = (is_mul || is_div) && (is_child_mul || is_child_div);
if (!is_add_sub && !is_mul_div) {
return false;
}
const bool is_child_symmetric = is_child_add || is_child_mul;
// Make sure that it is safe to change the value of the child node result.
if (op_child->input_size() < 2 ||
nodes_to_preserve_.find(op_child->name()) != nodes_to_preserve_.end() ||
NumNonControlOutputs(*op_child, *node_map_) > 1) {
return false;
}
// Do not rewrite integer expressions with subtraction or division.
if (!CheckAttrExists(*node, "T").ok()) return false;
DataType dtype = node->attr().at("T").type();
if (!(is_symmetric && is_child_symmetric) &&
!(DataTypeIsFloating(dtype) || DataTypeIsComplex(dtype))) {
return false;
}
// Identify the nodes to swap.
NodeDef* left_leaf = node_map_->GetNode(op_child->input(0));
NodeDef* right_leaf = node_map_->GetNode(op_child->input(1));
const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
if (left_leaf_is_constant && right_leaf_is_constant) {
// Child is already foldable, leave it alone.
return false;
}
// Don't move nodes across devices.
if (node->device() != left_leaf->device() ||
node->device() != right_leaf->device()) {
return false;
}
// Get the node names corresponding to X, Y, and C.
const string input_x =
left_leaf_is_constant ? op_child->input(1) : op_child->input(0);
const string input_y =
input_x == op_child->input(0) ? op_child->input(1) : op_child->input(0);
const string input_c =
left_child_is_constant ? node->input(0) : node->input(1);
const string input_op =
left_child_is_constant ? node->input(1) : node->input(0);
// Now we have identified the nodes to swap (non_const_leaf_input and
// const_child).
node_map_->UpdateInput(node->name(), input_c, input_x);
node_map_->AddOutput(input_c, op_child->name());
if (input_x != input_y) {
node_map_->RemoveOutput(input_x, op_child->name());
}
if (is_symmetric && is_child_symmetric) {
// Easy case (only commutative ops). We always write this as one of
// +
// / \
// X +
// / \
// C Y
node->set_input(0, input_x);
node->set_input(1, input_op);
op_child->set_input(0, input_c);
op_child->set_input(1, input_y);
} else {
// More complicated case: When there are non-commutative operations like
// subtractions or divisions involved, we may have to rotate the tree
// and/or change op types. There are 6 non-trivial cases depending on
// the effective generalized "sign" of each of the three terms C, Y, and X.
// Here are the final trees we want to generate for those 6 cases:
//
// (CYX signs): ++- +-- -+- --+ +-+ -++
// - - - - + +
// / \ / \ / \ / \ / \ / \
// + X - X - X X + X - X -
// / \ / \ / \ / \ / \ / \
// C Y C Y Y C Y C C Y Y C
//
NodeDef* non_const_leaf = left_leaf_is_constant ? right_leaf : left_leaf;
NodeDef* maybe_const_leaf =
non_const_leaf == right_leaf ? left_leaf : right_leaf;
// First, let's determine the effective sign of each term in the original
// expression
auto is_leaf_negated = [&](const NodeDef* node) -> bool {
bool leaf_negated = !is_child_symmetric && (node == right_leaf);
bool child_negated = !is_symmetric && (op_child == right_child);
return leaf_negated != child_negated;
};
const string symmetric_op = (is_add || is_sub) ? "Add" : "Mul";
const string nonsymmetric_op = (is_add || is_sub) ? "Sub" : "Div";
bool neg_c = !is_symmetric && (const_child == right_child);
bool neg_x = is_leaf_negated(non_const_leaf);
bool neg_y = is_leaf_negated(maybe_const_leaf);
// Rewrite the parent node.
node->set_op((neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op);
node->set_input(0, neg_x ? input_op : input_x);
node->set_input(1, neg_x ? input_x : input_op);
// Rewrite the child node.
op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
op_child->set_input(0, neg_c ? input_y : input_c);
op_child->set_input(1, neg_c ? input_c : input_y);
}
return true;
return false;
}
bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,

View File

@ -255,19 +255,20 @@ TEST_F(ConstantFoldingTest, SimpleFolding) {
TEST_F(ConstantFoldingTest, AddTree) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {1});
Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2, 2})));
Output add_child = ops::Add(s.WithOpName("add_child"), c2, x);
Output c1 = ops::Const(s.WithOpName("c1").WithControlDependencies(add_child),
1.0f, {1});
Output add_parent = ops::Add(s.WithOpName("add_parent"), c1, add_child);
Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2, 2})));
Output c4 = ops::Const(s.WithOpName("c4"), 4.0f, {2});
Output c5 = ops::Const(s.WithOpName("c5"), 5.0f, {2});
Output c20 = ops::Const(s.WithOpName("c20"), 20.0f, {2});
Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2, 2})));
Output mul_child = ops::Mul(s.WithOpName("mul_child"), c4, y);
Output mul_parent = ops::Mul(s.WithOpName("mul_parent"), c5, mul_child);
Output addmul_child = ops::Add(s.WithOpName("addmul_child"), c4, x);
@ -297,16 +298,16 @@ TEST_F(ConstantFoldingTest, AddTree) {
// / \ / \
// 5.0 y 4.0 5.0
EXPECT_EQ(10, output.node_size());
EXPECT_EQ(11, output.node_size());
for (const auto& node : output.node()) {
if (node.name() == "add_child") {
EXPECT_EQ("Const", node.op());
TensorProto t = node.attr().at("value").tensor();
ASSERT_EQ(1, t.tensor_shape().dim_size());
EXPECT_EQ(1, t.tensor_shape().dim_size());
EXPECT_EQ(2, t.tensor_shape().dim(0).size());
} else if (node.name() == "add_parent") {
EXPECT_EQ("Add", node.op());
ASSERT_EQ(2, node.input_size());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("add_child", node.input(1));
} else if (node.name() == "mul_child") {
@ -316,106 +317,30 @@ TEST_F(ConstantFoldingTest, AddTree) {
EXPECT_EQ(2, t.tensor_shape().dim(0).size());
} else if (node.name() == "mul_parent") {
EXPECT_EQ("Mul", node.op());
ASSERT_EQ(2, node.input_size());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("mul_child", node.input(1));
} else if (node.name() == "addmul_child") {
// Unchanged.
EXPECT_EQ("Add", node.op());
ASSERT_EQ(2, node.input_size());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("c4", node.input(0));
EXPECT_EQ("x", node.input(1));
}
}
// Check that the result nodes have the expected value.
auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
std::vector<string> fetch = {"add_parent", "mul_parent"};
auto tensor_expected =
EvaluateNodes(item.graph, fetch, {{"x", x_t}, {"y", y_t}});
ASSERT_EQ(fetch.size(), tensor_expected.size());
fetch = {"add_parent", "mul_parent"};
auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}, {"y", y_t}});
ASSERT_EQ(fetch.size(), tensors.size());
std::vector<string> fetch = {"c3", "c20"};
auto tensor_expected = EvaluateNodes(item.graph, fetch);
EXPECT_EQ(fetch.size(), tensor_expected.size());
fetch = {"add_child", "mul_child"};
auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(fetch.size(), tensors.size());
for (int i = 0; i < fetch.size(); i++) {
test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
}
}
TEST_F(ConstantFoldingTest, TreeCanonicalization) {
for (int is_add : {true, false}) {
for (int is_parent_commutative : {true, false}) {
for (int is_child_commutative : {true, false}) {
for (int is_left_child_const : {true, false}) {
for (int is_left_leaf_const : {true, false}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
Output x =
ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2, 2})));
auto get_op = [&](bool is_commutative, bool is_left_arg_cont,
const string& name, const Output& const_arg,
const Output non_const_arg) -> Output {
if (is_add) {
if (is_commutative) {
return ops::Add(s.WithOpName(name),
is_left_arg_cont ? const_arg : non_const_arg,
is_left_arg_cont ? non_const_arg : const_arg);
} else {
return ops::Sub(s.WithOpName(name),
is_left_arg_cont ? const_arg : non_const_arg,
is_left_arg_cont ? non_const_arg : const_arg);
}
} else {
if (is_commutative) {
return ops::Mul(s.WithOpName(name),
is_left_arg_cont ? const_arg : non_const_arg,
is_left_arg_cont ? non_const_arg : const_arg);
} else {
return ops::Div(s.WithOpName(name),
is_left_arg_cont ? const_arg : non_const_arg,
is_left_arg_cont ? non_const_arg : const_arg);
}
}
};
Output child = get_op(is_child_commutative, is_left_leaf_const,
"child", c2, x);
Output parent = get_op(is_parent_commutative, is_left_child_const,
"parent", c3, child);
GrapplerItem item;
item.fetch = {"parent"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
ConstantFolding optimizer(/*cpu_device=*/nullptr);
GraphDef output;
Status status =
optimizer.Optimize(/*cluster=*/nullptr, item, &output);
TF_EXPECT_OK(status);
// Check that the result nodes have the expected value.
auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
std::vector<string> fetch = {"parent"};
auto tensor_expected =
EvaluateNodes(item.graph, fetch, {{"x", x_t}});
ASSERT_EQ(fetch.size(), tensor_expected.size());
fetch = {"parent"};
auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}});
ASSERT_EQ(fetch.size(), tensors.size());
for (int i = 0; i < fetch.size(); i++) {
test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
}
}
}
}
}
}
}
TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_ScalarConst) {
for (string data_format : {
"NHWC",

View File

@ -1634,6 +1634,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
// Infer properties lazily in case they are not needed.
if (!ctx.inferred_graph_properties && RequiresInferredShapes(ctx, i)) {
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
// TODO(rmlarsen): Get rid of tensor value copies.
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
assume_valid_feeds,
/*aggressive_shape_inference=*/false,