Optimize LogicalOr and LogicalAnd with all true or false inputs:

LogicalOr(x, true) = true
LogicalOr(x, false) = x
LogicalAnd(x, true) = x
LogicalAnd(x, false) = false

and similar if the first argument is constant.

PiperOrigin-RevId: 195161140
This commit is contained in:
A. Unique TensorFlower 2018-05-02 15:21:17 -07:00 committed by TensorFlower Gardener
parent 9180cc254d
commit 4704ae7af1
4 changed files with 133 additions and 65 deletions

View File

@ -866,6 +866,25 @@ Status CreateConstantTensorAttrValue(DataType type, double value,
} }
#undef SET_TENSOR_CAL_CASE #undef SET_TENSOR_CAL_CASE
DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
const GraphProperties& properties) {
DataType dtype = DT_INVALID;
if (node.attr().count("T") == 1) {
dtype = node.attr().at("T").type();
} else if (node.attr().count("dtype") == 1) {
dtype = node.attr().at("dtype").type();
} else if (IsLogicalOr(node) || IsLogicalAnd(node)) {
dtype = DT_BOOL;
} else {
auto output_props = properties.GetOutputProperties(node.name());
if (!output_props.empty()) {
dtype = output_props[0].dtype();
}
}
return dtype;
}
} // namespace } // namespace
// static // static
@ -1412,6 +1431,7 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const {
} }
const auto dtype = node.attr().at("dtype").type(); const auto dtype = node.attr().at("dtype").type();
switch (dtype) { switch (dtype) {
IS_ONES_CASE(DT_BOOL);
IS_ONES_CASE(DT_HALF); IS_ONES_CASE(DT_HALF);
IS_ONES_CASE(DT_BFLOAT16); IS_ONES_CASE(DT_BFLOAT16);
IS_ONES_CASE(DT_FLOAT); IS_ONES_CASE(DT_FLOAT);
@ -1447,6 +1467,7 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
} }
const auto dtype = node.attr().at("dtype").type(); const auto dtype = node.attr().at("dtype").type();
switch (dtype) { switch (dtype) {
IS_ZEROS_CASE(DT_BOOL);
IS_ZEROS_CASE(DT_HALF); IS_ZEROS_CASE(DT_HALF);
IS_ZEROS_CASE(DT_BFLOAT16); IS_ZEROS_CASE(DT_BFLOAT16);
IS_ZEROS_CASE(DT_FLOAT); IS_ZEROS_CASE(DT_FLOAT);
@ -1466,14 +1487,15 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
return false; return false;
} }
void ConstantFolding::ReplaceOperationWithIdentity(int input_to_forward, void ConstantFolding::ReplaceOperationWithIdentity(
NodeDef* node, int input_to_forward, const GraphProperties& properties, NodeDef* node,
GraphDef* graph) { GraphDef* graph) {
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_INVALID) return;
node->set_op("Identity"); node->set_op("Identity");
DataType dtype = node->attr().at("T").type();
node->clear_attr(); node->clear_attr();
(*node->mutable_attr())["T"].set_type(dtype); (*node->mutable_attr())["T"].set_type(dtype);
// Propagate the designated input through the identity. // Propagate the designated input through the identity.
node->mutable_input()->SwapElements(0, input_to_forward); node->mutable_input()->SwapElements(0, input_to_forward);
// Add all other inputs as control dependencies. // Add all other inputs as control dependencies.
@ -1489,14 +1511,15 @@ void ConstantFolding::ReplaceOperationWithIdentity(int input_to_forward,
graph_modified_ = true; graph_modified_ = true;
} }
void ConstantFolding::ReplaceOperationWithSnapshot(int input_to_forward, void ConstantFolding::ReplaceOperationWithSnapshot(
NodeDef* node, int input_to_forward, const GraphProperties& properties, NodeDef* node,
GraphDef* graph) { GraphDef* graph) {
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_INVALID) return;
node->set_op("Snapshot"); node->set_op("Snapshot");
DataType dtype = node->attr().at("T").type();
node->clear_attr(); node->clear_attr();
(*node->mutable_attr())["T"].set_type(dtype); (*node->mutable_attr())["T"].set_type(dtype);
// Propagate the designated input through the Snapshot. // Propagate the designated input through the Snapshot.
node->mutable_input()->SwapElements(0, input_to_forward); node->mutable_input()->SwapElements(0, input_to_forward);
// Add all other inputs as control dependencies. // Add all other inputs as control dependencies.
@ -1535,15 +1558,18 @@ void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
} }
Status ConstantFolding::ReplaceOperationWithConstant( Status ConstantFolding::ReplaceOperationWithConstant(
double value, const AttrValue& dtype_attr, const TensorShapeProto& shape, double value, const GraphProperties& properties,
NodeDef* node, GraphDef* graph) { const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_INVALID) return Status::OK();
AttrValue tensor_attr; AttrValue tensor_attr;
TF_RETURN_IF_ERROR(CreateConstantTensorAttrValue(dtype_attr.type(), value, TF_RETURN_IF_ERROR(
shape, &tensor_attr)); CreateConstantTensorAttrValue(dtype, value, shape, &tensor_attr));
node->clear_attr();
node->mutable_attr()->insert({"dtype", dtype_attr});
node->mutable_attr()->insert({"value", tensor_attr});
node->set_op("Const"); node->set_op("Const");
node->clear_attr();
(*node->mutable_attr())["dtype"].set_type(dtype);
node->mutable_attr()->insert({"value", tensor_attr});
// Convert all inputs to control dependencies. // Convert all inputs to control dependencies.
for (int i = 0; i < node->input_size(); ++i) { for (int i = 0; i < node->input_size(); ++i) {
if (IsControlInput(node->input(i))) { if (IsControlInput(node->input(i))) {
@ -1566,12 +1592,12 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
NodeDef* node = optimized_graph->mutable_node(i); NodeDef* node = optimized_graph->mutable_node(i);
if (IsSplit(*node) && node->attr().at("num_split").i() == 1) { if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
ReplaceOperationWithIdentity(1, node, optimized_graph); ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
continue; continue;
} }
if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) { if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
ReplaceOperationWithIdentity(0, node, optimized_graph); ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue; continue;
} }
@ -1611,7 +1637,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
replaceable &= shape.dim(j).size() == 1 || j == permutation[j]; replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, node, optimized_graph); ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue; continue;
} }
} }
@ -1626,7 +1652,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
// unknown_rank == false && (dim_size == 0 || first dim is of size 1) // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
if (!shape.unknown_rank() && if (!shape.unknown_rank() &&
(shape.dim_size() == 0 || shape.dim(0).size() == 1)) { (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
ReplaceOperationWithIdentity(0, node, optimized_graph); ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue; continue;
} }
} }
@ -1651,11 +1677,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
for (int j = 0; j < axis.NumElements(); ++j) { for (int j = 0; j < axis.NumElements(); ++j) {
// value of axis can be negative. // value of axis can be negative.
if (axis.dtype() == DT_INT64) { if (axis.dtype() == DT_INT64) {
target_axes.insert( target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
(axis.vec<int64>()(j) + shape.dim_size()) % shape.dim_size()); shape.dim_size());
} else { } else {
target_axes.insert( target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
(axis.vec<int>()(j) + shape.dim_size()) % shape.dim_size()); shape.dim_size());
} }
} }
@ -1669,7 +1695,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
target_axes.find(j) == target_axes.end(); target_axes.find(j) == target_axes.end();
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, node, optimized_graph); ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue; continue;
} }
} }
@ -1711,7 +1737,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
} }
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, node, optimized_graph); ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue; continue;
} }
} }
@ -1740,7 +1766,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
} }
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, node, optimized_graph); ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue; continue;
} }
} }
@ -1764,7 +1790,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
replaceable &= flatten(j) == 0; replaceable &= flatten(j) == 0;
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, node, optimized_graph); ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue; continue;
} }
} }
@ -1784,7 +1810,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
replaceable &= shape.dim(j).size() > 1; replaceable &= shape.dim(j).size() > 1;
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, node, optimized_graph); ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue; continue;
} }
} }
@ -1996,9 +2022,9 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
continue; continue;
} }
const bool is_mul = IsMul(*node); const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
const bool is_matmul = IsMatMul(*node); const bool is_matmul = IsMatMul(*node);
const bool is_add = IsAdd(*node) || IsBiasAdd(*node); const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
const bool is_sub = IsSub(*node); const bool is_sub = IsSub(*node);
const bool is_any_div = IsAnyDiv(*node); const bool is_any_div = IsAnyDiv(*node);
// Simplify arithmetic operations with ones or zeros. // Simplify arithmetic operations with ones or zeros.
@ -2025,7 +2051,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
if (y_matches_output_shape && if (y_matches_output_shape &&
((is_mul && x_is_one) || (is_add && x_is_zero))) { ((is_mul && x_is_one) || (is_add && x_is_zero))) {
// 1 * y = y or 0 + y = y. // 1 * y = y or 0 + y = y.
ReplaceOperationWithSnapshot(1, node, optimized_graph); ReplaceOperationWithSnapshot(1, *properties, node, optimized_graph);
continue; continue;
} }
@ -2052,10 +2078,18 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) || if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero))) { ((is_add || is_sub) && y_is_zero))) {
// x * 1 = x or x / 1 = x or x +/- 0 = x // x * 1 = x or x / 1 = x or x +/- 0 = x
ReplaceOperationWithSnapshot(0, node, optimized_graph); ReplaceOperationWithSnapshot(0, *properties, node, optimized_graph);
continue; continue;
} }
// x OR true = true OR y = true.
const PartialTensorShape shp(output_shape);
if (shp.IsFullyDefined() && IsLogicalOr(*node) &&
(y_is_one || x_is_one)) {
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
1, *properties, output_shape, node, optimized_graph));
}
// Simplify multiplication and matmul by zeros. // Simplify multiplication and matmul by zeros.
// Also optimize zeros divided by a tensor, but only if we are in // Also optimize zeros divided by a tensor, but only if we are in
// aggressive mode, since we might get rid of divisions by zero. // aggressive mode, since we might get rid of divisions by zero.
@ -2063,26 +2097,19 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
is_any_div && x_is_zero && is_aggressive; is_any_div && x_is_zero && is_aggressive;
if ((x_is_zero || y_is_zero) && if ((x_is_zero || y_is_zero) &&
(is_mul || is_matmul || optimize_zeros_divided_by_y)) { (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
const PartialTensorShape shp(output_shape);
if (shp.IsFullyDefined()) { if (shp.IsFullyDefined()) {
AttrValue dtype_attr;
if (node->op() == "SparseMatMul") {
dtype_attr.set_type(DT_FLOAT);
} else {
dtype_attr = node->attr().at("T");
}
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant( TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
0, dtype_attr, output_shape, node, optimized_graph)); 0, *properties, output_shape, node, optimized_graph));
continue; continue;
} }
// Even if an input shape is only partially known, we may known that it // Even if an input shape is only partially known, we may known that it
// matches the output shape and thus forward the corresponding zero // matches the output shape and thus forward the corresponding zero
// input. // input.
if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) { if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
ReplaceOperationWithIdentity(0, node, optimized_graph); ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue; continue;
} else if (is_mul && y_is_zero && y_matches_output_shape) { } else if (is_mul && y_is_zero && y_matches_output_shape) {
ReplaceOperationWithIdentity(1, node, optimized_graph); ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
continue; continue;
} }
} }

