Add optimization to reorder redundant reshapes around unary ops.
PiperOrigin-RevId: 358239215 Change-Id: Ie0723ad4322fd0099185bcf5c92258ea797f917c
This commit is contained in:
parent
a98039c09e
commit
9378d6c3db
tensorflow/core/grappler/optimizers
@ -2142,93 +2142,6 @@ class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage {
|
||||
}
|
||||
};
|
||||
|
||||
// Reorder redundant reshapes around a single unary element-wise op, i.e.,
|
||||
//
|
||||
// input -> reshape A -> unary -> reshape B -> output
|
||||
//
|
||||
// becomes
|
||||
//
|
||||
// input -> unary -> reshape A -> reshape B -> output
|
||||
//
|
||||
// We conservatively consider reshapes to be redundant only if:
|
||||
// 1) The input shape of A is equal to the output shape of B.
|
||||
// 2) Both A and unary have a single output.
|
||||
//
|
||||
// A later pass (RemoveRedundantReshapeOrBroadcastTo) will remove both reshapes
|
||||
//
|
||||
class ReorderRedundantReshapeAroundUnary : public ArithmeticOptimizerStage {
|
||||
public:
|
||||
explicit ReorderRedundantReshapeAroundUnary(
|
||||
const GraphOptimizerContext& ctx,
|
||||
const ArithmeticOptimizerContext& ctx_ext)
|
||||
: ArithmeticOptimizerStage("ReorderRedundantReshapeAroundUnary", ctx,
|
||||
ctx_ext) {}
|
||||
|
||||
~ReorderRedundantReshapeAroundUnary() override = default;
|
||||
|
||||
bool IsSupported(const NodeDef* node) const override {
|
||||
return IsReshape(*node) && !IsInPreserveSet(*node);
|
||||
}
|
||||
|
||||
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
|
||||
// Check that we have a chain of (reshape -> unary -> reshape), with no
|
||||
// additional outputs on either the first reshape or unary op
|
||||
NodeDef* head = node;
|
||||
if (!IsReshape(*head) || IsInPreserveSet(*head)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
NodeDef* unary;
|
||||
TF_RETURN_IF_ERROR(GetInputNode(head->input(0), &unary));
|
||||
if (!IsUnaryElementWise(*unary) || IsInPreserveSet(*unary) ||
|
||||
NumNonControlOutputs(*unary, *ctx().node_map) != 1) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
NodeDef* tail;
|
||||
TF_RETURN_IF_ERROR(GetInputNode(unary->input(0), &tail));
|
||||
if (!IsReshape(*tail) || IsInPreserveSet(*tail) ||
|
||||
NumNonControlOutputs(*tail, *ctx().node_map) != 1) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The reshapes are a no-op if the input and output shapes match
|
||||
NodeDef* input;
|
||||
TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &input));
|
||||
if (!InputMatchesOutputShape(*input, *head)) {
|
||||
VLOG(3) << "Input and output shapes are unequal: input=" << input->name()
|
||||
<< ", output=" << head->name();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Swap `unary` and `tail` reshape
|
||||
unary->set_input(0, input->name());
|
||||
ctx().node_map->UpdateInput(unary->name(), tail->name(), input->name());
|
||||
tail->set_input(0, unary->name());
|
||||
ctx().node_map->UpdateInput(tail->name(), input->name(), unary->name());
|
||||
head->set_input(0, tail->name());
|
||||
ctx().node_map->UpdateInput(head->name(), unary->name(), tail->name());
|
||||
|
||||
*simplified_node_name = node->name();
|
||||
AddToOptimizationQueue(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns whether the input shape of the first op matches the output shape of
|
||||
// the second op.
|
||||
bool InputMatchesOutputShape(const NodeDef& input, const NodeDef& output) {
|
||||
const OpInfo::TensorProperties* input_props;
|
||||
const OpInfo::TensorProperties* output_props;
|
||||
if (!GetTensorProperties(input.name(), &input_props).ok() ||
|
||||
!GetTensorProperties(output.name(), &output_props).ok()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return ShapesSymbolicallyEqual(input_props->shape(), output_props->shape());
|
||||
}
|
||||
};
|
||||
|
||||
// Fold a multiply of a scalar into the following convolution. This folding
|
||||
// can jump across nodes that merely reorders data (such as reshape and
|
||||
// transpose). For example, we can optimize
|
||||
@ -3919,8 +3832,6 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
|
||||
pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
|
||||
if (options_.reorder_cast_like_and_value_preserving)
|
||||
pipeline.AddStage<ReorderCastLikeAndValuePreserving>(ctx, ctx_ext);
|
||||
if (options_.reorder_redundant_reshape_around_unary)
|
||||
pipeline.AddStage<ReorderRedundantReshapeAroundUnary>(ctx, ctx_ext);
|
||||
if (options_.simplify_aggregation)
|
||||
pipeline.AddStage<SimplifyAggregation>(ctx, ctx_ext);
|
||||
if (options_.hoist_cwise_unary_chains)
|
||||
|
@ -78,7 +78,6 @@ class ArithmeticOptimizer : public GraphOptimizer {
|
||||
bool remove_redundant_cast = true;
|
||||
bool remove_redundant_reshape = true;
|
||||
bool reorder_cast_like_and_value_preserving = true;
|
||||
bool reorder_redundant_reshape_around_unary = true;
|
||||
bool replace_mul_with_tile = true;
|
||||
bool replace_mul_with_square = true;
|
||||
bool simplify_aggregation = true;
|
||||
|
@ -1003,103 +1003,6 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
|
||||
test::ExpectTensorNear<complex64>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeAroundUnary) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output inputs =
|
||||
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 300, 300, 1}));
|
||||
Output reshape0 = ops::Reshape(s.WithOpName("Reshape0"), inputs,
|
||||
ops::Const(s, {1, 90000, 1}, {3}));
|
||||
Output unary = ops::Sigmoid(s, reshape0);
|
||||
Output reshape1 = ops::Reshape(s.WithOpName("Reshape1"), unary,
|
||||
ops::Const(s, {1, 300, 300, 1}, {4}));
|
||||
Output outputs = ops::Identity(s.WithOpName("outputs"), reshape1);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"outputs"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
auto t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 300, 300, 1}));
|
||||
auto expected = EvaluateNodes(item.graph, item.fetch, {{"Placeholder", t}});
|
||||
ASSERT_EQ(expected.size(), 1);
|
||||
|
||||
GraphDef output;
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableOnlyReorderRedundantReshapeAroundUnary(&optimizer);
|
||||
OptimizeTwiceAndPrune(&optimizer, &item, &output);
|
||||
EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
|
||||
|
||||
// Reshapes should be removed after pruning
|
||||
EnableOnlyRemoveRedundantReshape(&optimizer);
|
||||
OptimizeTwiceAndPrune(&optimizer, &item, &output);
|
||||
EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
|
||||
|
||||
auto actual = EvaluateNodes(output, item.fetch, {{"Placeholder", t}});
|
||||
ASSERT_EQ(actual.size(), 1);
|
||||
test::ExpectTensorNear<float>(actual[0], expected[0], 1e-6);
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeAroundUnaryNotOutput) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output inputs =
|
||||
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 300, 300, 1}));
|
||||
Output reshape0 = ops::Reshape(s, inputs, ops::Const(s, {1, 90000, 1}, {3}));
|
||||
Output unary = ops::Sigmoid(s.WithOpName("sigmoid"), reshape0);
|
||||
Output reshape1 =
|
||||
ops::Reshape(s, unary, ops::Const(s, {1, 300, 300, 1}, {4}));
|
||||
Output outputs = ops::Identity(s.WithOpName("output"), reshape1);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"output"};
|
||||
item.keep_ops = {"sigmoid"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
auto t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 300, 300, 1}));
|
||||
auto expected = EvaluateNodes(item.graph, item.fetch, {{"Placeholder", t}});
|
||||
ASSERT_EQ(expected.size(), 1);
|
||||
|
||||
// Reshape should not be moved since unary is a keep op
|
||||
GraphDef output;
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableOnlyReorderRedundantReshapeAroundUnary(&optimizer);
|
||||
OptimizeTwiceAndPrune(&optimizer, &item, &output);
|
||||
EnableOnlyRemoveRedundantReshape(&optimizer);
|
||||
OptimizeTwiceAndPrune(&optimizer, &item, &output);
|
||||
|
||||
EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
|
||||
auto actual = EvaluateNodes(output, item.fetch, {{"Placeholder", t}});
|
||||
ASSERT_EQ(actual.size(), 1);
|
||||
test::ExpectTensorNear<float>(actual[0], expected[0], 1e-6);
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeAroundUnaryNotIdentity) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output inputs =
|
||||
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 300, 300, 1}));
|
||||
Output reshape0 = ops::Reshape(s, inputs, ops::Const(s, {1, 90000, 1}, {3}));
|
||||
Output unary = ops::Sigmoid(s, reshape0);
|
||||
// [1, 300, 300, 1] is not equivalent to [1, 300, 1, 300]
|
||||
Output reshape1 =
|
||||
ops::Reshape(s, unary, ops::Const(s, {1, 300, 1, 300}, {4}));
|
||||
Output outputs = ops::Identity(s.WithOpName("outputs"), reshape1);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"outputs"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
auto t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 300, 300, 1}));
|
||||
auto expected = EvaluateNodes(item.graph, item.fetch, {{"Placeholder", t}});
|
||||
ASSERT_EQ(expected.size(), 1);
|
||||
|
||||
GraphDef output;
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableOnlyReorderRedundantReshapeAroundUnary(&optimizer);
|
||||
OptimizeTwiceAndPrune(&optimizer, &item, &output);
|
||||
EnableOnlyRemoveRedundantReshape(&optimizer);
|
||||
OptimizeTwiceAndPrune(&optimizer, &item, &output);
|
||||
|
||||
EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
|
||||
auto actual = EvaluateNodes(output, item.fetch, {{"Placeholder", t}});
|
||||
ASSERT_EQ(actual.size(), 1);
|
||||
test::ExpectTensorNear<float>(actual[0], expected[0], 1e-6);
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeIdentityReshape) {
|
||||
for (bool is_broadcastto : {false, true}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
@ -138,12 +138,6 @@ class ArithmeticOptimizerTest : public GrapplerTest {
|
||||
optimizer->options_.remove_redundant_cast = true;
|
||||
}
|
||||
|
||||
void EnableOnlyReorderRedundantReshapeAroundUnary(
|
||||
ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.reorder_redundant_reshape_around_unary = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_redundant_reshape = true;
|
||||
|
Loading…
Reference in New Issue
Block a user