Merge pull request #37903 from mpjlu:fixOpFuse

PiperOrigin-RevId: 317378663
Change-Id: I02d180fc4443e3fc95e1d2a129322ca1115550da
This commit is contained in:
TensorFlower Gardener 2020-06-19 14:31:10 -07:00
commit 23bb7b48a1
2 changed files with 112 additions and 2 deletions

View File

@ -1593,7 +1593,8 @@ bool IsConv2DWithAdd(const RemapperContext& ctx, int node_index) {
// shapes:
// (1) Splitting FusedBatchNorm into primitives.
// (2) Fusing side input and/or activation into FusedBatchNorm.
// (3) INTEL_MKL specific: Conv2D -> Add or Conv2D -> BiasAdd -> Add.
// (3) Fusing Conv2D biasadd and relu on GPU
// (4) INTEL_MKL specific: Conv2D -> Add or Conv2D -> BiasAdd -> Add.
bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
// Candidate for a FusedBatchNorm splitting.
const auto* node_view = ctx.graph_view.GetNode(node_index);
@ -1609,6 +1610,31 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
return true;
};
const auto is_relu_biasadd_conv2d_candidate = [&]() -> bool {
if (!IsRelu(*node_def)) return false;
if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false;
if (node_view->NumRegularFanins() < 1) return false;
const auto& relu_fanin_0 = node_view->GetRegularFanin(0);
const auto* relu_fanin_0_node_view = relu_fanin_0.node_view();
const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
if (!IsBiasAdd(*relu_fanin_0_node_def)) return false;
if (GetDataTypeFromAttr(*relu_fanin_0_node_def, "T") != DT_FLOAT)
return false;
if (relu_fanin_0_node_view->NumRegularFanins() < 1) return false;
const auto& biasadd_fanin_0 = relu_fanin_0_node_view->GetRegularFanin(0);
const auto* biasadd_fanin_0_node_def = biasadd_fanin_0.node_view()->node();
if (!IsConv2D(*biasadd_fanin_0_node_def)) return false;
if (GetDataTypeFromAttr(*biasadd_fanin_0_node_def, "T") != DT_FLOAT)
return false;
return true;
};
// Candidate for a FusedBatchNorm fusion.
const auto is_batch_norm_fusion_candidate = [&]() -> bool {
if (!IsRelu(*node_def)) return false;
@ -1643,7 +1669,8 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() ||
IsConv2DWithAdd(ctx, node_index);
#else
return is_batch_norm_candidate() || is_batch_norm_fusion_candidate();
return is_relu_biasadd_conv2d_candidate() || is_batch_norm_candidate() ||
is_batch_norm_fusion_candidate();
#endif // INTEL_MKL
}
@ -1713,6 +1740,17 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
}
#endif //! INTEL_MKL
// Infer properties lazily in case they are not needed.
if (!ctx.inferred_graph_properties && RequiresInferredShapes(ctx, i)) {
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
assume_valid_feeds,
/*aggressive_shape_inference=*/false,
/*include_input_tensor_values=*/true,
/*include_output_tensor_values=*/false));
ctx.inferred_graph_properties = true;
}
// Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd into the
// _Fused{Conv2D,DepthwiseConv2dNative,MatMul}
ContractionWithBiasAdd contract_with_bias;

View File

@ -449,6 +449,78 @@ TEST_F(RemapperTest, FuseMatMulWithBias) {
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
}
TEST_F(RemapperTest, FuseConv2DWithBiasAndActivationOnGPU) {
#if !(GOOGLE_CUDA)
GTEST_SKIP() << "No CUDA, skip FuseConv2DWithBiasAndActivation on GPU";
#endif // !GOOGLE_CUDA
using ::tensorflow::ops::Placeholder;
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = Placeholder::Shape({8, 32, 32, 3});
auto filter_shape = Placeholder::Shape({3, 3, 3, 128});
auto bias_shape = Placeholder::Shape({128});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
std::vector<int> strides = {1, 1, 1, 1};
auto conv = ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME");
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);
ops::Identity fetch = [&]() -> ops::Identity {
auto activate = s.WithOpName("activation");
auto fetch = s.WithOpName("fetch");
return ops::Identity(fetch, ops::Relu(activate, bias_add));
}();
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
auto filter_t = GenerateRandomTensor<DT_FLOAT>({3, 3, 3, 128});
auto bias_t = GenerateRandomTensor<DT_FLOAT>({128});
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
// Place all nodes on GPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:GPU:0");
}
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "activation") {
EXPECT_EQ(node.op(), "_FusedConv2D");
ASSERT_GE(node.input_size(), 3);
EXPECT_EQ(node.input(0), "input");
EXPECT_EQ(node.input(1), "filter");
EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "bias");
const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(fused_ops.size(), 2);
EXPECT_EQ(fused_ops[0], "BiasAdd");
EXPECT_EQ(fused_ops[1], "Relu");
found++;
}
}
EXPECT_EQ(found, 1);
if (GetNumAvailableGPUs() > 0) {
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
}
}
TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) {
using ::tensorflow::ops::Placeholder;