parent
8e7014bcac
commit
a1661765dd
@ -989,10 +989,11 @@ bool ConstantFolding::IsFoldable(const NodeDef& node,
|
||||
}
|
||||
}
|
||||
|
||||
// Don't fold nodes that have no outgoing edges except whitelisted nodes.
|
||||
// Such nodes could be introduced by an earlier constant folding pass and are
|
||||
// preserved in case users want to fetch their values; re-processing them
|
||||
// would lead to an error of adding a duplicated node to graph.
|
||||
// No need to (and don't) fold nodes that have no outgoing edges except
|
||||
// whitelisted nodes. Such nodes could be introduced by an earlier constant
|
||||
// folding pass and are preserved in case users want to fetch their values;
|
||||
// re-processing them would lead to an error of adding a duplicated node
|
||||
// to graph.
|
||||
const auto& outputs = node_map_->GetOutputs(node.name());
|
||||
if (outputs.empty() &&
|
||||
nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) {
|
||||
@ -1028,7 +1029,6 @@ bool ConstantFolding::IsFoldable(const NodeDef& node,
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (is_merge && !merge_has_constant_input) return false;
|
||||
|
||||
// If we know the output shapes, make sure that the outputs are small enough
|
||||
// to materialize.
|
||||
@ -1050,7 +1050,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node,
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
return !is_merge || merge_has_constant_input;
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -2827,162 +2827,83 @@ bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
|
||||
// / \ / \
|
||||
// X Y C Y = leaves
|
||||
//
|
||||
// where C is constant, X is non-constant, Y may be constant or non-constant,
|
||||
// and '+' denotes an associative and commutative operator like addition or
|
||||
// multiplication. This optimization pushes constants down in the tree to
|
||||
// canonicalize it. Moreoever, in cases where the child node has a second
|
||||
// constant input Y we will create a leaf node that can be folded, e.g.
|
||||
// where C is constant and X is non-constant, and '+' denotes an
|
||||
// associative and commutative operator like addition or multiplication.
|
||||
// This optimization pushes constants down in the tree to canonicalize it.
|
||||
// Moreoever, in cases where the child node has a second constant input Y
|
||||
// we will create a leaf node that can be folded, e.g.
|
||||
//
|
||||
// Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
|
||||
//
|
||||
// We also handle the non-commutative cases of subtraction and division
|
||||
// by rotating the tree locally, e.g.
|
||||
// Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
|
||||
// Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
|
||||
//
|
||||
// Note: Don't touch BiasAdd since they can't handle vectors as their first
|
||||
// TODO(rmlarsen): Handle non-associative/non-commutative operators like
|
||||
// subtraction and division, as well as mixed subtraction/addition,
|
||||
// division/multiplication.
|
||||
// Don't touch BiasAdd since they can't handle vectors as their first
|
||||
// inputs.
|
||||
if (has_fetch_ && (IsAdd(*node) || IsMul(*node)) &&
|
||||
NumNonControlInputs(*node) == 2) {
|
||||
NodeDef* left_child = node_map_->GetNode(node->input(0));
|
||||
NodeDef* right_child = node_map_->GetNode(node->input(1));
|
||||
// One child must be constant, and the other the same op as the parent.
|
||||
if (node->op() != left_child->op() && node->op() != right_child->op()) {
|
||||
return false;
|
||||
}
|
||||
const bool left_child_is_constant = IsReallyConstant(*left_child);
|
||||
const bool right_child_is_constant = IsReallyConstant(*right_child);
|
||||
if (!left_child_is_constant && !right_child_is_constant) {
|
||||
return false;
|
||||
}
|
||||
if (node->device() != left_child->device() ||
|
||||
node->device() != right_child->device()) {
|
||||
return false;
|
||||
}
|
||||
NodeDef* op_child_node = left_child_is_constant ? right_child : left_child;
|
||||
NodeDef* const_child_node =
|
||||
left_child_is_constant ? left_child : right_child;
|
||||
// Make sure that it is safe to change the value of the child node->
|
||||
if (op_child_node->input_size() < 2 ||
|
||||
nodes_to_preserve_.find(op_child_node->name()) !=
|
||||
nodes_to_preserve_.end() ||
|
||||
NumNonControlOutputs(*op_child_node, *node_map_) > 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get parent op type.
|
||||
const bool is_add = IsAdd(*node);
|
||||
const bool is_mul = IsMul(*node);
|
||||
const bool is_sub = IsSub(*node);
|
||||
const bool is_div = IsDiv(*node);
|
||||
const bool is_symmetric = is_add || is_mul;
|
||||
if (!has_fetch_ || !(is_add || is_sub || is_mul || is_div) ||
|
||||
NumNonControlInputs(*node) != 2) {
|
||||
return false;
|
||||
}
|
||||
// Identify the nodes to swap.
|
||||
NodeDef* left_leaf = node_map_->GetNode(op_child_node->input(0));
|
||||
NodeDef* right_leaf = node_map_->GetNode(op_child_node->input(1));
|
||||
const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
|
||||
const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
|
||||
if (left_leaf_is_constant && right_leaf_is_constant) {
|
||||
// Child is already foldable, leave it alone.
|
||||
return false;
|
||||
}
|
||||
const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0;
|
||||
const int parent_const_input = left_child_is_constant ? 0 : 1;
|
||||
const auto& child_output = node_map_->GetOutputs(op_child_node->name());
|
||||
if (child_output.find(const_child_node) != child_output.end()) {
|
||||
// If there is a control edge from the child op to C, the transformation
|
||||
// would create a cycle in the graph. We know that it must be a control
|
||||
// edge. We can replace such a control edge with a control edge from A
|
||||
// to C.
|
||||
CHECK(MaybeRemoveControlInput(op_child_node->name(), const_child_node,
|
||||
optimized_graph, node_map_.get()));
|
||||
string other_leaf_input = left_leaf_is_constant ? op_child_node->input(0)
|
||||
: op_child_node->input(1);
|
||||
MaybeAddControlInput(other_leaf_input, const_child_node, optimized_graph,
|
||||
node_map_.get());
|
||||
}
|
||||
|
||||
NodeDef* left_child = node_map_->GetNode(node->input(0));
|
||||
NodeDef* right_child = node_map_->GetNode(node->input(1));
|
||||
|
||||
const bool left_child_is_constant = IsReallyConstant(*left_child);
|
||||
const bool right_child_is_constant = IsReallyConstant(*right_child);
|
||||
if (!left_child_is_constant && !right_child_is_constant) {
|
||||
return false;
|
||||
// Swap the constant child with a non-constant leaf node.
|
||||
node_map_->UpdateInput(node->name(), node->input(parent_const_input),
|
||||
op_child_node->input(non_const_leaf_input));
|
||||
node_map_->UpdateInput(op_child_node->name(),
|
||||
op_child_node->input(non_const_leaf_input),
|
||||
node->input(parent_const_input));
|
||||
std::swap(*node->mutable_input(parent_const_input),
|
||||
*op_child_node->mutable_input(non_const_leaf_input));
|
||||
return true;
|
||||
}
|
||||
// Don't move nodes across devices.
|
||||
if (node->device() != left_child->device() ||
|
||||
node->device() != right_child->device()) {
|
||||
return false;
|
||||
}
|
||||
NodeDef* op_child = left_child_is_constant ? right_child : left_child;
|
||||
NodeDef* const_child = left_child_is_constant ? left_child : right_child;
|
||||
// Don't rewrite the tree if it might create cycles.
|
||||
// TODO(rmlarsen): Add back handling of control dependency from op to C.
|
||||
const auto& child_output = node_map_->GetOutputs(op_child->name());
|
||||
if (child_output.find(const_child) != child_output.end()) {
|
||||
return false;
|
||||
}
|
||||
// Get child op type.
|
||||
const bool is_child_add = IsAdd(*op_child);
|
||||
const bool is_child_mul = IsMul(*op_child);
|
||||
const bool is_child_sub = IsSub(*op_child);
|
||||
const bool is_child_div = IsDiv(*op_child);
|
||||
const bool is_add_sub = (is_add || is_sub) && (is_child_add || is_child_sub);
|
||||
const bool is_mul_div = (is_mul || is_div) && (is_child_mul || is_child_div);
|
||||
if (!is_add_sub && !is_mul_div) {
|
||||
return false;
|
||||
}
|
||||
const bool is_child_symmetric = is_child_add || is_child_mul;
|
||||
// Make sure that it is safe to change the value of the child node result.
|
||||
if (op_child->input_size() < 2 ||
|
||||
nodes_to_preserve_.find(op_child->name()) != nodes_to_preserve_.end() ||
|
||||
NumNonControlOutputs(*op_child, *node_map_) > 1) {
|
||||
return false;
|
||||
}
|
||||
// Do not rewrite integer expressions with subtraction or division.
|
||||
if (!CheckAttrExists(*node, "T").ok()) return false;
|
||||
DataType dtype = node->attr().at("T").type();
|
||||
if (!(is_symmetric && is_child_symmetric) &&
|
||||
!(DataTypeIsFloating(dtype) || DataTypeIsComplex(dtype))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Identify the nodes to swap.
|
||||
NodeDef* left_leaf = node_map_->GetNode(op_child->input(0));
|
||||
NodeDef* right_leaf = node_map_->GetNode(op_child->input(1));
|
||||
const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
|
||||
const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
|
||||
if (left_leaf_is_constant && right_leaf_is_constant) {
|
||||
// Child is already foldable, leave it alone.
|
||||
return false;
|
||||
}
|
||||
// Don't move nodes across devices.
|
||||
if (node->device() != left_leaf->device() ||
|
||||
node->device() != right_leaf->device()) {
|
||||
return false;
|
||||
}
|
||||
// Get the node names corresponding to X, Y, and C.
|
||||
const string input_x =
|
||||
left_leaf_is_constant ? op_child->input(1) : op_child->input(0);
|
||||
const string input_y =
|
||||
input_x == op_child->input(0) ? op_child->input(1) : op_child->input(0);
|
||||
const string input_c =
|
||||
left_child_is_constant ? node->input(0) : node->input(1);
|
||||
const string input_op =
|
||||
left_child_is_constant ? node->input(1) : node->input(0);
|
||||
|
||||
// Now we have identified the nodes to swap (non_const_leaf_input and
|
||||
// const_child).
|
||||
node_map_->UpdateInput(node->name(), input_c, input_x);
|
||||
node_map_->AddOutput(input_c, op_child->name());
|
||||
if (input_x != input_y) {
|
||||
node_map_->RemoveOutput(input_x, op_child->name());
|
||||
}
|
||||
|
||||
if (is_symmetric && is_child_symmetric) {
|
||||
// Easy case (only commutative ops). We always write this as one of
|
||||
// +
|
||||
// / \
|
||||
// X +
|
||||
// / \
|
||||
// C Y
|
||||
node->set_input(0, input_x);
|
||||
node->set_input(1, input_op);
|
||||
op_child->set_input(0, input_c);
|
||||
op_child->set_input(1, input_y);
|
||||
} else {
|
||||
// More complicated case: When there are non-commutative operations like
|
||||
// subtractions or divisions involved, we may have to rotate the tree
|
||||
// and/or change op types. There are 6 non-trivial cases depending on
|
||||
// the effective generalized "sign" of each of the three terms C, Y, and X.
|
||||
// Here are the final trees we want to generate for those 6 cases:
|
||||
//
|
||||
// (CYX signs): ++- +-- -+- --+ +-+ -++
|
||||
// - - - - + +
|
||||
// / \ / \ / \ / \ / \ / \
|
||||
// + X - X - X X + X - X -
|
||||
// / \ / \ / \ / \ / \ / \
|
||||
// C Y C Y Y C Y C C Y Y C
|
||||
//
|
||||
NodeDef* non_const_leaf = left_leaf_is_constant ? right_leaf : left_leaf;
|
||||
NodeDef* maybe_const_leaf =
|
||||
non_const_leaf == right_leaf ? left_leaf : right_leaf;
|
||||
|
||||
// First, let's determine the effective sign of each term in the original
|
||||
// expression
|
||||
auto is_leaf_negated = [&](const NodeDef* node) -> bool {
|
||||
bool leaf_negated = !is_child_symmetric && (node == right_leaf);
|
||||
bool child_negated = !is_symmetric && (op_child == right_child);
|
||||
return leaf_negated != child_negated;
|
||||
};
|
||||
const string symmetric_op = (is_add || is_sub) ? "Add" : "Mul";
|
||||
const string nonsymmetric_op = (is_add || is_sub) ? "Sub" : "Div";
|
||||
bool neg_c = !is_symmetric && (const_child == right_child);
|
||||
bool neg_x = is_leaf_negated(non_const_leaf);
|
||||
bool neg_y = is_leaf_negated(maybe_const_leaf);
|
||||
// Rewrite the parent node.
|
||||
node->set_op((neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op);
|
||||
node->set_input(0, neg_x ? input_op : input_x);
|
||||
node->set_input(1, neg_x ? input_x : input_op);
|
||||
// Rewrite the child node.
|
||||
op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
|
||||
op_child->set_input(0, neg_c ? input_y : input_c);
|
||||
op_child->set_input(1, neg_c ? input_c : input_y);
|
||||
}
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
|
||||
|
@ -255,19 +255,20 @@ TEST_F(ConstantFoldingTest, SimpleFolding) {
|
||||
TEST_F(ConstantFoldingTest, AddTree) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {1});
|
||||
Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
|
||||
Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
|
||||
Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
|
||||
ops::Placeholder::Shape(TensorShape({2, 2})));
|
||||
Output add_child = ops::Add(s.WithOpName("add_child"), c2, x);
|
||||
Output c1 = ops::Const(s.WithOpName("c1").WithControlDependencies(add_child),
|
||||
1.0f, {1});
|
||||
Output add_parent = ops::Add(s.WithOpName("add_parent"), c1, add_child);
|
||||
|
||||
Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
|
||||
ops::Placeholder::Shape(TensorShape({2, 2})));
|
||||
Output c4 = ops::Const(s.WithOpName("c4"), 4.0f, {2});
|
||||
Output c5 = ops::Const(s.WithOpName("c5"), 5.0f, {2});
|
||||
Output c20 = ops::Const(s.WithOpName("c20"), 20.0f, {2});
|
||||
Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
|
||||
ops::Placeholder::Shape(TensorShape({2, 2})));
|
||||
Output mul_child = ops::Mul(s.WithOpName("mul_child"), c4, y);
|
||||
Output mul_parent = ops::Mul(s.WithOpName("mul_parent"), c5, mul_child);
|
||||
Output addmul_child = ops::Add(s.WithOpName("addmul_child"), c4, x);
|
||||
@ -297,16 +298,16 @@ TEST_F(ConstantFoldingTest, AddTree) {
|
||||
// / \ / \
|
||||
// 5.0 y 4.0 5.0
|
||||
|
||||
EXPECT_EQ(10, output.node_size());
|
||||
EXPECT_EQ(11, output.node_size());
|
||||
for (const auto& node : output.node()) {
|
||||
if (node.name() == "add_child") {
|
||||
EXPECT_EQ("Const", node.op());
|
||||
TensorProto t = node.attr().at("value").tensor();
|
||||
ASSERT_EQ(1, t.tensor_shape().dim_size());
|
||||
EXPECT_EQ(1, t.tensor_shape().dim_size());
|
||||
EXPECT_EQ(2, t.tensor_shape().dim(0).size());
|
||||
} else if (node.name() == "add_parent") {
|
||||
EXPECT_EQ("Add", node.op());
|
||||
ASSERT_EQ(2, node.input_size());
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("x", node.input(0));
|
||||
EXPECT_EQ("add_child", node.input(1));
|
||||
} else if (node.name() == "mul_child") {
|
||||
@ -316,106 +317,30 @@ TEST_F(ConstantFoldingTest, AddTree) {
|
||||
EXPECT_EQ(2, t.tensor_shape().dim(0).size());
|
||||
} else if (node.name() == "mul_parent") {
|
||||
EXPECT_EQ("Mul", node.op());
|
||||
ASSERT_EQ(2, node.input_size());
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("y", node.input(0));
|
||||
EXPECT_EQ("mul_child", node.input(1));
|
||||
} else if (node.name() == "addmul_child") {
|
||||
// Unchanged.
|
||||
EXPECT_EQ("Add", node.op());
|
||||
ASSERT_EQ(2, node.input_size());
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("c4", node.input(0));
|
||||
EXPECT_EQ("x", node.input(1));
|
||||
}
|
||||
}
|
||||
|
||||
// Check that the result nodes have the expected value.
|
||||
auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
|
||||
auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
|
||||
|
||||
std::vector<string> fetch = {"add_parent", "mul_parent"};
|
||||
auto tensor_expected =
|
||||
EvaluateNodes(item.graph, fetch, {{"x", x_t}, {"y", y_t}});
|
||||
ASSERT_EQ(fetch.size(), tensor_expected.size());
|
||||
fetch = {"add_parent", "mul_parent"};
|
||||
auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}, {"y", y_t}});
|
||||
ASSERT_EQ(fetch.size(), tensors.size());
|
||||
std::vector<string> fetch = {"c3", "c20"};
|
||||
auto tensor_expected = EvaluateNodes(item.graph, fetch);
|
||||
EXPECT_EQ(fetch.size(), tensor_expected.size());
|
||||
fetch = {"add_child", "mul_child"};
|
||||
auto tensors = EvaluateNodes(output, fetch);
|
||||
EXPECT_EQ(fetch.size(), tensors.size());
|
||||
for (int i = 0; i < fetch.size(); i++) {
|
||||
test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, TreeCanonicalization) {
|
||||
for (int is_add : {true, false}) {
|
||||
for (int is_parent_commutative : {true, false}) {
|
||||
for (int is_child_commutative : {true, false}) {
|
||||
for (int is_left_child_const : {true, false}) {
|
||||
for (int is_left_leaf_const : {true, false}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
|
||||
Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
|
||||
Output x =
|
||||
ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
|
||||
ops::Placeholder::Shape(TensorShape({2, 2})));
|
||||
|
||||
auto get_op = [&](bool is_commutative, bool is_left_arg_cont,
|
||||
const string& name, const Output& const_arg,
|
||||
const Output non_const_arg) -> Output {
|
||||
if (is_add) {
|
||||
if (is_commutative) {
|
||||
return ops::Add(s.WithOpName(name),
|
||||
is_left_arg_cont ? const_arg : non_const_arg,
|
||||
is_left_arg_cont ? non_const_arg : const_arg);
|
||||
} else {
|
||||
return ops::Sub(s.WithOpName(name),
|
||||
is_left_arg_cont ? const_arg : non_const_arg,
|
||||
is_left_arg_cont ? non_const_arg : const_arg);
|
||||
}
|
||||
} else {
|
||||
if (is_commutative) {
|
||||
return ops::Mul(s.WithOpName(name),
|
||||
is_left_arg_cont ? const_arg : non_const_arg,
|
||||
is_left_arg_cont ? non_const_arg : const_arg);
|
||||
} else {
|
||||
return ops::Div(s.WithOpName(name),
|
||||
is_left_arg_cont ? const_arg : non_const_arg,
|
||||
is_left_arg_cont ? non_const_arg : const_arg);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Output child = get_op(is_child_commutative, is_left_leaf_const,
|
||||
"child", c2, x);
|
||||
Output parent = get_op(is_parent_commutative, is_left_child_const,
|
||||
"parent", c3, child);
|
||||
GrapplerItem item;
|
||||
item.fetch = {"parent"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding optimizer(/*cpu_device=*/nullptr);
|
||||
GraphDef output;
|
||||
Status status =
|
||||
optimizer.Optimize(/*cluster=*/nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
// Check that the result nodes have the expected value.
|
||||
auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
|
||||
std::vector<string> fetch = {"parent"};
|
||||
auto tensor_expected =
|
||||
EvaluateNodes(item.graph, fetch, {{"x", x_t}});
|
||||
ASSERT_EQ(fetch.size(), tensor_expected.size());
|
||||
fetch = {"parent"};
|
||||
auto tensors = EvaluateNodes(output, fetch, {{"x", x_t}});
|
||||
ASSERT_EQ(fetch.size(), tensors.size());
|
||||
for (int i = 0; i < fetch.size(); i++) {
|
||||
test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_ScalarConst) {
|
||||
for (string data_format : {
|
||||
"NHWC",
|
||||
|
@ -1634,6 +1634,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
// Infer properties lazily in case they are not needed.
|
||||
if (!ctx.inferred_graph_properties && RequiresInferredShapes(ctx, i)) {
|
||||
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
|
||||
// TODO(rmlarsen): Get rid of tensor value copies.
|
||||
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
|
||||
assume_valid_feeds,
|
||||
/*aggressive_shape_inference=*/false,
|
||||
|
Loading…
Reference in New Issue
Block a user