Add optimization to replace broadcasting using Mul by all ones with Tile.

PiperOrigin-RevId: 358195304
Change-Id: I786f1344b665e4733936d9da428ce3dd57409941
This commit is contained in:
A. Unique TensorFlower 2021-02-18 09:25:51 -08:00 committed by TensorFlower Gardener
parent c92320d7a4
commit 6e05909d98
4 changed files with 306 additions and 0 deletions

View File

@ -51,6 +51,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/tensor_coding.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"
#include "tensorflow/core/util/strided_slice_op.h"
@ -2458,6 +2459,127 @@ class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
}
};
// Replace a combination of Mul with broadcasting by Tile. E.g. replace
//
// input(1x22x1x48x1x64) -> Mul (1x22x2x48x2x64) -> output
// Ones (1x22x2x48x2x64) -^
//
// with
//
// input -> Tile(1x22x2x48x2x64) -> output
class ReplaceMulWithBroadcastByTile : public ArithmeticOptimizerStage {
public:
explicit ReplaceMulWithBroadcastByTile(
const GraphOptimizerContext& ctx,
const ArithmeticOptimizerContext& ctx_ext)
: ArithmeticOptimizerStage("ReplaceMulWithBroadcastByTile", ctx,
ctx_ext) {}
~ReplaceMulWithBroadcastByTile() override = default;
bool IsSupported(const NodeDef* node) const override {
return IsMul(*node) && !IsInPreserveSet(*node);
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
NodeDef *input, *ones;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
if (IsInPreserveSet(*node) || IsInPreserveSet(*input) ||
IsInPreserveSet(*ones)) {
return Status::OK();
}
// TODO(kkiningh): Generalize using IsOnes from constant_folding.cc
if (IsConstant(*input) || !IsOnes(*ones)) return Status::OK();
// Avoid optimizing the same node twice
const NodeScopeAndName scope_and_name = ParseNodeScopeAndName(node->name());
const string tile_node_name = OptimizedNodeName(scope_and_name, "Tile");
const string const_node_name = OptimizedNodeName(scope_and_name, "Const");
if (ctx().node_map->NodeExists(tile_node_name) ||
ctx().node_map->NodeExists(const_node_name)) {
return Status::OK();
}
const std::vector<OpInfo::TensorProperties>& props =
ctx().graph_properties->GetInputProperties(node->name());
if (props.size() != 2) return Status::OK();
// Ignore ops where the shape doesn't change
const TensorShapeProto& input_shape = props[0].shape();
const TensorShapeProto& ones_shape = props[1].shape();
TensorShapeProto output_shape;
if (!ShapeAfterBroadcast(input_shape, ones_shape, &output_shape)) {
return Status::OK();
}
if (ShapesSymbolicallyEqual(input_shape, output_shape)) {
return Status::OK();
}
// All inputs must have same input/output dimensions
if (input_shape.dim_size() != output_shape.dim_size() ||
ones_shape.dim_size() != output_shape.dim_size())
return Status::OK();
// At this point all preconditions are met. Can proceed with rewrite.
VLOG(3) << "Simplify multiply with all ones input: node=" << node->name()
<< "@" << output_shape << " ones=" << ones->name() << "@"
<< ones_shape << " input=" << input->name() << "@" << input_shape;
// 1. Create constant node with correct tile multiples
Tensor multiples(DT_INT32, TensorShape({output_shape.dim_size()}));
for (int i = 0; i < output_shape.dim_size(); ++i) {
int64 size = output_shape.dim(i).size() / input_shape.dim(i).size();
if (TF_PREDICT_FALSE(size >= INT_MAX)) {
return Status(error::OUT_OF_RANGE, "int32 overflow");
}
multiples.flat<int32>()(i) = static_cast<int32>(size);
}
NodeDef* const_node = AddEmptyNode(const_node_name);
TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
const_node->name(), TensorValue(&multiples), const_node));
const_node->set_device(node->device());
ForwardControlDependencies(const_node, {ones});
AddToOptimizationQueue(const_node);
// 2. Replace multiply node with Tile(Const, input);
const DataType type = GetDataTypeFromAttr(*node, "T");
NodeDef* tile_node = AddEmptyNode(tile_node_name);
tile_node->set_op("Tile");
tile_node->set_device(node->device());
SetDataTypeToAttr(type, "T", tile_node);
SetDataTypeToAttr(DT_INT32, "Tmultiples", tile_node);
tile_node->add_input(input->name());
tile_node->add_input(const_node->name());
ForwardControlDependencies(tile_node, {node});
*simplified_node_name = tile_node->name();
return Status::OK();
}
protected:
bool IsOnes(const NodeDef& node) const {
if (!IsReallyConstant(node)) return false;
if (node.attr().at("dtype").type() != DT_FLOAT) return false;
Tensor tensor;
if (!tensor.FromProto(node.attr().at("value").tensor())) {
return false;
}
auto values = tensor.flat<float>();
for (int i = 0; i < tensor.NumElements(); ++i) {
if (values(i) != 1.0f) {
return false;
}
}
return true;
}
};
// Simplify aggregation (e.g. AddN) nodes:
//
// 1. Discard aggregate nodes with a single input and no control dependencies.
@ -3704,6 +3826,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
if (options_.replace_mul_with_square)
pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext);
if (options_.replace_mul_with_tile)
pipeline.AddStage<ReplaceMulWithBroadcastByTile>(ctx, ctx_ext);
if (options_.remove_logical_not)
pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
if (options_.reorder_cast_like_and_value_preserving)

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_H_
#include <unordered_set>
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
@ -77,6 +78,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool remove_redundant_cast = true;
bool remove_redundant_reshape = true;
bool reorder_cast_like_and_value_preserving = true;
bool replace_mul_with_tile = true;
bool replace_mul_with_square = true;
bool simplify_aggregation = true;
bool convert_pow = true;

