[Grappler] Do not hoist Relu from Concat if it can be fused with preceding ops.
PiperOrigin-RevId: 285088679 Change-Id: I47cb205e6de7ba30cc1ea0050c52b440c06b9e33
This commit is contained in:
parent
1afe0cb258
commit
824250a7ed
tensorflow/core/grappler/optimizers
@ -1631,6 +1631,17 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
|
||||
// TODO(rmlarsen): Allow outgoing control edges.
|
||||
return false;
|
||||
}
|
||||
// Do not hoist Relu if it can be fused with its predecessors. This is
|
||||
// important because remapping runs after arithmetic.
|
||||
if (IsRelu(*op) || IsRelu6(*op)) {
|
||||
NodeDef* operand = nullptr;
|
||||
if (!GetInputNode(op->input(0), &operand).ok()) {
|
||||
return false;
|
||||
}
|
||||
if (IsFusedBatchNorm(*operand) || IsBiasAdd(*operand)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -3077,6 +3077,47 @@ TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, DoNotHoistReluFromConcat) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output weights1 = ops::Const(s.WithOpName("weights1"),
|
||||
Input::Initializer(1.0f, {5, 5, 3, 4}));
|
||||
Output weights2 = ops::Const(s.WithOpName("weights2"),
|
||||
Input::Initializer(2.0f, {5, 5, 3, 4}));
|
||||
Output biases =
|
||||
ops::Const(s.WithOpName("biases"), Input::Initializer(2.0f, {4}));
|
||||
Output axis = ops::Const(s.WithOpName("axis"), 3, {});
|
||||
Output input = ops::Const(s.WithOpName("input"),
|
||||
Input::Initializer(1.0f, {1, 28, 28, 3}));
|
||||
Output branch1 =
|
||||
ops::Conv2D(s.WithOpName("conv1"), input, weights1, {1, 1, 1, 1}, "SAME");
|
||||
branch1 = ops::BiasAdd(s.WithOpName("biasadd1"), branch1, biases);
|
||||
branch1 = ops::Relu(s.WithOpName("relu1"), branch1);
|
||||
Output branch2 =
|
||||
ops::Conv2D(s.WithOpName("conv2"), input, weights2, {1, 1, 1, 1}, "SAME");
|
||||
branch2 = ops::BiasAdd(s.WithOpName("biasadd2"), branch2, biases);
|
||||
branch2 = ops::Relu(s.WithOpName("relu2"), branch2);
|
||||
Output concat = ops::Concat(s.WithOpName("concat"), {branch1, branch2}, axis);
|
||||
Output output = ops::Identity(s.WithOpName("output"), concat);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"output"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
|
||||
GraphDef new_graph;
|
||||
ArithmeticOptimizer optimizer;
|
||||
OptimizeAndPrune(&optimizer, &item, &new_graph);
|
||||
|
||||
// Verify that the two Relus are not hoisted.
|
||||
EXPECT_EQ(CountOpNodes(new_graph, "Relu"), 2);
|
||||
|
||||
auto tensors = EvaluateNodes(new_graph, item.fetch);
|
||||
for (int i = 0; i < item.fetch.size(); ++i) {
|
||||
test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
|
||||
|
Loading…
Reference in New Issue
Block a user