Merge pull request #37903 from mpjlu:fixOpFuse
PiperOrigin-RevId: 317378663 Change-Id: I02d180fc4443e3fc95e1d2a129322ca1115550da
This commit is contained in:
commit
23bb7b48a1
@ -1593,7 +1593,8 @@ bool IsConv2DWithAdd(const RemapperContext& ctx, int node_index) {
|
||||
// shapes:
|
||||
// (1) Splitting FusedBatchNorm into primitives.
|
||||
// (2) Fusing side input and/or activation into FusedBatchNorm.
|
||||
// (3) INTEL_MKL specific: Conv2D -> Add or Conv2D -> BiasAdd -> Add.
|
||||
// (3) Fusing Conv2D biasadd and relu on GPU
|
||||
// (4) INTEL_MKL specific: Conv2D -> Add or Conv2D -> BiasAdd -> Add.
|
||||
bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
|
||||
// Candidate for a FusedBatchNorm splitting.
|
||||
const auto* node_view = ctx.graph_view.GetNode(node_index);
|
||||
@ -1609,6 +1610,31 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
|
||||
return true;
|
||||
};
|
||||
|
||||
const auto is_relu_biasadd_conv2d_candidate = [&]() -> bool {
|
||||
if (!IsRelu(*node_def)) return false;
|
||||
if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false;
|
||||
|
||||
if (node_view->NumRegularFanins() < 1) return false;
|
||||
const auto& relu_fanin_0 = node_view->GetRegularFanin(0);
|
||||
const auto* relu_fanin_0_node_view = relu_fanin_0.node_view();
|
||||
const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
|
||||
|
||||
if (!IsBiasAdd(*relu_fanin_0_node_def)) return false;
|
||||
if (GetDataTypeFromAttr(*relu_fanin_0_node_def, "T") != DT_FLOAT)
|
||||
return false;
|
||||
|
||||
if (relu_fanin_0_node_view->NumRegularFanins() < 1) return false;
|
||||
|
||||
const auto& biasadd_fanin_0 = relu_fanin_0_node_view->GetRegularFanin(0);
|
||||
const auto* biasadd_fanin_0_node_def = biasadd_fanin_0.node_view()->node();
|
||||
|
||||
if (!IsConv2D(*biasadd_fanin_0_node_def)) return false;
|
||||
if (GetDataTypeFromAttr(*biasadd_fanin_0_node_def, "T") != DT_FLOAT)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
// Candidate for a FusedBatchNorm fusion.
|
||||
const auto is_batch_norm_fusion_candidate = [&]() -> bool {
|
||||
if (!IsRelu(*node_def)) return false;
|
||||
@ -1643,7 +1669,8 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
|
||||
return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() ||
|
||||
IsConv2DWithAdd(ctx, node_index);
|
||||
#else
|
||||
return is_batch_norm_candidate() || is_batch_norm_fusion_candidate();
|
||||
return is_relu_biasadd_conv2d_candidate() || is_batch_norm_candidate() ||
|
||||
is_batch_norm_fusion_candidate();
|
||||
#endif // INTEL_MKL
|
||||
}
|
||||
|
||||
@ -1713,6 +1740,17 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
}
|
||||
#endif //! INTEL_MKL
|
||||
|
||||
// Infer properties lazily in case they are not needed.
|
||||
if (!ctx.inferred_graph_properties && RequiresInferredShapes(ctx, i)) {
|
||||
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
|
||||
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
|
||||
assume_valid_feeds,
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_input_tensor_values=*/true,
|
||||
/*include_output_tensor_values=*/false));
|
||||
ctx.inferred_graph_properties = true;
|
||||
}
|
||||
|
||||
// Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd into the
|
||||
// _Fused{Conv2D,DepthwiseConv2dNative,MatMul}
|
||||
ContractionWithBiasAdd contract_with_bias;
|
||||
|
@ -449,6 +449,78 @@ TEST_F(RemapperTest, FuseMatMulWithBias) {
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
|
||||
TEST_F(RemapperTest, FuseConv2DWithBiasAndActivationOnGPU) {
|
||||
#if !(GOOGLE_CUDA)
|
||||
GTEST_SKIP() << "No CUDA, skip FuseConv2DWithBiasAndActivation on GPU";
|
||||
#endif // !GOOGLE_CUDA
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto input_shape = Placeholder::Shape({8, 32, 32, 3});
|
||||
auto filter_shape = Placeholder::Shape({3, 3, 3, 128});
|
||||
auto bias_shape = Placeholder::Shape({128});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
|
||||
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
|
||||
|
||||
std::vector<int> strides = {1, 1, 1, 1};
|
||||
auto conv = ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME");
|
||||
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);
|
||||
|
||||
ops::Identity fetch = [&]() -> ops::Identity {
|
||||
auto activate = s.WithOpName("activation");
|
||||
auto fetch = s.WithOpName("fetch");
|
||||
return ops::Identity(fetch, ops::Relu(activate, bias_add));
|
||||
}();
|
||||
|
||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
|
||||
auto filter_t = GenerateRandomTensor<DT_FLOAT>({3, 3, 3, 128});
|
||||
auto bias_t = GenerateRandomTensor<DT_FLOAT>({128});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on GPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:GPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "activation") {
|
||||
EXPECT_EQ(node.op(), "_FusedConv2D");
|
||||
ASSERT_GE(node.input_size(), 3);
|
||||
EXPECT_EQ(node.input(0), "input");
|
||||
EXPECT_EQ(node.input(1), "filter");
|
||||
|
||||
EXPECT_EQ(node.attr().at("num_args").i(), 1);
|
||||
EXPECT_EQ(node.input(2), "bias");
|
||||
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||
ASSERT_EQ(fused_ops.size(), 2);
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||
EXPECT_EQ(fused_ops[1], "Relu");
|
||||
found++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(found, 1);
|
||||
|
||||
if (GetNumAvailableGPUs() > 0) {
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user