diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 4801f18619e..47d88276863 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -866,6 +866,25 @@ Status CreateConstantTensorAttrValue(DataType type, double value, } #undef SET_TENSOR_CAL_CASE + +DataType GetDataTypeFromNodeOrProps(const NodeDef& node, + const GraphProperties& properties) { + DataType dtype = DT_INVALID; + if (node.attr().count("T") == 1) { + dtype = node.attr().at("T").type(); + } else if (node.attr().count("dtype") == 1) { + dtype = node.attr().at("dtype").type(); + } else if (IsLogicalOr(node) || IsLogicalAnd(node)) { + dtype = DT_BOOL; + } else { + auto output_props = properties.GetOutputProperties(node.name()); + if (!output_props.empty()) { + dtype = output_props[0].dtype(); + } + } + return dtype; +} + } // namespace // static @@ -1412,6 +1431,7 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const { } const auto dtype = node.attr().at("dtype").type(); switch (dtype) { + IS_ONES_CASE(DT_BOOL); IS_ONES_CASE(DT_HALF); IS_ONES_CASE(DT_BFLOAT16); IS_ONES_CASE(DT_FLOAT); @@ -1447,6 +1467,7 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const { } const auto dtype = node.attr().at("dtype").type(); switch (dtype) { + IS_ZEROS_CASE(DT_BOOL); IS_ZEROS_CASE(DT_HALF); IS_ZEROS_CASE(DT_BFLOAT16); IS_ZEROS_CASE(DT_FLOAT); @@ -1466,14 +1487,15 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const { return false; } -void ConstantFolding::ReplaceOperationWithIdentity(int input_to_forward, - NodeDef* node, - GraphDef* graph) { +void ConstantFolding::ReplaceOperationWithIdentity( + int input_to_forward, const GraphProperties& properties, NodeDef* node, + GraphDef* graph) { + const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties); + if (dtype == DT_INVALID) return; + node->set_op("Identity"); - DataType dtype = node->attr().at("T").type(); node->clear_attr(); (*node->mutable_attr())["T"].set_type(dtype); - // Propagate the designated input through the identity. node->mutable_input()->SwapElements(0, input_to_forward); // Add all other inputs as control dependencies. @@ -1489,14 +1511,15 @@ void ConstantFolding::ReplaceOperationWithIdentity(int input_to_forward, graph_modified_ = true; } -void ConstantFolding::ReplaceOperationWithSnapshot(int input_to_forward, - NodeDef* node, - GraphDef* graph) { +void ConstantFolding::ReplaceOperationWithSnapshot( + int input_to_forward, const GraphProperties& properties, NodeDef* node, + GraphDef* graph) { + const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties); + if (dtype == DT_INVALID) return; + node->set_op("Snapshot"); - DataType dtype = node->attr().at("T").type(); node->clear_attr(); (*node->mutable_attr())["T"].set_type(dtype); - // Propagate the designated input through the Snapshot. node->mutable_input()->SwapElements(0, input_to_forward); // Add all other inputs as control dependencies. @@ -1535,15 +1558,18 @@ void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node, } Status ConstantFolding::ReplaceOperationWithConstant( - double value, const AttrValue& dtype_attr, const TensorShapeProto& shape, - NodeDef* node, GraphDef* graph) { + double value, const GraphProperties& properties, + const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) { + const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties); + if (dtype == DT_INVALID) return Status::OK(); + AttrValue tensor_attr; - TF_RETURN_IF_ERROR(CreateConstantTensorAttrValue(dtype_attr.type(), value, - shape, &tensor_attr)); - node->clear_attr(); - node->mutable_attr()->insert({"dtype", dtype_attr}); - node->mutable_attr()->insert({"value", tensor_attr}); + TF_RETURN_IF_ERROR( + CreateConstantTensorAttrValue(dtype, value, shape, &tensor_attr)); node->set_op("Const"); + node->clear_attr(); + (*node->mutable_attr())["dtype"].set_type(dtype); + node->mutable_attr()->insert({"value", tensor_attr}); // Convert all inputs to control dependencies. for (int i = 0; i < node->input_size(); ++i) { if (IsControlInput(node->input(i))) { @@ -1566,12 +1592,12 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, NodeDef* node = optimized_graph->mutable_node(i); if (IsSplit(*node) && node->attr().at("num_split").i() == 1) { - ReplaceOperationWithIdentity(1, node, optimized_graph); + ReplaceOperationWithIdentity(1, *properties, node, optimized_graph); continue; } if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } @@ -1611,7 +1637,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, replaceable &= shape.dim(j).size() == 1 || j == permutation[j]; } if (replaceable) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } } @@ -1626,7 +1652,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, // unknown_rank == false && (dim_size == 0 || first dim is of size 1) if (!shape.unknown_rank() && (shape.dim_size() == 0 || shape.dim(0).size() == 1)) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } } @@ -1651,11 +1677,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, for (int j = 0; j < axis.NumElements(); ++j) { // value of axis can be negative. if (axis.dtype() == DT_INT64) { - target_axes.insert( - (axis.vec()(j) + shape.dim_size()) % shape.dim_size()); + target_axes.insert((axis.vec()(j) + shape.dim_size()) % + shape.dim_size()); } else { - target_axes.insert( - (axis.vec()(j) + shape.dim_size()) % shape.dim_size()); + target_axes.insert((axis.vec()(j) + shape.dim_size()) % + shape.dim_size()); } } @@ -1669,7 +1695,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, target_axes.find(j) == target_axes.end(); } if (replaceable) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } } @@ -1711,7 +1737,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, } } if (replaceable) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } } @@ -1740,7 +1766,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, } } if (replaceable) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } } @@ -1764,7 +1790,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, replaceable &= flatten(j) == 0; } if (replaceable) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } } @@ -1784,7 +1810,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, replaceable &= shape.dim(j).size() > 1; } if (replaceable) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } } @@ -1996,9 +2022,9 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, continue; } - const bool is_mul = IsMul(*node); + const bool is_mul = IsMul(*node) || IsLogicalAnd(*node); const bool is_matmul = IsMatMul(*node); - const bool is_add = IsAdd(*node) || IsBiasAdd(*node); + const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node); const bool is_sub = IsSub(*node); const bool is_any_div = IsAnyDiv(*node); // Simplify arithmetic operations with ones or zeros. @@ -2025,7 +2051,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, if (y_matches_output_shape && ((is_mul && x_is_one) || (is_add && x_is_zero))) { // 1 * y = y or 0 + y = y. - ReplaceOperationWithSnapshot(1, node, optimized_graph); + ReplaceOperationWithSnapshot(1, *properties, node, optimized_graph); continue; } @@ -2052,10 +2078,18 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) || ((is_add || is_sub) && y_is_zero))) { // x * 1 = x or x / 1 = x or x +/- 0 = x - ReplaceOperationWithSnapshot(0, node, optimized_graph); + ReplaceOperationWithSnapshot(0, *properties, node, optimized_graph); continue; } + // x OR true = true OR y = true. + const PartialTensorShape shp(output_shape); + if (shp.IsFullyDefined() && IsLogicalOr(*node) && + (y_is_one || x_is_one)) { + TF_RETURN_IF_ERROR(ReplaceOperationWithConstant( + 1, *properties, output_shape, node, optimized_graph)); + } + // Simplify multiplication and matmul by zeros. // Also optimize zeros divided by a tensor, but only if we are in // aggressive mode, since we might get rid of divisions by zero. @@ -2063,26 +2097,19 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, is_any_div && x_is_zero && is_aggressive; if ((x_is_zero || y_is_zero) && (is_mul || is_matmul || optimize_zeros_divided_by_y)) { - const PartialTensorShape shp(output_shape); if (shp.IsFullyDefined()) { - AttrValue dtype_attr; - if (node->op() == "SparseMatMul") { - dtype_attr.set_type(DT_FLOAT); - } else { - dtype_attr = node->attr().at("T"); - } TF_RETURN_IF_ERROR(ReplaceOperationWithConstant( - 0, dtype_attr, output_shape, node, optimized_graph)); + 0, *properties, output_shape, node, optimized_graph)); continue; } // Even if an input shape is only partially known, we may known that it // matches the output shape and thus forward the corresponding zero // input. if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } else if (is_mul && y_is_zero && y_matches_output_shape) { - ReplaceOperationWithIdentity(1, node, optimized_graph); + ReplaceOperationWithIdentity(1, *properties, node, optimized_graph); continue; } } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index eb06cd081f7..a694f1721ad 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -78,12 +78,15 @@ class ConstantFolding : public GraphOptimizer { bool IsOnes(const NodeDef& node) const; bool IsZeros(const NodeDef& node) const; - void ReplaceOperationWithIdentity(int input_to_forward, NodeDef* node, - GraphDef* graph); - void ReplaceOperationWithSnapshot(int input_to_forward, NodeDef* node, - GraphDef* graph); + void ReplaceOperationWithIdentity(int input_to_forward, + const GraphProperties& properties, + NodeDef* node, GraphDef* graph); + void ReplaceOperationWithSnapshot(int input_to_forward, + const GraphProperties& properties, + NodeDef* node, GraphDef* graph); void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph); - Status ReplaceOperationWithConstant(double value, const AttrValue& dtype_attr, + Status ReplaceOperationWithConstant(double value, + const GraphProperties& properties, const TensorShapeProto& shape, NodeDef* node, GraphDef* graph); void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 306ddd22d73..f018b217e66 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -47,18 +47,30 @@ class ConstantFoldingTest : public GrapplerTest { } Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t); Output ones = ops::Const(s.WithOpName("ones"), ones_t); - Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros); - Output mul2 = ops::Mul(s.WithOpName("mul2"), x, ones); - + Output mul1; + Output mul2; + Output add1; + Output add2; + if (DTYPE == DT_BOOL) { + mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros); + mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones); + add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros); + add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones); + } else { + mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros); + mul2 = ops::Mul(s.WithOpName("mul2"), x, ones); + add1 = ops::Add(s.WithOpName("add1"), x, zeros); + add1 = ops::Add(s.WithOpName("add2"), x, ones); + } GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch = {"mul1", "mul2"}; + item.fetch = {"mul1", "mul2", "add1", "add2"}; ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - LOG(INFO) << output.DebugString(); - EXPECT_EQ(5, output.node_size()); + + EXPECT_EQ(7, output.node_size()); for (int i = 0; i < output.node_size(); ++i) { const NodeDef& node = output.node(i); const string& name = node.name(); @@ -70,14 +82,27 @@ class ConstantFoldingTest : public GrapplerTest { EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^ones", node.input(1)); + } else if (name == "add1") { + EXPECT_EQ("Snapshot", node.op()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^zeros", node.input(1)); + } else if (name == "add2") { + if (DTYPE == DT_BOOL) { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("^x", node.input(0)); + EXPECT_EQ("^ones", node.input(1)); + } else { + EXPECT_EQ("Add", node.op()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("ones", node.input(1)); + } } } - auto tensors_expected = - EvaluateNodes(item.graph, {"mul1", "mul2"}, {{"x", x_t}}); - auto tensors = EvaluateNodes(output, {"mul1", "mul2"}, {{"x", x_t}}); - EXPECT_EQ(2, tensors_expected.size()); - EXPECT_EQ(2, tensors.size()); - for (int i = 0; i < 2; ++i) { + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}}); + auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}}); + EXPECT_EQ(4, tensors_expected.size()); + EXPECT_EQ(4, tensors.size()); + for (int i = 0; i < item.fetch.size(); ++i) { test::ExpectTensorEqual(tensors_expected[i], tensors[i]); } } @@ -393,6 +418,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) { } TEST_F(ConstantFoldingTest, NeutralElement_ShortFloats) { + SimpleNeutralElementTest(); SimpleNeutralElementTest(); SimpleNeutralElementTest(); } diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index d963dd9b551..b06540489ff 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -25,6 +25,8 @@ import numpy as np from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -54,8 +56,8 @@ from tensorflow.python.training import coordinator from tensorflow.python.training import queue_runner_impl -def _initialized_session(): - sess = session.Session() +def _initialized_session(config=None): + sess = session.Session(config=config) sess.run(variables_lib.global_variables_initializer()) sess.run(lookup_ops.tables_initializer()) return sess @@ -6191,7 +6193,12 @@ class WeightedCategoricalColumnTest(test.TestCase): 'values': ((.5,), (1.,)) }, (column,), sparse_combiner='mean') - with _initialized_session(): + # Disabling the constant folding optimizer here since it changes the + # error message differently on CPU and GPU. + config = config_pb2.ConfigProto() + config.graph_options.rewrite_options.constant_folding = ( + rewriter_config_pb2.RewriterConfig.OFF) + with _initialized_session(config): with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'): predictions.eval() @@ -6284,7 +6291,12 @@ class WeightedCategoricalColumnTest(test.TestCase): 'values': ((.5,), (1.,)) }, (column,), sparse_combiner='mean') - with _initialized_session(): + # Disabling the constant folding optimizer here since it changes the + # error message differently on CPU and GPU. + config = config_pb2.ConfigProto() + config.graph_options.rewrite_options.constant_folding = ( + rewriter_config_pb2.RewriterConfig.OFF) + with _initialized_session(config): with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'): predictions.eval()