View File

@ -78,12 +78,15 @@ class ConstantFolding : public GraphOptimizer {
bool IsOnes(const NodeDef& node) const; bool IsOnes(const NodeDef& node) const;
bool IsZeros(const NodeDef& node) const; bool IsZeros(const NodeDef& node) const;
void ReplaceOperationWithIdentity(int input_to_forward, NodeDef* node, void ReplaceOperationWithIdentity(int input_to_forward,
GraphDef* graph); const GraphProperties& properties,
void ReplaceOperationWithSnapshot(int input_to_forward, NodeDef* node, NodeDef* node, GraphDef* graph);
GraphDef* graph); void ReplaceOperationWithSnapshot(int input_to_forward,
const GraphProperties& properties,
NodeDef* node, GraphDef* graph);
void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph); void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
Status ReplaceOperationWithConstant(double value, const AttrValue& dtype_attr, Status ReplaceOperationWithConstant(double value,
const GraphProperties& properties,
const TensorShapeProto& shape, const TensorShapeProto& shape,
NodeDef* node, GraphDef* graph); NodeDef* node, GraphDef* graph);
void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph); void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);

View File

@ -47,18 +47,30 @@ class ConstantFoldingTest : public GrapplerTest {
} }
Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t); Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
Output ones = ops::Const(s.WithOpName("ones"), ones_t); Output ones = ops::Const(s.WithOpName("ones"), ones_t);
Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros); Output mul1;
Output mul2 = ops::Mul(s.WithOpName("mul2"), x, ones); Output mul2;
Output add1;
Output add2;
if (DTYPE == DT_BOOL) {
mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros);
mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones);
add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros);
add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones);
} else {
mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
add1 = ops::Add(s.WithOpName("add1"), x, zeros);
add1 = ops::Add(s.WithOpName("add2"), x, ones);
}
GrapplerItem item; GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph)); TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch = {"mul1", "mul2"}; item.fetch = {"mul1", "mul2", "add1", "add2"};
ConstantFolding optimizer(nullptr /* cpu_device */); ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output; GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output); Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status); TF_EXPECT_OK(status);
LOG(INFO) << output.DebugString();
EXPECT_EQ(5, output.node_size()); EXPECT_EQ(7, output.node_size());
for (int i = 0; i < output.node_size(); ++i) { for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i); const NodeDef& node = output.node(i);
const string& name = node.name(); const string& name = node.name();
@ -70,14 +82,27 @@ class ConstantFoldingTest : public GrapplerTest {
EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0)); EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^ones", node.input(1)); EXPECT_EQ("^ones", node.input(1));
} else if (name == "add1") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^zeros", node.input(1));
} else if (name == "add2") {
if (DTYPE == DT_BOOL) {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^x", node.input(0));
EXPECT_EQ("^ones", node.input(1));
} else {
EXPECT_EQ("Add", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("ones", node.input(1));
}
} }
} }
auto tensors_expected = auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
EvaluateNodes(item.graph, {"mul1", "mul2"}, {{"x", x_t}}); auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
auto tensors = EvaluateNodes(output, {"mul1", "mul2"}, {{"x", x_t}}); EXPECT_EQ(4, tensors_expected.size());
EXPECT_EQ(2, tensors_expected.size()); EXPECT_EQ(4, tensors.size());
EXPECT_EQ(2, tensors.size()); for (int i = 0; i < item.fetch.size(); ++i) {
for (int i = 0; i < 2; ++i) {
test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]); test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
} }
} }
@ -393,6 +418,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
} }
TEST_F(ConstantFoldingTest, NeutralElement_ShortFloats) { TEST_F(ConstantFoldingTest, NeutralElement_ShortFloats) {
SimpleNeutralElementTest<DT_BOOL>();
SimpleNeutralElementTest<DT_HALF>(); SimpleNeutralElementTest<DT_HALF>();
SimpleNeutralElementTest<DT_BFLOAT16>(); SimpleNeutralElementTest<DT_BFLOAT16>();
} }