View File

@ -106,6 +106,180 @@ TEST_F(ArithmeticOptimizerTest, NoOp) {
VerifyGraphsMatch(item.graph, output, __LINE__);
}
TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTile) {
// Graph from b/176172427
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output input =
ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
ops::Placeholder::Shape({1, 44, 1, 96, 1, 64}));
Output ones = ops::Const(s.WithOpName("ones"), 1.0f, {1, 1, 2, 1, 2, 1});
Output multiply = ops::Mul(s.WithOpName("mul"), input, ones);
Output output = ops::Identity(s.WithOpName("output"), multiply);
GrapplerItem item;
item.fetch = {"output"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensor =
GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 44, 1, 96, 1, 64}));
auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
ASSERT_EQ(expected.size(), 1);
GraphDef g;
ArithmeticOptimizer optimizer;
EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
OptimizeTwiceAndPrune(&optimizer, &item, &g);
EXPECT_EQ(g.node_size(), 4);
ASSERT_EQ(CountOpNodes(g, "Mul"), 0);
ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
NodeMap node_map(&g);
const string p = "ArithmeticOptimizer/ReplaceMulWithBroadcastByTile";
const NodeDef* t = node_map.GetNode(absl::StrCat(p, "_", "Tile_mul"));
const NodeDef* c = node_map.GetNode(absl::StrCat(p, "_", "Const_mul"));
ASSERT_NE(t, nullptr);
ASSERT_NE(c, nullptr);
EXPECT_EQ(t->op(), "Tile");
ASSERT_EQ(t->input_size(), 2);
EXPECT_EQ(t->input(0), "input");
EXPECT_EQ(t->input(1), c->name());
EXPECT_EQ(t->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(t->attr().at("Tmultiples").type(), c->attr().at("dtype").type());
auto result = EvaluateNodes(g, item.fetch, {{"input", tensor}});
ASSERT_EQ(result.size(), 1);
test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTilePreserveControl) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
ops::Placeholder::Shape({1, 1, 1}));
Output ones = ops::Const(s.WithOpName("ones").WithControlDependencies(input),
1.0f, {1, 2, 1});
Output multiply = ops::Mul(s.WithOpName("mul"), input, ones);
Output output = ops::Identity(s.WithOpName("output"), multiply);
GrapplerItem item;
item.fetch = {"output"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
ASSERT_EQ(expected.size(), 1);
GraphDef g;
ArithmeticOptimizer optimizer;
EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
OptimizeTwiceAndPrune(&optimizer, &item, &g);
EXPECT_EQ(g.node_size(), 4);
VLOG(0) << g.node_size();
for (auto&& node : g.node()) {
VLOG(0) << node.name();
}
ASSERT_EQ(CountOpNodes(g, "Mul"), 0);
ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
NodeMap node_map(&g);
const string p = "ArithmeticOptimizer/ReplaceMulWithBroadcastByTile";
const NodeDef* c = node_map.GetNode(absl::StrCat(p, "_", "Const_mul"));
ASSERT_NE(c, nullptr);
ASSERT_EQ(c->input_size(), 1);
EXPECT_TRUE(IsControlInput(c->input(0)));
EXPECT_EQ(c->input(0), "^input");
}
TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNoBroadcast) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output input =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 2, 1}));
Output ones = ops::Const(s.WithOpName("ones"), 1.0f, {1, 2, 1});
Output multiply = ops::Mul(s.WithOpName("multiply"), input, ones);
Output output = ops::Identity(s.WithOpName("output"), multiply);
GrapplerItem item;
item.fetch = {"output"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 1}));
auto expected =
EvaluateNodes(item.graph, item.fetch, {{"Placeholder", tensor}});
ASSERT_EQ(expected.size(), 1);
GraphDef g;
ArithmeticOptimizer optimizer;
EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
OptimizeTwiceAndPrune(&optimizer, &item, &g);
EXPECT_EQ(g.node_size(), 4);
VerifyGraphsMatch(item.graph, g, __LINE__);
auto result = EvaluateNodes(g, item.fetch, {{"Placeholder", tensor}});
ASSERT_EQ(result.size(), 1);
test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNotConst) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output input1 = ops::Placeholder(s.WithOpName("input1"), DT_FLOAT,
ops::Placeholder::Shape({1, 1, 1}));
Output input2 = ops::Placeholder(s.WithOpName("input2"), DT_FLOAT,
ops::Placeholder::Shape({1, 2, 1}));
Output multiply = ops::Mul(s.WithOpName("multiply"), input1, input2);
Output output = ops::Identity(s.WithOpName("output"), multiply);
GrapplerItem item;
item.fetch = {"output"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensor1 = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
auto tensor2 = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 1}));
auto expected = EvaluateNodes(item.graph, item.fetch,
{{"input1", tensor1}, {"input2", tensor2}});
ASSERT_EQ(expected.size(), 1);
GraphDef g;
ArithmeticOptimizer optimizer;
EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
OptimizeTwiceAndPrune(&optimizer, &item, &g);
EXPECT_EQ(g.node_size(), 4);
VerifyGraphsMatch(item.graph, g, __LINE__);
auto result = EvaluateNodes(item.graph, item.fetch,
{{"input1", tensor1}, {"input2", tensor2}});
ASSERT_EQ(result.size(), 1);
test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNotOnes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output input =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 1, 1}));
Output ones = ops::Const(s.WithOpName("ones"), 2.0f, {1, 2, 1});
Output multiply = ops::Mul(s.WithOpName("multiply"), input, ones);
Output output = ops::Identity(s.WithOpName("output"), multiply);
GrapplerItem item;
item.fetch = {"output"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
auto expected =
EvaluateNodes(item.graph, item.fetch, {{"Placeholder", tensor}});
ASSERT_EQ(expected.size(), 1);
GraphDef g;
ArithmeticOptimizer optimizer;
EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
OptimizeTwiceAndPrune(&optimizer, &item, &g);
EXPECT_EQ(g.node_size(), 4);
VerifyGraphsMatch(item.graph, g, __LINE__);
auto result = EvaluateNodes(g, item.fetch, {{"Placeholder", tensor}});
ASSERT_EQ(result.size(), 1);
test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});

View File

@ -153,6 +153,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.reorder_cast_like_and_value_preserving = true;
}
void EnableOnlyReplaceMulWithBroadcastByTile(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.replace_mul_with_tile = true;
}
void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.replace_mul_with_square = true;
@ -258,6 +263,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
options.remove_negation = false;
options.remove_logical_not = false;
options.reorder_cast_like_and_value_preserving = false;
options.replace_mul_with_tile = false;
options.replace_mul_with_square = false;
options.simplify_aggregation = false;
options.unary_ops_composition = false;