From 6e05909d9877aea696d15f3ef08409bbeca53688 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 18 Feb 2021 09:25:51 -0800 Subject: [PATCH] Add optimization to replace broadcasting using Mul by all ones with Tile. PiperOrigin-RevId: 358195304 Change-Id: I786f1344b665e4733936d9da428ce3dd57409941 --- .../optimizers/arithmetic_optimizer.cc | 124 +++++++++++++ .../optimizers/arithmetic_optimizer.h | 2 + .../optimizers/arithmetic_optimizer_test.cc | 174 ++++++++++++++++++ .../arithmetic_optimizer_test_utils.h | 6 + 4 files changed, 306 insertions(+) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index f0a15cbce49..9e16b3e2101 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/tensor_coding.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/saved_tensor_slice_util.h" #include "tensorflow/core/util/strided_slice_op.h" @@ -2458,6 +2459,127 @@ class ReplaceMulWithSquare : public ArithmeticOptimizerStage { } }; +// Replace a combination of Mul with broadcasting by Tile. E.g. replace +// +// input(1x22x1x48x1x64) -> Mul (1x22x2x48x2x64) -> output +// Ones (1x22x2x48x2x64) -^ +// +// with +// +// input -> Tile(1x22x2x48x2x64) -> output +class ReplaceMulWithBroadcastByTile : public ArithmeticOptimizerStage { + public: + explicit ReplaceMulWithBroadcastByTile( + const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ReplaceMulWithBroadcastByTile", ctx, + ctx_ext) {} + ~ReplaceMulWithBroadcastByTile() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsMul(*node) && !IsInPreserveSet(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + NodeDef *input, *ones; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); + TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones)); + if (IsInPreserveSet(*node) || IsInPreserveSet(*input) || + IsInPreserveSet(*ones)) { + return Status::OK(); + } + + // TODO(kkiningh): Generalize using IsOnes from constant_folding.cc + if (IsConstant(*input) || !IsOnes(*ones)) return Status::OK(); + + // Avoid optimizing the same node twice + const NodeScopeAndName scope_and_name = ParseNodeScopeAndName(node->name()); + const string tile_node_name = OptimizedNodeName(scope_and_name, "Tile"); + const string const_node_name = OptimizedNodeName(scope_and_name, "Const"); + if (ctx().node_map->NodeExists(tile_node_name) || + ctx().node_map->NodeExists(const_node_name)) { + return Status::OK(); + } + + const std::vector& props = + ctx().graph_properties->GetInputProperties(node->name()); + if (props.size() != 2) return Status::OK(); + + // Ignore ops where the shape doesn't change + const TensorShapeProto& input_shape = props[0].shape(); + const TensorShapeProto& ones_shape = props[1].shape(); + TensorShapeProto output_shape; + if (!ShapeAfterBroadcast(input_shape, ones_shape, &output_shape)) { + return Status::OK(); + } + if (ShapesSymbolicallyEqual(input_shape, output_shape)) { + return Status::OK(); + } + + // All inputs must have same input/output dimensions + if (input_shape.dim_size() != output_shape.dim_size() || + ones_shape.dim_size() != output_shape.dim_size()) + return Status::OK(); + + // At this point all preconditions are met. Can proceed with rewrite. + VLOG(3) << "Simplify multiply with all ones input: node=" << node->name() + << "@" << output_shape << " ones=" << ones->name() << "@" + << ones_shape << " input=" << input->name() << "@" << input_shape; + + // 1. Create constant node with correct tile multiples + Tensor multiples(DT_INT32, TensorShape({output_shape.dim_size()})); + for (int i = 0; i < output_shape.dim_size(); ++i) { + int64 size = output_shape.dim(i).size() / input_shape.dim(i).size(); + if (TF_PREDICT_FALSE(size >= INT_MAX)) { + return Status(error::OUT_OF_RANGE, "int32 overflow"); + } + multiples.flat()(i) = static_cast(size); + } + + NodeDef* const_node = AddEmptyNode(const_node_name); + TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef( + const_node->name(), TensorValue(&multiples), const_node)); + const_node->set_device(node->device()); + ForwardControlDependencies(const_node, {ones}); + AddToOptimizationQueue(const_node); + + // 2. Replace multiply node with Tile(Const, input); + const DataType type = GetDataTypeFromAttr(*node, "T"); + NodeDef* tile_node = AddEmptyNode(tile_node_name); + tile_node->set_op("Tile"); + tile_node->set_device(node->device()); + SetDataTypeToAttr(type, "T", tile_node); + SetDataTypeToAttr(DT_INT32, "Tmultiples", tile_node); + tile_node->add_input(input->name()); + tile_node->add_input(const_node->name()); + + ForwardControlDependencies(tile_node, {node}); + *simplified_node_name = tile_node->name(); + + return Status::OK(); + } + + protected: + bool IsOnes(const NodeDef& node) const { + if (!IsReallyConstant(node)) return false; + if (node.attr().at("dtype").type() != DT_FLOAT) return false; + + Tensor tensor; + if (!tensor.FromProto(node.attr().at("value").tensor())) { + return false; + } + + auto values = tensor.flat(); + for (int i = 0; i < tensor.NumElements(); ++i) { + if (values(i) != 1.0f) { + return false; + } + } + + return true; + } +}; + // Simplify aggregation (e.g. AddN) nodes: // // 1. Discard aggregate nodes with a single input and no control dependencies. @@ -3704,6 +3826,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage(ctx, ctx_ext); if (options_.replace_mul_with_square) pipeline.AddStage(ctx, ctx_ext); + if (options_.replace_mul_with_tile) + pipeline.AddStage(ctx, ctx_ext); if (options_.remove_logical_not) pipeline.AddStage(ctx, ctx_ext); if (options_.reorder_cast_like_and_value_preserving) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 044dc855244..d9f03ef38c1 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_H_ #include + #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/utils.h" @@ -77,6 +78,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool remove_redundant_cast = true; bool remove_redundant_reshape = true; bool reorder_cast_like_and_value_preserving = true; + bool replace_mul_with_tile = true; bool replace_mul_with_square = true; bool simplify_aggregation = true; bool convert_pow = true; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index ec40ade8248..154196c954f 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -106,6 +106,180 @@ TEST_F(ArithmeticOptimizerTest, NoOp) { VerifyGraphsMatch(item.graph, output, __LINE__); } +TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTile) { + // Graph from b/176172427 + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = + ops::Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape({1, 44, 1, 96, 1, 64})); + Output ones = ops::Const(s.WithOpName("ones"), 1.0f, {1, 1, 2, 1, 2, 1}); + Output multiply = ops::Mul(s.WithOpName("mul"), input, ones); + Output output = ops::Identity(s.WithOpName("output"), multiply); + + GrapplerItem item; + item.fetch = {"output"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensor = + GenerateRandomTensor(TensorShape({1, 44, 1, 96, 1, 64})); + auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}}); + ASSERT_EQ(expected.size(), 1); + + GraphDef g; + ArithmeticOptimizer optimizer; + EnableOnlyReplaceMulWithBroadcastByTile(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &g); + EXPECT_EQ(g.node_size(), 4); + + ASSERT_EQ(CountOpNodes(g, "Mul"), 0); + ASSERT_EQ(CountOpNodes(g, "Tile"), 1); + + NodeMap node_map(&g); + const string p = "ArithmeticOptimizer/ReplaceMulWithBroadcastByTile"; + const NodeDef* t = node_map.GetNode(absl::StrCat(p, "_", "Tile_mul")); + const NodeDef* c = node_map.GetNode(absl::StrCat(p, "_", "Const_mul")); + ASSERT_NE(t, nullptr); + ASSERT_NE(c, nullptr); + EXPECT_EQ(t->op(), "Tile"); + ASSERT_EQ(t->input_size(), 2); + EXPECT_EQ(t->input(0), "input"); + EXPECT_EQ(t->input(1), c->name()); + EXPECT_EQ(t->attr().at("T").type(), DT_FLOAT); + EXPECT_EQ(t->attr().at("Tmultiples").type(), c->attr().at("dtype").type()); + + auto result = EvaluateNodes(g, item.fetch, {{"input", tensor}}); + ASSERT_EQ(result.size(), 1); + test::ExpectTensorNear(result[0], expected[0], 1e-6); +} + +TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTilePreserveControl) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape({1, 1, 1})); + Output ones = ops::Const(s.WithOpName("ones").WithControlDependencies(input), + 1.0f, {1, 2, 1}); + Output multiply = ops::Mul(s.WithOpName("mul"), input, ones); + Output output = ops::Identity(s.WithOpName("output"), multiply); + + GrapplerItem item; + item.fetch = {"output"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensor = GenerateRandomTensor(TensorShape({1, 1, 1})); + auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}}); + ASSERT_EQ(expected.size(), 1); + + GraphDef g; + ArithmeticOptimizer optimizer; + EnableOnlyReplaceMulWithBroadcastByTile(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &g); + EXPECT_EQ(g.node_size(), 4); + + VLOG(0) << g.node_size(); + for (auto&& node : g.node()) { + VLOG(0) << node.name(); + } + + ASSERT_EQ(CountOpNodes(g, "Mul"), 0); + ASSERT_EQ(CountOpNodes(g, "Tile"), 1); + + NodeMap node_map(&g); + const string p = "ArithmeticOptimizer/ReplaceMulWithBroadcastByTile"; + const NodeDef* c = node_map.GetNode(absl::StrCat(p, "_", "Const_mul")); + ASSERT_NE(c, nullptr); + ASSERT_EQ(c->input_size(), 1); + EXPECT_TRUE(IsControlInput(c->input(0))); + EXPECT_EQ(c->input(0), "^input"); +} + +TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNoBroadcast) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = + ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 2, 1})); + Output ones = ops::Const(s.WithOpName("ones"), 1.0f, {1, 2, 1}); + Output multiply = ops::Mul(s.WithOpName("multiply"), input, ones); + Output output = ops::Identity(s.WithOpName("output"), multiply); + + GrapplerItem item; + item.fetch = {"output"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensor = GenerateRandomTensor(TensorShape({1, 2, 1})); + auto expected = + EvaluateNodes(item.graph, item.fetch, {{"Placeholder", tensor}}); + ASSERT_EQ(expected.size(), 1); + + GraphDef g; + ArithmeticOptimizer optimizer; + EnableOnlyReplaceMulWithBroadcastByTile(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &g); + EXPECT_EQ(g.node_size(), 4); + + VerifyGraphsMatch(item.graph, g, __LINE__); + + auto result = EvaluateNodes(g, item.fetch, {{"Placeholder", tensor}}); + ASSERT_EQ(result.size(), 1); + test::ExpectTensorNear(result[0], expected[0], 1e-6); +} + +TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNotConst) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input1 = ops::Placeholder(s.WithOpName("input1"), DT_FLOAT, + ops::Placeholder::Shape({1, 1, 1})); + Output input2 = ops::Placeholder(s.WithOpName("input2"), DT_FLOAT, + ops::Placeholder::Shape({1, 2, 1})); + Output multiply = ops::Mul(s.WithOpName("multiply"), input1, input2); + Output output = ops::Identity(s.WithOpName("output"), multiply); + + GrapplerItem item; + item.fetch = {"output"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensor1 = GenerateRandomTensor(TensorShape({1, 1, 1})); + auto tensor2 = GenerateRandomTensor(TensorShape({1, 2, 1})); + auto expected = EvaluateNodes(item.graph, item.fetch, + {{"input1", tensor1}, {"input2", tensor2}}); + ASSERT_EQ(expected.size(), 1); + + GraphDef g; + ArithmeticOptimizer optimizer; + EnableOnlyReplaceMulWithBroadcastByTile(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &g); + EXPECT_EQ(g.node_size(), 4); + + VerifyGraphsMatch(item.graph, g, __LINE__); + + auto result = EvaluateNodes(item.graph, item.fetch, + {{"input1", tensor1}, {"input2", tensor2}}); + ASSERT_EQ(result.size(), 1); + test::ExpectTensorNear(result[0], expected[0], 1e-6); +} + +TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNotOnes) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = + ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 1, 1})); + Output ones = ops::Const(s.WithOpName("ones"), 2.0f, {1, 2, 1}); + Output multiply = ops::Mul(s.WithOpName("multiply"), input, ones); + Output output = ops::Identity(s.WithOpName("output"), multiply); + + GrapplerItem item; + item.fetch = {"output"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensor = GenerateRandomTensor(TensorShape({1, 1, 1})); + auto expected = + EvaluateNodes(item.graph, item.fetch, {{"Placeholder", tensor}}); + ASSERT_EQ(expected.size(), 1); + + GraphDef g; + ArithmeticOptimizer optimizer; + EnableOnlyReplaceMulWithBroadcastByTile(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &g); + EXPECT_EQ(g.node_size(), 4); + + VerifyGraphsMatch(item.graph, g, __LINE__); + + auto result = EvaluateNodes(g, item.fetch, {{"Placeholder", tensor}}); + ASSERT_EQ(result.size(), 1); + test::ExpectTensorNear(result[0], expected[0], 1e-6); +} + TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h index 9025635e668..71c7ef564ae 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h @@ -153,6 +153,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.reorder_cast_like_and_value_preserving = true; } + void EnableOnlyReplaceMulWithBroadcastByTile(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.replace_mul_with_tile = true; + } + void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.replace_mul_with_square = true; @@ -258,6 +263,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.remove_negation = false; options.remove_logical_not = false; options.reorder_cast_like_and_value_preserving = false; + options.replace_mul_with_tile = false; options.replace_mul_with_square = false; options.simplify_aggregation = false; options.unary_ops_composition = false;