[Grappler] Do not hoist Relu from Concat if it can be fused with preceding ops.

PiperOrigin-RevId: 285088679
Change-Id: I47cb205e6de7ba30cc1ea0050c52b440c06b9e33
This commit is contained in:
Jingyue Wu 2019-12-11 16:39:02 -08:00 committed by TensorFlower Gardener
parent 1afe0cb258
commit 824250a7ed
2 changed files with 52 additions and 0 deletions
tensorflow/core/grappler/optimizers

View File

@ -1631,6 +1631,17 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
// TODO(rmlarsen): Allow outgoing control edges.
return false;
}
// Do not hoist Relu if it can be fused with its predecessors. This is
// important because remapping runs after arithmetic.
if (IsRelu(*op) || IsRelu6(*op)) {
NodeDef* operand = nullptr;
if (!GetInputNode(op->input(0), &operand).ok()) {
return false;
}
if (IsFusedBatchNorm(*operand) || IsBiasAdd(*operand)) {
return false;
}
}
}
return true;
}

View File

@ -3077,6 +3077,47 @@ TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, DoNotHoistReluFromConcat) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output weights1 = ops::Const(s.WithOpName("weights1"),
Input::Initializer(1.0f, {5, 5, 3, 4}));
Output weights2 = ops::Const(s.WithOpName("weights2"),
Input::Initializer(2.0f, {5, 5, 3, 4}));
Output biases =
ops::Const(s.WithOpName("biases"), Input::Initializer(2.0f, {4}));
Output axis = ops::Const(s.WithOpName("axis"), 3, {});
Output input = ops::Const(s.WithOpName("input"),
Input::Initializer(1.0f, {1, 28, 28, 3}));
Output branch1 =
ops::Conv2D(s.WithOpName("conv1"), input, weights1, {1, 1, 1, 1}, "SAME");
branch1 = ops::BiasAdd(s.WithOpName("biasadd1"), branch1, biases);
branch1 = ops::Relu(s.WithOpName("relu1"), branch1);
Output branch2 =
ops::Conv2D(s.WithOpName("conv2"), input, weights2, {1, 1, 1, 1}, "SAME");
branch2 = ops::BiasAdd(s.WithOpName("biasadd2"), branch2, biases);
branch2 = ops::Relu(s.WithOpName("relu2"), branch2);
Output concat = ops::Concat(s.WithOpName("concat"), {branch1, branch2}, axis);
Output output = ops::Identity(s.WithOpName("output"), concat);
GrapplerItem item;
item.fetch = {"output"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
GraphDef new_graph;
ArithmeticOptimizer optimizer;
OptimizeAndPrune(&optimizer, &item, &new_graph);
// Verify that the two Relus are not hoisted.
EXPECT_EQ(CountOpNodes(new_graph, "Relu"), 2);
auto tensors = EvaluateNodes(new_graph, item.fetch);
for (int i = 0; i < item.fetch.size(); ++i) {
test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
}
}
TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});