Optimize LogicalOr and LogicalAnd with all true or false inputs:

LogicalOr(x, true) = true
LogicalOr(x, false) = x
LogicalAnd(x, true) = x
LogicalAnd(x, false) = false

and similar if the first argument is constant.

PiperOrigin-RevId: 195161140
This commit is contained in:
A. Unique TensorFlower 2018-05-02 15:21:17 -07:00 committed by TensorFlower Gardener
parent 9180cc254d
commit 4704ae7af1
4 changed files with 133 additions and 65 deletions

View File

@ -866,6 +866,25 @@ Status CreateConstantTensorAttrValue(DataType type, double value,
}
#undef SET_TENSOR_CAL_CASE
DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
const GraphProperties& properties) {
DataType dtype = DT_INVALID;
if (node.attr().count("T") == 1) {
dtype = node.attr().at("T").type();
} else if (node.attr().count("dtype") == 1) {
dtype = node.attr().at("dtype").type();
} else if (IsLogicalOr(node) || IsLogicalAnd(node)) {
dtype = DT_BOOL;
} else {
auto output_props = properties.GetOutputProperties(node.name());
if (!output_props.empty()) {
dtype = output_props[0].dtype();
}
}
return dtype;
}
} // namespace
// static
@ -1412,6 +1431,7 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const {
}
const auto dtype = node.attr().at("dtype").type();
switch (dtype) {
IS_ONES_CASE(DT_BOOL);
IS_ONES_CASE(DT_HALF);
IS_ONES_CASE(DT_BFLOAT16);
IS_ONES_CASE(DT_FLOAT);
@ -1447,6 +1467,7 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
}
const auto dtype = node.attr().at("dtype").type();
switch (dtype) {
IS_ZEROS_CASE(DT_BOOL);
IS_ZEROS_CASE(DT_HALF);
IS_ZEROS_CASE(DT_BFLOAT16);
IS_ZEROS_CASE(DT_FLOAT);
@ -1466,14 +1487,15 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
return false;
}
void ConstantFolding::ReplaceOperationWithIdentity(int input_to_forward,
NodeDef* node,
GraphDef* graph) {
void ConstantFolding::ReplaceOperationWithIdentity(
int input_to_forward, const GraphProperties& properties, NodeDef* node,
GraphDef* graph) {
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_INVALID) return;
node->set_op("Identity");
DataType dtype = node->attr().at("T").type();
node->clear_attr();
(*node->mutable_attr())["T"].set_type(dtype);
// Propagate the designated input through the identity.
node->mutable_input()->SwapElements(0, input_to_forward);
// Add all other inputs as control dependencies.
@ -1489,14 +1511,15 @@ void ConstantFolding::ReplaceOperationWithIdentity(int input_to_forward,
graph_modified_ = true;
}
void ConstantFolding::ReplaceOperationWithSnapshot(int input_to_forward,
NodeDef* node,
GraphDef* graph) {
void ConstantFolding::ReplaceOperationWithSnapshot(
int input_to_forward, const GraphProperties& properties, NodeDef* node,
GraphDef* graph) {
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_INVALID) return;
node->set_op("Snapshot");
DataType dtype = node->attr().at("T").type();
node->clear_attr();
(*node->mutable_attr())["T"].set_type(dtype);
// Propagate the designated input through the Snapshot.
node->mutable_input()->SwapElements(0, input_to_forward);
// Add all other inputs as control dependencies.
@ -1535,15 +1558,18 @@ void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
}
Status ConstantFolding::ReplaceOperationWithConstant(
double value, const AttrValue& dtype_attr, const TensorShapeProto& shape,
NodeDef* node, GraphDef* graph) {
double value, const GraphProperties& properties,
const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_INVALID) return Status::OK();
AttrValue tensor_attr;
TF_RETURN_IF_ERROR(CreateConstantTensorAttrValue(dtype_attr.type(), value,
shape, &tensor_attr));
node->clear_attr();
node->mutable_attr()->insert({"dtype", dtype_attr});
node->mutable_attr()->insert({"value", tensor_attr});
TF_RETURN_IF_ERROR(
CreateConstantTensorAttrValue(dtype, value, shape, &tensor_attr));
node->set_op("Const");
node->clear_attr();
(*node->mutable_attr())["dtype"].set_type(dtype);
node->mutable_attr()->insert({"value", tensor_attr});
// Convert all inputs to control dependencies.
for (int i = 0; i < node->input_size(); ++i) {
if (IsControlInput(node->input(i))) {
@ -1566,12 +1592,12 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
NodeDef* node = optimized_graph->mutable_node(i);
if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
ReplaceOperationWithIdentity(1, node, optimized_graph);
ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
continue;
}
if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
ReplaceOperationWithIdentity(0, node, optimized_graph);
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue;
}
@ -1611,7 +1637,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
}
if (replaceable) {
ReplaceOperationWithIdentity(0, node, optimized_graph);
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue;
}
}
@ -1626,7 +1652,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
// unknown_rank == false && (dim_size == 0 || first dim is of size 1)
if (!shape.unknown_rank() &&
(shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
ReplaceOperationWithIdentity(0, node, optimized_graph);
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue;
}
}
@ -1651,11 +1677,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
for (int j = 0; j < axis.NumElements(); ++j) {
// value of axis can be negative.
if (axis.dtype() == DT_INT64) {
target_axes.insert(
(axis.vec<int64>()(j) + shape.dim_size()) % shape.dim_size());
target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
shape.dim_size());
} else {
target_axes.insert(
(axis.vec<int>()(j) + shape.dim_size()) % shape.dim_size());
target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
shape.dim_size());
}
}
@ -1669,7 +1695,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
target_axes.find(j) == target_axes.end();
}
if (replaceable) {
ReplaceOperationWithIdentity(0, node, optimized_graph);
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue;
}
}
@ -1711,7 +1737,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
}
}
if (replaceable) {
ReplaceOperationWithIdentity(0, node, optimized_graph);
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue;
}
}
@ -1740,7 +1766,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
}
}
if (replaceable) {
ReplaceOperationWithIdentity(0, node, optimized_graph);
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue;
}
}
@ -1764,7 +1790,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
replaceable &= flatten(j) == 0;
}
if (replaceable) {
ReplaceOperationWithIdentity(0, node, optimized_graph);
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue;
}
}
@ -1784,7 +1810,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
replaceable &= shape.dim(j).size() > 1;
}
if (replaceable) {
ReplaceOperationWithIdentity(0, node, optimized_graph);
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue;
}
}
@ -1996,9 +2022,9 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
continue;
}
const bool is_mul = IsMul(*node);
const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
const bool is_matmul = IsMatMul(*node);
const bool is_add = IsAdd(*node) || IsBiasAdd(*node);
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
const bool is_sub = IsSub(*node);
const bool is_any_div = IsAnyDiv(*node);
// Simplify arithmetic operations with ones or zeros.
@ -2025,7 +2051,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
if (y_matches_output_shape &&
((is_mul && x_is_one) || (is_add && x_is_zero))) {
// 1 * y = y or 0 + y = y.
ReplaceOperationWithSnapshot(1, node, optimized_graph);
ReplaceOperationWithSnapshot(1, *properties, node, optimized_graph);
continue;
}
@ -2052,10 +2078,18 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero))) {
// x * 1 = x or x / 1 = x or x +/- 0 = x
ReplaceOperationWithSnapshot(0, node, optimized_graph);
ReplaceOperationWithSnapshot(0, *properties, node, optimized_graph);
continue;
}
// x OR true = true OR y = true.
const PartialTensorShape shp(output_shape);
if (shp.IsFullyDefined() && IsLogicalOr(*node) &&
(y_is_one || x_is_one)) {
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
1, *properties, output_shape, node, optimized_graph));
}
// Simplify multiplication and matmul by zeros.
// Also optimize zeros divided by a tensor, but only if we are in
// aggressive mode, since we might get rid of divisions by zero.
@ -2063,26 +2097,19 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
is_any_div && x_is_zero && is_aggressive;
if ((x_is_zero || y_is_zero) &&
(is_mul || is_matmul || optimize_zeros_divided_by_y)) {
const PartialTensorShape shp(output_shape);
if (shp.IsFullyDefined()) {
AttrValue dtype_attr;
if (node->op() == "SparseMatMul") {
dtype_attr.set_type(DT_FLOAT);
} else {
dtype_attr = node->attr().at("T");
}
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
0, dtype_attr, output_shape, node, optimized_graph));
0, *properties, output_shape, node, optimized_graph));
continue;
}
// Even if an input shape is only partially known, we may known that it
// matches the output shape and thus forward the corresponding zero
// input.
if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
ReplaceOperationWithIdentity(0, node, optimized_graph);
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
continue;
} else if (is_mul && y_is_zero && y_matches_output_shape) {
ReplaceOperationWithIdentity(1, node, optimized_graph);
ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
continue;
}
}