View File

@ -25,6 +25,8 @@ import numpy as np
from tensorflow.core.example import example_pb2 from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2 from tensorflow.core.example import feature_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -54,8 +56,8 @@ from tensorflow.python.training import coordinator
from tensorflow.python.training import queue_runner_impl from tensorflow.python.training import queue_runner_impl
def _initialized_session(): def _initialized_session(config=None):
sess = session.Session() sess = session.Session(config=config)
sess.run(variables_lib.global_variables_initializer()) sess.run(variables_lib.global_variables_initializer())
sess.run(lookup_ops.tables_initializer()) sess.run(lookup_ops.tables_initializer())
return sess return sess
@ -6191,7 +6193,12 @@ class WeightedCategoricalColumnTest(test.TestCase):
'values': ((.5,), (1.,)) 'values': ((.5,), (1.,))
}, (column,), }, (column,),
sparse_combiner='mean') sparse_combiner='mean')
with _initialized_session(): # Disabling the constant folding optimizer here since it changes the
# error message differently on CPU and GPU.
config = config_pb2.ConfigProto()
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
with _initialized_session(config):
with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'): with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
predictions.eval() predictions.eval()
@ -6284,7 +6291,12 @@ class WeightedCategoricalColumnTest(test.TestCase):
'values': ((.5,), (1.,)) 'values': ((.5,), (1.,))
}, (column,), }, (column,),
sparse_combiner='mean') sparse_combiner='mean')
with _initialized_session(): # Disabling the constant folding optimizer here since it changes the
# error message differently on CPU and GPU.
config = config_pb2.ConfigProto()
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
with _initialized_session(config):
with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'): with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
predictions.eval() predictions.eval()