Add optimization to replace broadcasting using Mul by all ones with Tile.
PiperOrigin-RevId: 358195304 Change-Id: I786f1344b665e4733936d9da428ce3dd57409941
This commit is contained in:
parent
c92320d7a4
commit
6e05909d98
@ -51,6 +51,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/tensor_coding.h"
|
#include "tensorflow/core/platform/tensor_coding.h"
|
||||||
|
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||||
#include "tensorflow/core/util/device_name_utils.h"
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
#include "tensorflow/core/util/saved_tensor_slice_util.h"
|
#include "tensorflow/core/util/saved_tensor_slice_util.h"
|
||||||
#include "tensorflow/core/util/strided_slice_op.h"
|
#include "tensorflow/core/util/strided_slice_op.h"
|
||||||
@ -2458,6 +2459,127 @@ class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Replace a combination of Mul with broadcasting by Tile. E.g. replace
|
||||||
|
//
|
||||||
|
// input(1x22x1x48x1x64) -> Mul (1x22x2x48x2x64) -> output
|
||||||
|
// Ones (1x22x2x48x2x64) -^
|
||||||
|
//
|
||||||
|
// with
|
||||||
|
//
|
||||||
|
// input -> Tile(1x22x2x48x2x64) -> output
|
||||||
|
class ReplaceMulWithBroadcastByTile : public ArithmeticOptimizerStage {
|
||||||
|
public:
|
||||||
|
explicit ReplaceMulWithBroadcastByTile(
|
||||||
|
const GraphOptimizerContext& ctx,
|
||||||
|
const ArithmeticOptimizerContext& ctx_ext)
|
||||||
|
: ArithmeticOptimizerStage("ReplaceMulWithBroadcastByTile", ctx,
|
||||||
|
ctx_ext) {}
|
||||||
|
~ReplaceMulWithBroadcastByTile() override = default;
|
||||||
|
|
||||||
|
bool IsSupported(const NodeDef* node) const override {
|
||||||
|
return IsMul(*node) && !IsInPreserveSet(*node);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
|
||||||
|
NodeDef *input, *ones;
|
||||||
|
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
|
||||||
|
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
|
||||||
|
if (IsInPreserveSet(*node) || IsInPreserveSet(*input) ||
|
||||||
|
IsInPreserveSet(*ones)) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kkiningh): Generalize using IsOnes from constant_folding.cc
|
||||||
|
if (IsConstant(*input) || !IsOnes(*ones)) return Status::OK();
|
||||||
|
|
||||||
|
// Avoid optimizing the same node twice
|
||||||
|
const NodeScopeAndName scope_and_name = ParseNodeScopeAndName(node->name());
|
||||||
|
const string tile_node_name = OptimizedNodeName(scope_and_name, "Tile");
|
||||||
|
const string const_node_name = OptimizedNodeName(scope_and_name, "Const");
|
||||||
|
if (ctx().node_map->NodeExists(tile_node_name) ||
|
||||||
|
ctx().node_map->NodeExists(const_node_name)) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<OpInfo::TensorProperties>& props =
|
||||||
|
ctx().graph_properties->GetInputProperties(node->name());
|
||||||
|
if (props.size() != 2) return Status::OK();
|
||||||
|
|
||||||
|
// Ignore ops where the shape doesn't change
|
||||||
|
const TensorShapeProto& input_shape = props[0].shape();
|
||||||
|
const TensorShapeProto& ones_shape = props[1].shape();
|
||||||
|
TensorShapeProto output_shape;
|
||||||
|
if (!ShapeAfterBroadcast(input_shape, ones_shape, &output_shape)) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (ShapesSymbolicallyEqual(input_shape, output_shape)) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// All inputs must have same input/output dimensions
|
||||||
|
if (input_shape.dim_size() != output_shape.dim_size() ||
|
||||||
|
ones_shape.dim_size() != output_shape.dim_size())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
// At this point all preconditions are met. Can proceed with rewrite.
|
||||||
|
VLOG(3) << "Simplify multiply with all ones input: node=" << node->name()
|
||||||
|
<< "@" << output_shape << " ones=" << ones->name() << "@"
|
||||||
|
<< ones_shape << " input=" << input->name() << "@" << input_shape;
|
||||||
|
|
||||||
|
// 1. Create constant node with correct tile multiples
|
||||||
|
Tensor multiples(DT_INT32, TensorShape({output_shape.dim_size()}));
|
||||||
|
for (int i = 0; i < output_shape.dim_size(); ++i) {
|
||||||
|
int64 size = output_shape.dim(i).size() / input_shape.dim(i).size();
|
||||||
|
if (TF_PREDICT_FALSE(size >= INT_MAX)) {
|
||||||
|
return Status(error::OUT_OF_RANGE, "int32 overflow");
|
||||||
|
}
|
||||||
|
multiples.flat<int32>()(i) = static_cast<int32>(size);
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeDef* const_node = AddEmptyNode(const_node_name);
|
||||||
|
TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
|
||||||
|
const_node->name(), TensorValue(&multiples), const_node));
|
||||||
|
const_node->set_device(node->device());
|
||||||
|
ForwardControlDependencies(const_node, {ones});
|
||||||
|
AddToOptimizationQueue(const_node);
|
||||||
|
|
||||||
|
// 2. Replace multiply node with Tile(Const, input);
|
||||||
|
const DataType type = GetDataTypeFromAttr(*node, "T");
|
||||||
|
NodeDef* tile_node = AddEmptyNode(tile_node_name);
|
||||||
|
tile_node->set_op("Tile");
|
||||||
|
tile_node->set_device(node->device());
|
||||||
|
SetDataTypeToAttr(type, "T", tile_node);
|
||||||
|
SetDataTypeToAttr(DT_INT32, "Tmultiples", tile_node);
|
||||||
|
tile_node->add_input(input->name());
|
||||||
|
tile_node->add_input(const_node->name());
|
||||||
|
|
||||||
|
ForwardControlDependencies(tile_node, {node});
|
||||||
|
*simplified_node_name = tile_node->name();
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool IsOnes(const NodeDef& node) const {
|
||||||
|
if (!IsReallyConstant(node)) return false;
|
||||||
|
if (node.attr().at("dtype").type() != DT_FLOAT) return false;
|
||||||
|
|
||||||
|
Tensor tensor;
|
||||||
|
if (!tensor.FromProto(node.attr().at("value").tensor())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto values = tensor.flat<float>();
|
||||||
|
for (int i = 0; i < tensor.NumElements(); ++i) {
|
||||||
|
if (values(i) != 1.0f) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Simplify aggregation (e.g. AddN) nodes:
|
// Simplify aggregation (e.g. AddN) nodes:
|
||||||
//
|
//
|
||||||
// 1. Discard aggregate nodes with a single input and no control dependencies.
|
// 1. Discard aggregate nodes with a single input and no control dependencies.
|
||||||
@ -3704,6 +3826,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
|
|||||||
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
|
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
|
||||||
if (options_.replace_mul_with_square)
|
if (options_.replace_mul_with_square)
|
||||||
pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext);
|
pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext);
|
||||||
|
if (options_.replace_mul_with_tile)
|
||||||
|
pipeline.AddStage<ReplaceMulWithBroadcastByTile>(ctx, ctx_ext);
|
||||||
if (options_.remove_logical_not)
|
if (options_.remove_logical_not)
|
||||||
pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
|
pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
|
||||||
if (options_.reorder_cast_like_and_value_preserving)
|
if (options_.reorder_cast_like_and_value_preserving)
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_H_
|
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_H_
|
||||||
|
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||||
#include "tensorflow/core/grappler/utils.h"
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
@ -77,6 +78,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
|
|||||||
bool remove_redundant_cast = true;
|
bool remove_redundant_cast = true;
|
||||||
bool remove_redundant_reshape = true;
|
bool remove_redundant_reshape = true;
|
||||||
bool reorder_cast_like_and_value_preserving = true;
|
bool reorder_cast_like_and_value_preserving = true;
|
||||||
|
bool replace_mul_with_tile = true;
|
||||||
bool replace_mul_with_square = true;
|
bool replace_mul_with_square = true;
|
||||||
bool simplify_aggregation = true;
|
bool simplify_aggregation = true;
|
||||||
bool convert_pow = true;
|
bool convert_pow = true;
|
||||||
|
@ -106,6 +106,180 @@ TEST_F(ArithmeticOptimizerTest, NoOp) {
|
|||||||
VerifyGraphsMatch(item.graph, output, __LINE__);
|
VerifyGraphsMatch(item.graph, output, __LINE__);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTile) {
|
||||||
|
// Graph from b/176172427
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output input =
|
||||||
|
ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
|
||||||
|
ops::Placeholder::Shape({1, 44, 1, 96, 1, 64}));
|
||||||
|
Output ones = ops::Const(s.WithOpName("ones"), 1.0f, {1, 1, 2, 1, 2, 1});
|
||||||
|
Output multiply = ops::Mul(s.WithOpName("mul"), input, ones);
|
||||||
|
Output output = ops::Identity(s.WithOpName("output"), multiply);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"output"};
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
auto tensor =
|
||||||
|
GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 44, 1, 96, 1, 64}));
|
||||||
|
auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
|
||||||
|
ASSERT_EQ(expected.size(), 1);
|
||||||
|
|
||||||
|
GraphDef g;
|
||||||
|
ArithmeticOptimizer optimizer;
|
||||||
|
EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
|
||||||
|
OptimizeTwiceAndPrune(&optimizer, &item, &g);
|
||||||
|
EXPECT_EQ(g.node_size(), 4);
|
||||||
|
|
||||||
|
ASSERT_EQ(CountOpNodes(g, "Mul"), 0);
|
||||||
|
ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
|
||||||
|
|
||||||
|
NodeMap node_map(&g);
|
||||||
|
const string p = "ArithmeticOptimizer/ReplaceMulWithBroadcastByTile";
|
||||||
|
const NodeDef* t = node_map.GetNode(absl::StrCat(p, "_", "Tile_mul"));
|
||||||
|
const NodeDef* c = node_map.GetNode(absl::StrCat(p, "_", "Const_mul"));
|
||||||
|
ASSERT_NE(t, nullptr);
|
||||||
|
ASSERT_NE(c, nullptr);
|
||||||
|
EXPECT_EQ(t->op(), "Tile");
|
||||||
|
ASSERT_EQ(t->input_size(), 2);
|
||||||
|
EXPECT_EQ(t->input(0), "input");
|
||||||
|
EXPECT_EQ(t->input(1), c->name());
|
||||||
|
EXPECT_EQ(t->attr().at("T").type(), DT_FLOAT);
|
||||||
|
EXPECT_EQ(t->attr().at("Tmultiples").type(), c->attr().at("dtype").type());
|
||||||
|
|
||||||
|
auto result = EvaluateNodes(g, item.fetch, {{"input", tensor}});
|
||||||
|
ASSERT_EQ(result.size(), 1);
|
||||||
|
test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTilePreserveControl) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
|
||||||
|
ops::Placeholder::Shape({1, 1, 1}));
|
||||||
|
Output ones = ops::Const(s.WithOpName("ones").WithControlDependencies(input),
|
||||||
|
1.0f, {1, 2, 1});
|
||||||
|
Output multiply = ops::Mul(s.WithOpName("mul"), input, ones);
|
||||||
|
Output output = ops::Identity(s.WithOpName("output"), multiply);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"output"};
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
|
||||||
|
auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
|
||||||
|
ASSERT_EQ(expected.size(), 1);
|
||||||
|
|
||||||
|
GraphDef g;
|
||||||
|
ArithmeticOptimizer optimizer;
|
||||||
|
EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
|
||||||
|
OptimizeTwiceAndPrune(&optimizer, &item, &g);
|
||||||
|
EXPECT_EQ(g.node_size(), 4);
|
||||||
|
|
||||||
|
VLOG(0) << g.node_size();
|
||||||
|
for (auto&& node : g.node()) {
|
||||||
|
VLOG(0) << node.name();
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_EQ(CountOpNodes(g, "Mul"), 0);
|
||||||
|
ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
|
||||||
|
|
||||||
|
NodeMap node_map(&g);
|
||||||
|
const string p = "ArithmeticOptimizer/ReplaceMulWithBroadcastByTile";
|
||||||
|
const NodeDef* c = node_map.GetNode(absl::StrCat(p, "_", "Const_mul"));
|
||||||
|
ASSERT_NE(c, nullptr);
|
||||||
|
ASSERT_EQ(c->input_size(), 1);
|
||||||
|
EXPECT_TRUE(IsControlInput(c->input(0)));
|
||||||
|
EXPECT_EQ(c->input(0), "^input");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNoBroadcast) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output input =
|
||||||
|
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 2, 1}));
|
||||||
|
Output ones = ops::Const(s.WithOpName("ones"), 1.0f, {1, 2, 1});
|
||||||
|
Output multiply = ops::Mul(s.WithOpName("multiply"), input, ones);
|
||||||
|
Output output = ops::Identity(s.WithOpName("output"), multiply);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"output"};
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 1}));
|
||||||
|
auto expected =
|
||||||
|
EvaluateNodes(item.graph, item.fetch, {{"Placeholder", tensor}});
|
||||||
|
ASSERT_EQ(expected.size(), 1);
|
||||||
|
|
||||||
|
GraphDef g;
|
||||||
|
ArithmeticOptimizer optimizer;
|
||||||
|
EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
|
||||||
|
OptimizeTwiceAndPrune(&optimizer, &item, &g);
|
||||||
|
EXPECT_EQ(g.node_size(), 4);
|
||||||
|
|
||||||
|
VerifyGraphsMatch(item.graph, g, __LINE__);
|
||||||
|
|
||||||
|
auto result = EvaluateNodes(g, item.fetch, {{"Placeholder", tensor}});
|
||||||
|
ASSERT_EQ(result.size(), 1);
|
||||||
|
test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNotConst) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output input1 = ops::Placeholder(s.WithOpName("input1"), DT_FLOAT,
|
||||||
|
ops::Placeholder::Shape({1, 1, 1}));
|
||||||
|
Output input2 = ops::Placeholder(s.WithOpName("input2"), DT_FLOAT,
|
||||||
|
ops::Placeholder::Shape({1, 2, 1}));
|
||||||
|
Output multiply = ops::Mul(s.WithOpName("multiply"), input1, input2);
|
||||||
|
Output output = ops::Identity(s.WithOpName("output"), multiply);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"output"};
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
auto tensor1 = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
|
||||||
|
auto tensor2 = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 1}));
|
||||||
|
auto expected = EvaluateNodes(item.graph, item.fetch,
|
||||||
|
{{"input1", tensor1}, {"input2", tensor2}});
|
||||||
|
ASSERT_EQ(expected.size(), 1);
|
||||||
|
|
||||||
|
GraphDef g;
|
||||||
|
ArithmeticOptimizer optimizer;
|
||||||
|
EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
|
||||||
|
OptimizeTwiceAndPrune(&optimizer, &item, &g);
|
||||||
|
EXPECT_EQ(g.node_size(), 4);
|
||||||
|
|
||||||
|
VerifyGraphsMatch(item.graph, g, __LINE__);
|
||||||
|
|
||||||
|
auto result = EvaluateNodes(item.graph, item.fetch,
|
||||||
|
{{"input1", tensor1}, {"input2", tensor2}});
|
||||||
|
ASSERT_EQ(result.size(), 1);
|
||||||
|
test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNotOnes) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output input =
|
||||||
|
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 1, 1}));
|
||||||
|
Output ones = ops::Const(s.WithOpName("ones"), 2.0f, {1, 2, 1});
|
||||||
|
Output multiply = ops::Mul(s.WithOpName("multiply"), input, ones);
|
||||||
|
Output output = ops::Identity(s.WithOpName("output"), multiply);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"output"};
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
|
||||||
|
auto expected =
|
||||||
|
EvaluateNodes(item.graph, item.fetch, {{"Placeholder", tensor}});
|
||||||
|
ASSERT_EQ(expected.size(), 1);
|
||||||
|
|
||||||
|
GraphDef g;
|
||||||
|
ArithmeticOptimizer optimizer;
|
||||||
|
EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
|
||||||
|
OptimizeTwiceAndPrune(&optimizer, &item, &g);
|
||||||
|
EXPECT_EQ(g.node_size(), 4);
|
||||||
|
|
||||||
|
VerifyGraphsMatch(item.graph, g, __LINE__);
|
||||||
|
|
||||||
|
auto result = EvaluateNodes(g, item.fetch, {{"Placeholder", tensor}});
|
||||||
|
ASSERT_EQ(result.size(), 1);
|
||||||
|
test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
|
TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
|
Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
|
||||||
|
@ -153,6 +153,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
|
|||||||
optimizer->options_.reorder_cast_like_and_value_preserving = true;
|
optimizer->options_.reorder_cast_like_and_value_preserving = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EnableOnlyReplaceMulWithBroadcastByTile(ArithmeticOptimizer* optimizer) {
|
||||||
|
DisableAllStages(optimizer);
|
||||||
|
optimizer->options_.replace_mul_with_tile = true;
|
||||||
|
}
|
||||||
|
|
||||||
void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) {
|
void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) {
|
||||||
DisableAllStages(optimizer);
|
DisableAllStages(optimizer);
|
||||||
optimizer->options_.replace_mul_with_square = true;
|
optimizer->options_.replace_mul_with_square = true;
|
||||||
@ -258,6 +263,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
|
|||||||
options.remove_negation = false;
|
options.remove_negation = false;
|
||||||
options.remove_logical_not = false;
|
options.remove_logical_not = false;
|
||||||
options.reorder_cast_like_and_value_preserving = false;
|
options.reorder_cast_like_and_value_preserving = false;
|
||||||
|
options.replace_mul_with_tile = false;
|
||||||
options.replace_mul_with_square = false;
|
options.replace_mul_with_square = false;
|
||||||
options.simplify_aggregation = false;
|
options.simplify_aggregation = false;
|
||||||
options.unary_ops_composition = false;
|
options.unary_ops_composition = false;
|
||||||
|
Loading…
Reference in New Issue
Block a user