View File

@ -78,12 +78,15 @@ class ConstantFolding : public GraphOptimizer {
bool IsOnes(const NodeDef& node) const;
bool IsZeros(const NodeDef& node) const;
void ReplaceOperationWithIdentity(int input_to_forward, NodeDef* node,
GraphDef* graph);
void ReplaceOperationWithSnapshot(int input_to_forward, NodeDef* node,
GraphDef* graph);
void ReplaceOperationWithIdentity(int input_to_forward,
const GraphProperties& properties,
NodeDef* node, GraphDef* graph);
void ReplaceOperationWithSnapshot(int input_to_forward,
const GraphProperties& properties,
NodeDef* node, GraphDef* graph);
void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
Status ReplaceOperationWithConstant(double value, const AttrValue& dtype_attr,
Status ReplaceOperationWithConstant(double value,
const GraphProperties& properties,
const TensorShapeProto& shape,
NodeDef* node, GraphDef* graph);
void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);

View File

@ -47,18 +47,30 @@ class ConstantFoldingTest : public GrapplerTest {
}
Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
Output ones = ops::Const(s.WithOpName("ones"), ones_t);
Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
Output mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
Output mul1;
Output mul2;
Output add1;
Output add2;
if (DTYPE == DT_BOOL) {
mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros);
mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones);
add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros);
add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones);
} else {
mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
add1 = ops::Add(s.WithOpName("add1"), x, zeros);
add1 = ops::Add(s.WithOpName("add2"), x, ones);
}
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch = {"mul1", "mul2"};
item.fetch = {"mul1", "mul2", "add1", "add2"};
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
LOG(INFO) << output.DebugString();
EXPECT_EQ(5, output.node_size());
EXPECT_EQ(7, output.node_size());
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
const string& name = node.name();
@ -70,14 +82,27 @@ class ConstantFoldingTest : public GrapplerTest {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^ones", node.input(1));
} else if (name == "add1") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^zeros", node.input(1));
} else if (name == "add2") {
if (DTYPE == DT_BOOL) {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^x", node.input(0));
EXPECT_EQ("^ones", node.input(1));
} else {
EXPECT_EQ("Add", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("ones", node.input(1));
}
}
}
auto tensors_expected =
EvaluateNodes(item.graph, {"mul1", "mul2"}, {{"x", x_t}});
auto tensors = EvaluateNodes(output, {"mul1", "mul2"}, {{"x", x_t}});
EXPECT_EQ(2, tensors_expected.size());
EXPECT_EQ(2, tensors.size());
for (int i = 0; i < 2; ++i) {
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
EXPECT_EQ(4, tensors_expected.size());
EXPECT_EQ(4, tensors.size());
for (int i = 0; i < item.fetch.size(); ++i) {
test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
}
}
@ -393,6 +418,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
}
TEST_F(ConstantFoldingTest, NeutralElement_ShortFloats) {
SimpleNeutralElementTest<DT_BOOL>();
SimpleNeutralElementTest<DT_HALF>();
SimpleNeutralElementTest<DT_BFLOAT16>();
}

View File

@ -25,6 +25,8 @@ import numpy as np
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@ -54,8 +56,8 @@ from tensorflow.python.training import coordinator
from tensorflow.python.training import queue_runner_impl
def _initialized_session():
sess = session.Session()
def _initialized_session(config=None):
sess = session.Session(config=config)
sess.run(variables_lib.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
return sess
@ -6191,7 +6193,12 @@ class WeightedCategoricalColumnTest(test.TestCase):
'values': ((.5,), (1.,))
}, (column,),
sparse_combiner='mean')
with _initialized_session():
# Disabling the constant folding optimizer here since it changes the
# error message differently on CPU and GPU.
config = config_pb2.ConfigProto()
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
with _initialized_session(config):
with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
predictions.eval()
@ -6284,7 +6291,12 @@ class WeightedCategoricalColumnTest(test.TestCase):
'values': ((.5,), (1.,))
}, (column,),
sparse_combiner='mean')
with _initialized_session():
# Disabling the constant folding optimizer here since it changes the
# error message differently on CPU and GPU.
config = config_pb2.ConfigProto()
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
with _initialized_session(config):
with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
predictions.eval()