Optimize LogicalOr and LogicalAnd with all true or false inputs:
LogicalOr(x, true) = true LogicalOr(x, false) = x LogicalAnd(x, true) = x LogicalAnd(x, false) = false and similar if the first argument is constant. PiperOrigin-RevId: 195161140
This commit is contained in:
parent
9180cc254d
commit
4704ae7af1
@ -866,6 +866,25 @@ Status CreateConstantTensorAttrValue(DataType type, double value,
|
||||
}
|
||||
|
||||
#undef SET_TENSOR_CAL_CASE
|
||||
|
||||
DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
|
||||
const GraphProperties& properties) {
|
||||
DataType dtype = DT_INVALID;
|
||||
if (node.attr().count("T") == 1) {
|
||||
dtype = node.attr().at("T").type();
|
||||
} else if (node.attr().count("dtype") == 1) {
|
||||
dtype = node.attr().at("dtype").type();
|
||||
} else if (IsLogicalOr(node) || IsLogicalAnd(node)) {
|
||||
dtype = DT_BOOL;
|
||||
} else {
|
||||
auto output_props = properties.GetOutputProperties(node.name());
|
||||
if (!output_props.empty()) {
|
||||
dtype = output_props[0].dtype();
|
||||
}
|
||||
}
|
||||
return dtype;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// static
|
||||
@ -1412,6 +1431,7 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const {
|
||||
}
|
||||
const auto dtype = node.attr().at("dtype").type();
|
||||
switch (dtype) {
|
||||
IS_ONES_CASE(DT_BOOL);
|
||||
IS_ONES_CASE(DT_HALF);
|
||||
IS_ONES_CASE(DT_BFLOAT16);
|
||||
IS_ONES_CASE(DT_FLOAT);
|
||||
@ -1447,6 +1467,7 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
|
||||
}
|
||||
const auto dtype = node.attr().at("dtype").type();
|
||||
switch (dtype) {
|
||||
IS_ZEROS_CASE(DT_BOOL);
|
||||
IS_ZEROS_CASE(DT_HALF);
|
||||
IS_ZEROS_CASE(DT_BFLOAT16);
|
||||
IS_ZEROS_CASE(DT_FLOAT);
|
||||
@ -1466,14 +1487,15 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
void ConstantFolding::ReplaceOperationWithIdentity(int input_to_forward,
|
||||
NodeDef* node,
|
||||
GraphDef* graph) {
|
||||
void ConstantFolding::ReplaceOperationWithIdentity(
|
||||
int input_to_forward, const GraphProperties& properties, NodeDef* node,
|
||||
GraphDef* graph) {
|
||||
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
|
||||
if (dtype == DT_INVALID) return;
|
||||
|
||||
node->set_op("Identity");
|
||||
DataType dtype = node->attr().at("T").type();
|
||||
node->clear_attr();
|
||||
(*node->mutable_attr())["T"].set_type(dtype);
|
||||
|
||||
// Propagate the designated input through the identity.
|
||||
node->mutable_input()->SwapElements(0, input_to_forward);
|
||||
// Add all other inputs as control dependencies.
|
||||
@ -1489,14 +1511,15 @@ void ConstantFolding::ReplaceOperationWithIdentity(int input_to_forward,
|
||||
graph_modified_ = true;
|
||||
}
|
||||
|
||||
void ConstantFolding::ReplaceOperationWithSnapshot(int input_to_forward,
|
||||
NodeDef* node,
|
||||
GraphDef* graph) {
|
||||
void ConstantFolding::ReplaceOperationWithSnapshot(
|
||||
int input_to_forward, const GraphProperties& properties, NodeDef* node,
|
||||
GraphDef* graph) {
|
||||
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
|
||||
if (dtype == DT_INVALID) return;
|
||||
|
||||
node->set_op("Snapshot");
|
||||
DataType dtype = node->attr().at("T").type();
|
||||
node->clear_attr();
|
||||
(*node->mutable_attr())["T"].set_type(dtype);
|
||||
|
||||
// Propagate the designated input through the Snapshot.
|
||||
node->mutable_input()->SwapElements(0, input_to_forward);
|
||||
// Add all other inputs as control dependencies.
|
||||
@ -1535,15 +1558,18 @@ void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
|
||||
}
|
||||
|
||||
Status ConstantFolding::ReplaceOperationWithConstant(
|
||||
double value, const AttrValue& dtype_attr, const TensorShapeProto& shape,
|
||||
NodeDef* node, GraphDef* graph) {
|
||||
double value, const GraphProperties& properties,
|
||||
const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
|
||||
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
|
||||
if (dtype == DT_INVALID) return Status::OK();
|
||||
|
||||
AttrValue tensor_attr;
|
||||
TF_RETURN_IF_ERROR(CreateConstantTensorAttrValue(dtype_attr.type(), value,
|
||||
shape, &tensor_attr));
|
||||
node->clear_attr();
|
||||
node->mutable_attr()->insert({"dtype", dtype_attr});
|
||||
node->mutable_attr()->insert({"value", tensor_attr});
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateConstantTensorAttrValue(dtype, value, shape, &tensor_attr));
|
||||
node->set_op("Const");
|
||||
node->clear_attr();
|
||||
(*node->mutable_attr())["dtype"].set_type(dtype);
|
||||
node->mutable_attr()->insert({"value", tensor_attr});
|
||||
// Convert all inputs to control dependencies.
|
||||
for (int i = 0; i < node->input_size(); ++i) {
|
||||
if (IsControlInput(node->input(i))) {
|
||||
@ -1566,12 +1592,12 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
|
||||
NodeDef* node = optimized_graph->mutable_node(i);
|
||||
|
||||
if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
|
||||
ReplaceOperationWithIdentity(1, node, optimized_graph);
|
||||
ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
|
||||
ReplaceOperationWithIdentity(0, node, optimized_graph);
|
||||
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -1611,7 +1637,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
|
||||
replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, node, optimized_graph);
|
||||
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@ -1626,7 +1652,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
|
||||
// unknown_rank == false && (dim_size == 0 || first dim is of size 1)
|
||||
if (!shape.unknown_rank() &&
|
||||
(shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
|
||||
ReplaceOperationWithIdentity(0, node, optimized_graph);
|
||||
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@ -1651,11 +1677,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
|
||||
for (int j = 0; j < axis.NumElements(); ++j) {
|
||||
// value of axis can be negative.
|
||||
if (axis.dtype() == DT_INT64) {
|
||||
target_axes.insert(
|
||||
(axis.vec<int64>()(j) + shape.dim_size()) % shape.dim_size());
|
||||
target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
|
||||
shape.dim_size());
|
||||
} else {
|
||||
target_axes.insert(
|
||||
(axis.vec<int>()(j) + shape.dim_size()) % shape.dim_size());
|
||||
target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
|
||||
shape.dim_size());
|
||||
}
|
||||
}
|
||||
|
||||
@ -1669,7 +1695,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
|
||||
target_axes.find(j) == target_axes.end();
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, node, optimized_graph);
|
||||
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@ -1711,7 +1737,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
|
||||
}
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, node, optimized_graph);
|
||||
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@ -1740,7 +1766,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
|
||||
}
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, node, optimized_graph);
|
||||
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@ -1764,7 +1790,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
|
||||
replaceable &= flatten(j) == 0;
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, node, optimized_graph);
|
||||
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@ -1784,7 +1810,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
|
||||
replaceable &= shape.dim(j).size() > 1;
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, node, optimized_graph);
|
||||
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@ -1996,9 +2022,9 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
|
||||
continue;
|
||||
}
|
||||
|
||||
const bool is_mul = IsMul(*node);
|
||||
const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
|
||||
const bool is_matmul = IsMatMul(*node);
|
||||
const bool is_add = IsAdd(*node) || IsBiasAdd(*node);
|
||||
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
|
||||
const bool is_sub = IsSub(*node);
|
||||
const bool is_any_div = IsAnyDiv(*node);
|
||||
// Simplify arithmetic operations with ones or zeros.
|
||||
@ -2025,7 +2051,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
|
||||
if (y_matches_output_shape &&
|
||||
((is_mul && x_is_one) || (is_add && x_is_zero))) {
|
||||
// 1 * y = y or 0 + y = y.
|
||||
ReplaceOperationWithSnapshot(1, node, optimized_graph);
|
||||
ReplaceOperationWithSnapshot(1, *properties, node, optimized_graph);
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -2052,10 +2078,18 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
|
||||
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
|
||||
((is_add || is_sub) && y_is_zero))) {
|
||||
// x * 1 = x or x / 1 = x or x +/- 0 = x
|
||||
ReplaceOperationWithSnapshot(0, node, optimized_graph);
|
||||
ReplaceOperationWithSnapshot(0, *properties, node, optimized_graph);
|
||||
continue;
|
||||
}
|
||||
|
||||
// x OR true = true OR y = true.
|
||||
const PartialTensorShape shp(output_shape);
|
||||
if (shp.IsFullyDefined() && IsLogicalOr(*node) &&
|
||||
(y_is_one || x_is_one)) {
|
||||
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
|
||||
1, *properties, output_shape, node, optimized_graph));
|
||||
}
|
||||
|
||||
// Simplify multiplication and matmul by zeros.
|
||||
// Also optimize zeros divided by a tensor, but only if we are in
|
||||
// aggressive mode, since we might get rid of divisions by zero.
|
||||
@ -2063,26 +2097,19 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
|
||||
is_any_div && x_is_zero && is_aggressive;
|
||||
if ((x_is_zero || y_is_zero) &&
|
||||
(is_mul || is_matmul || optimize_zeros_divided_by_y)) {
|
||||
const PartialTensorShape shp(output_shape);
|
||||
if (shp.IsFullyDefined()) {
|
||||
AttrValue dtype_attr;
|
||||
if (node->op() == "SparseMatMul") {
|
||||
dtype_attr.set_type(DT_FLOAT);
|
||||
} else {
|
||||
dtype_attr = node->attr().at("T");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
|
||||
0, dtype_attr, output_shape, node, optimized_graph));
|
||||
0, *properties, output_shape, node, optimized_graph));
|
||||
continue;
|
||||
}
|
||||
// Even if an input shape is only partially known, we may known that it
|
||||
// matches the output shape and thus forward the corresponding zero
|
||||
// input.
|
||||
if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
|
||||
ReplaceOperationWithIdentity(0, node, optimized_graph);
|
||||
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
|
||||
continue;
|
||||
} else if (is_mul && y_is_zero && y_matches_output_shape) {
|
||||
ReplaceOperationWithIdentity(1, node, optimized_graph);
|
||||
ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
@ -78,12 +78,15 @@ class ConstantFolding : public GraphOptimizer {
|
||||
|
||||
bool IsOnes(const NodeDef& node) const;
|
||||
bool IsZeros(const NodeDef& node) const;
|
||||
void ReplaceOperationWithIdentity(int input_to_forward, NodeDef* node,
|
||||
GraphDef* graph);
|
||||
void ReplaceOperationWithSnapshot(int input_to_forward, NodeDef* node,
|
||||
GraphDef* graph);
|
||||
void ReplaceOperationWithIdentity(int input_to_forward,
|
||||
const GraphProperties& properties,
|
||||
NodeDef* node, GraphDef* graph);
|
||||
void ReplaceOperationWithSnapshot(int input_to_forward,
|
||||
const GraphProperties& properties,
|
||||
NodeDef* node, GraphDef* graph);
|
||||
void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
|
||||
Status ReplaceOperationWithConstant(double value, const AttrValue& dtype_attr,
|
||||
Status ReplaceOperationWithConstant(double value,
|
||||
const GraphProperties& properties,
|
||||
const TensorShapeProto& shape,
|
||||
NodeDef* node, GraphDef* graph);
|
||||
void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);
|
||||
|
@ -47,18 +47,30 @@ class ConstantFoldingTest : public GrapplerTest {
|
||||
}
|
||||
Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
|
||||
Output ones = ops::Const(s.WithOpName("ones"), ones_t);
|
||||
Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
|
||||
Output mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
|
||||
|
||||
Output mul1;
|
||||
Output mul2;
|
||||
Output add1;
|
||||
Output add2;
|
||||
if (DTYPE == DT_BOOL) {
|
||||
mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros);
|
||||
mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones);
|
||||
add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros);
|
||||
add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones);
|
||||
} else {
|
||||
mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
|
||||
mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
|
||||
add1 = ops::Add(s.WithOpName("add1"), x, zeros);
|
||||
add1 = ops::Add(s.WithOpName("add2"), x, ones);
|
||||
}
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
item.fetch = {"mul1", "mul2"};
|
||||
item.fetch = {"mul1", "mul2", "add1", "add2"};
|
||||
ConstantFolding optimizer(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
LOG(INFO) << output.DebugString();
|
||||
EXPECT_EQ(5, output.node_size());
|
||||
|
||||
EXPECT_EQ(7, output.node_size());
|
||||
for (int i = 0; i < output.node_size(); ++i) {
|
||||
const NodeDef& node = output.node(i);
|
||||
const string& name = node.name();
|
||||
@ -70,14 +82,27 @@ class ConstantFoldingTest : public GrapplerTest {
|
||||
EXPECT_EQ("Snapshot", node.op());
|
||||
EXPECT_EQ("x", node.input(0));
|
||||
EXPECT_EQ("^ones", node.input(1));
|
||||
} else if (name == "add1") {
|
||||
EXPECT_EQ("Snapshot", node.op());
|
||||
EXPECT_EQ("x", node.input(0));
|
||||
EXPECT_EQ("^zeros", node.input(1));
|
||||
} else if (name == "add2") {
|
||||
if (DTYPE == DT_BOOL) {
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ("^x", node.input(0));
|
||||
EXPECT_EQ("^ones", node.input(1));
|
||||
} else {
|
||||
EXPECT_EQ("Add", node.op());
|
||||
EXPECT_EQ("x", node.input(0));
|
||||
EXPECT_EQ("ones", node.input(1));
|
||||
}
|
||||
}
|
||||
}
|
||||
auto tensors_expected =
|
||||
EvaluateNodes(item.graph, {"mul1", "mul2"}, {{"x", x_t}});
|
||||
auto tensors = EvaluateNodes(output, {"mul1", "mul2"}, {{"x", x_t}});
|
||||
EXPECT_EQ(2, tensors_expected.size());
|
||||
EXPECT_EQ(2, tensors.size());
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
|
||||
auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
|
||||
EXPECT_EQ(4, tensors_expected.size());
|
||||
EXPECT_EQ(4, tensors.size());
|
||||
for (int i = 0; i < item.fetch.size(); ++i) {
|
||||
test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
|
||||
}
|
||||
}
|
||||
@ -393,6 +418,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, NeutralElement_ShortFloats) {
|
||||
SimpleNeutralElementTest<DT_BOOL>();
|
||||
SimpleNeutralElementTest<DT_HALF>();
|
||||
SimpleNeutralElementTest<DT_BFLOAT16>();
|
||||
}
|
||||
|
@ -25,6 +25,8 @@ import numpy as np
|
||||
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.example import feature_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
@ -54,8 +56,8 @@ from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import queue_runner_impl
|
||||
|
||||
|
||||
def _initialized_session():
|
||||
sess = session.Session()
|
||||
def _initialized_session(config=None):
|
||||
sess = session.Session(config=config)
|
||||
sess.run(variables_lib.global_variables_initializer())
|
||||
sess.run(lookup_ops.tables_initializer())
|
||||
return sess
|
||||
@ -6191,7 +6193,12 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
||||
'values': ((.5,), (1.,))
|
||||
}, (column,),
|
||||
sparse_combiner='mean')
|
||||
with _initialized_session():
|
||||
# Disabling the constant folding optimizer here since it changes the
|
||||
# error message differently on CPU and GPU.
|
||||
config = config_pb2.ConfigProto()
|
||||
config.graph_options.rewrite_options.constant_folding = (
|
||||
rewriter_config_pb2.RewriterConfig.OFF)
|
||||
with _initialized_session(config):
|
||||
with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
|
||||
predictions.eval()
|
||||
|
||||
@ -6284,7 +6291,12 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
||||
'values': ((.5,), (1.,))
|
||||
}, (column,),
|
||||
sparse_combiner='mean')
|
||||
with _initialized_session():
|
||||
# Disabling the constant folding optimizer here since it changes the
|
||||
# error message differently on CPU and GPU.
|
||||
config = config_pb2.ConfigProto()
|
||||
config.graph_options.rewrite_options.constant_folding = (
|
||||
rewriter_config_pb2.RewriterConfig.OFF)
|
||||
with _initialized_session(config):
|
||||
with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
|
||||
predictions.eval()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user