Merge pull request #37624 from Intel-tensorflow:yifeng/depthwise_conv2d_fusion
PiperOrigin-RevId: 307656312 Change-Id: I7098e05501111caea1c7e32329eb4e502292273e
This commit is contained in:
commit
0fa30f59f3
@ -273,6 +273,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
|
csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
|
||||||
csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3";
|
csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3";
|
||||||
csinfo_.fused_conv2d = "_FusedConv2D";
|
csinfo_.fused_conv2d = "_FusedConv2D";
|
||||||
|
csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative";
|
||||||
csinfo_.fused_matmul = "_FusedMatMul";
|
csinfo_.fused_matmul = "_FusedMatMul";
|
||||||
csinfo_.identity = "Identity";
|
csinfo_.identity = "Identity";
|
||||||
csinfo_.leakyrelu = "LeakyRelu";
|
csinfo_.leakyrelu = "LeakyRelu";
|
||||||
@ -295,6 +296,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
csinfo_.mkl_depthwise_conv2d_grad_filter =
|
csinfo_.mkl_depthwise_conv2d_grad_filter =
|
||||||
"_MklDepthwiseConv2dNativeBackpropFilter";
|
"_MklDepthwiseConv2dNativeBackpropFilter";
|
||||||
csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
|
csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
|
||||||
|
csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative";
|
||||||
csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
|
csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
|
||||||
csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
|
csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
|
||||||
csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
|
csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
|
||||||
@ -479,6 +481,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
|
rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
|
||||||
CopyAttrsFusedConv2D, FusedConv2DRewrite,
|
CopyAttrsFusedConv2D, FusedConv2DRewrite,
|
||||||
kRewriteForLayoutPropagation});
|
kRewriteForLayoutPropagation});
|
||||||
|
rinfo_.push_back({csinfo_.fused_depthwise_conv2d,
|
||||||
|
csinfo_.mkl_fused_depthwise_conv2d, CopyAttrsFusedConv2D,
|
||||||
|
FusedDepthwiseConv2DRewrite,
|
||||||
|
kRewriteForLayoutPropagation});
|
||||||
rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul,
|
rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul,
|
||||||
CopyAttrsAllCheckConstFilter, FusedMatMulRewrite});
|
CopyAttrsAllCheckConstFilter, FusedMatMulRewrite});
|
||||||
|
|
||||||
@ -925,6 +931,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
string fused_batch_norm_v3;
|
string fused_batch_norm_v3;
|
||||||
string fused_batch_norm_grad_v3;
|
string fused_batch_norm_grad_v3;
|
||||||
string fused_conv2d;
|
string fused_conv2d;
|
||||||
|
string fused_depthwise_conv2d;
|
||||||
string fused_matmul;
|
string fused_matmul;
|
||||||
string identity;
|
string identity;
|
||||||
string leakyrelu;
|
string leakyrelu;
|
||||||
@ -945,6 +952,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
string mkl_depthwise_conv2d_grad_input;
|
string mkl_depthwise_conv2d_grad_input;
|
||||||
string mkl_depthwise_conv2d_grad_filter;
|
string mkl_depthwise_conv2d_grad_filter;
|
||||||
string mkl_fused_conv2d;
|
string mkl_fused_conv2d;
|
||||||
|
string mkl_fused_depthwise_conv2d;
|
||||||
string mkl_fused_matmul;
|
string mkl_fused_matmul;
|
||||||
string mkl_pad_with_conv2d;
|
string mkl_pad_with_conv2d;
|
||||||
string mkl_pad_with_fused_conv2d;
|
string mkl_pad_with_fused_conv2d;
|
||||||
@ -1675,6 +1683,25 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"});
|
fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool FusedDepthwiseConv2DRewrite(const Node* n) {
|
||||||
|
// MKL DNN currently doesn't support all fusions that grappler fuses
|
||||||
|
// together with DepthwiseConv2D (ex. batchnorm). We rewrite
|
||||||
|
// _FusedDepthwiseConv2DNative only if it includes those we support.
|
||||||
|
DataType T;
|
||||||
|
if (!TryGetNodeAttr(n->def(), "T", &T) ||
|
||||||
|
!mkl_op_registry::IsMklLayoutDependentOp(
|
||||||
|
csinfo_.mkl_fused_depthwise_conv2d, T)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<string> fused_ops;
|
||||||
|
TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
|
||||||
|
return (fused_ops == std::vector<string>{"BiasAdd"} ||
|
||||||
|
fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
|
||||||
|
fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
|
||||||
|
fused_ops == std::vector<string>{"BiasAdd", "Elu"});
|
||||||
|
}
|
||||||
|
|
||||||
// Rewrites input node to a new node specified by its matching rewrite info.
|
// Rewrites input node to a new node specified by its matching rewrite info.
|
||||||
//
|
//
|
||||||
// Method first searches matching rewrite info for input node and then
|
// Method first searches matching rewrite info for input node and then
|
||||||
@ -3703,6 +3730,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
|
|||||||
n->type_string() != csinfo_.pad_with_fused_conv2d &&
|
n->type_string() != csinfo_.pad_with_fused_conv2d &&
|
||||||
n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
|
n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
|
||||||
n->type_string() != csinfo_.fused_conv2d &&
|
n->type_string() != csinfo_.fused_conv2d &&
|
||||||
|
n->type_string() != csinfo_.fused_depthwise_conv2d &&
|
||||||
n->type_string() != csinfo_.fused_matmul &&
|
n->type_string() != csinfo_.fused_matmul &&
|
||||||
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
|
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
|
||||||
T)) {
|
T)) {
|
||||||
|
@ -1789,6 +1789,56 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive6);
|
|||||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive7);
|
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive7);
|
||||||
#undef REGISTER_TEST
|
#undef REGISTER_TEST
|
||||||
|
|
||||||
|
// Rewrite test for _FusedDepthwiseConv2dNative Op fusion
|
||||||
|
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||||
|
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||||
|
InitGraph( \
|
||||||
|
"node { name: 'A' op: '" #INPUT "'}" \
|
||||||
|
"node { name: 'B' op: '" #INPUT "'}" \
|
||||||
|
"node { name: 'C' op: '" #INPUT "'}" \
|
||||||
|
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
|
||||||
|
" attr { key: 'T' value { type: " #T " } }" \
|
||||||
|
" attr { key: 'num_args' value { i: 1 } }" \
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||||
|
"} }" \
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }" \
|
||||||
|
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||||
|
"} }" \
|
||||||
|
" attr { key: 'fused_ops' value { list: " FUSED_OPS " } }" \
|
||||||
|
" attr { key: 'epsilon' value { f: 0.001 }}" \
|
||||||
|
" input: ['A', 'B', 'C']}" \
|
||||||
|
"node { name: 'E' op: 'Zeta'" \
|
||||||
|
"attr { key: 'T' value { type: " #T " } }" \
|
||||||
|
" input: ['D', 'C'] }"); \
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||||
|
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
|
||||||
|
"D(_MklFusedDepthwiseConv2dNative);" \
|
||||||
|
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
|
||||||
|
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
|
||||||
|
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
|
||||||
|
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
|
||||||
|
}
|
||||||
|
|
||||||
|
// BiasAdd fusion
|
||||||
|
#define FUSED_OPS "{s: 'BiasAdd'}"
|
||||||
|
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive1);
|
||||||
|
|
||||||
|
// BiasAdd + Relu fusion
|
||||||
|
#define FUSED_OPS "{s: 'BiasAdd', s: 'Relu'}"
|
||||||
|
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive2);
|
||||||
|
|
||||||
|
// BiasAdd + Relu6 fusion
|
||||||
|
#define FUSED_OPS "{s: 'BiasAdd', s: 'Relu6'}"
|
||||||
|
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive3);
|
||||||
|
|
||||||
|
// BiasAdd + Elu fusion
|
||||||
|
#define FUSED_OPS "{s: 'BiasAdd', s: 'Elu'}"
|
||||||
|
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive4);
|
||||||
|
|
||||||
|
#undef FUSED_OPS
|
||||||
|
#undef REGISTER_TEST
|
||||||
|
|
||||||
// Rewrite test for _FusedConv2D Op with unsupported fusion
|
// Rewrite test for _FusedConv2D Op with unsupported fusion
|
||||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||||
@ -1818,6 +1868,36 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive7);
|
|||||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Negative1);
|
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Negative1);
|
||||||
#undef REGISTER_TEST
|
#undef REGISTER_TEST
|
||||||
|
|
||||||
|
// Rewrite test for _FusedDepthwiseConv2dNative with unsupported fusion
|
||||||
|
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||||
|
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||||
|
InitGraph( \
|
||||||
|
"node { name: 'A' op: '" #INPUT "'}" \
|
||||||
|
"node { name: 'B' op: '" #INPUT "'}" \
|
||||||
|
"node { name: 'C' op: '" #INPUT "'}" \
|
||||||
|
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
|
||||||
|
" attr { key: 'T' value { type: " #T " } }" \
|
||||||
|
" attr { key: 'num_args' value { i: 1 } }" \
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||||
|
"} }" \
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }" \
|
||||||
|
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||||
|
"} }" \
|
||||||
|
" attr { key: 'fused_ops' value { list: {s: 'Unsupported'} } }" \
|
||||||
|
" attr { key: 'epsilon' value { f: 0.001 }}" \
|
||||||
|
" input: ['A', 'B', 'C']}" \
|
||||||
|
"node { name: 'E' op: 'Zeta'" \
|
||||||
|
"attr { key: 'T' value { type: " #T " } }" \
|
||||||
|
" input: ['D', 'C'] }"); \
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||||
|
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
|
||||||
|
"D(_FusedDepthwiseConv2dNative);" \
|
||||||
|
"E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E"); \
|
||||||
|
}
|
||||||
|
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Negative1);
|
||||||
|
#undef REGISTER_TEST
|
||||||
|
|
||||||
// Rewrite test for _FusedConv2D Op with unsupported type
|
// Rewrite test for _FusedConv2D Op with unsupported type
|
||||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||||
@ -1847,6 +1927,37 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Negative1);
|
|||||||
REGISTER_TEST(NodeRewrite_FusedConv2D_Negative2, DT_DOUBLE, DoubleInput);
|
REGISTER_TEST(NodeRewrite_FusedConv2D_Negative2, DT_DOUBLE, DoubleInput);
|
||||||
#undef REGISTER_TEST
|
#undef REGISTER_TEST
|
||||||
|
|
||||||
|
// Rewrite test for _FusedDepthwiseConv2dNativeOp with unsupported type
|
||||||
|
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||||
|
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||||
|
InitGraph( \
|
||||||
|
"node { name: 'A' op: '" #INPUT "'}" \
|
||||||
|
"node { name: 'B' op: '" #INPUT "'}" \
|
||||||
|
"node { name: 'C' op: '" #INPUT "'}" \
|
||||||
|
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
|
||||||
|
" attr { key: 'T' value { type:" #T "} }" \
|
||||||
|
" attr { key: 'num_args' value { i: 1 } }" \
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||||
|
"} }" \
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }" \
|
||||||
|
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||||
|
"} }" \
|
||||||
|
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \
|
||||||
|
" attr { key: 'epsilon' value { f: 0.001 }}" \
|
||||||
|
" input: ['A', 'B', 'C']}" \
|
||||||
|
"node { name: 'E' op: 'Zeta'" \
|
||||||
|
"attr { key: 'T' value { type: " #T "} }" \
|
||||||
|
" input: ['D', 'C'] }"); \
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||||
|
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
|
||||||
|
"D(_FusedDepthwiseConv2dNative);" \
|
||||||
|
"E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E"); \
|
||||||
|
}
|
||||||
|
REGISTER_TEST(NodeRewrite_FusedDepthwiseConv2dNative_Negative2,
|
||||||
|
DT_DOUBLE, DoubleInput);
|
||||||
|
#undef REGISTER_TEST
|
||||||
|
|
||||||
// Test set: _FusedMatMul -> MklFusedMatMul rewrite tests
|
// Test set: _FusedMatMul -> MklFusedMatMul rewrite tests
|
||||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||||
@ -4240,6 +4351,33 @@ TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Positive) {
|
|||||||
"_MklFusedConv2D"));
|
"_MklFusedConv2D"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// _FusedDepthwiseConv2dNative + BiasAdd fusion where filter is a constant.
|
||||||
|
TEST_F(MklLayoutPassTest,
|
||||||
|
FusedDepthwiseConv2dNativeWithBias_FilterCaching_Positive) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Const'" // Filter
|
||||||
|
" attr { key: 'dtype' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'value' value { "
|
||||||
|
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
|
||||||
|
" int_val: 0 } } } }"
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'num_args' value { i: 1 } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
|
||||||
|
" attr { key: 'epsilon' value { f: 0.001 }}"
|
||||||
|
" input: ['A', 'B', 'C']}"
|
||||||
|
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['D', 'C'] }");
|
||||||
|
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
|
||||||
|
"is_filter_const", "_MklFusedDepthwiseConv2dNative"));
|
||||||
|
}
|
||||||
|
|
||||||
// _FusedConv2D + BiasAdd fusion where filter is NOT a constant.
|
// _FusedConv2D + BiasAdd fusion where filter is NOT a constant.
|
||||||
TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Negative) {
|
TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Negative) {
|
||||||
InitGraph(
|
InitGraph(
|
||||||
@ -4262,6 +4400,28 @@ TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Negative) {
|
|||||||
"_MklFusedConv2D"));
|
"_MklFusedConv2D"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// _FusedDepthwiseConv2dNative + BiasAdd fusion where filter is NOT a constant.
|
||||||
|
TEST_F(MklLayoutPassTest,
|
||||||
|
FusedDepthwiseConv2dNativeWithBias_FilterCaching_Negative) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}" // Filter
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'num_args' value { i: 1 } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
|
||||||
|
" attr { key: 'epsilon' value { f: 0.001 }}"
|
||||||
|
" input: ['A', 'B', 'C']}"
|
||||||
|
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['D', 'C'] }");
|
||||||
|
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
|
||||||
|
"is_filter_const", "_MklFusedDepthwiseConv2dNative"));
|
||||||
|
}
|
||||||
// Depthwise Conv2D op where filter is a constant.
|
// Depthwise Conv2D op where filter is a constant.
|
||||||
TEST_F(MklLayoutPassTest, DepthwiseConv2dNative_FilterCaching_Positive) {
|
TEST_F(MklLayoutPassTest, DepthwiseConv2dNative_FilterCaching_Positive) {
|
||||||
InitGraph(
|
InitGraph(
|
||||||
|
@ -206,6 +206,94 @@ CREATE_CONV2DFUSION_ADD_BCAST_TEST(AddV2);
|
|||||||
#undef CREATE_CONV2DFUSION_ADD_BCAST_TEST
|
#undef CREATE_CONV2DFUSION_ADD_BCAST_TEST
|
||||||
#undef CREATE_CONV2DFUSION_TEST
|
#undef CREATE_CONV2DFUSION_TEST
|
||||||
|
|
||||||
|
TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBiasAndActivation) {
|
||||||
|
using ::tensorflow::ops::Placeholder;
|
||||||
|
|
||||||
|
for (const string& activation : {"Relu", "Relu6", "Elu", "None"}) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
|
auto input_shape = Placeholder::Shape({8, 32, 32, 3});
|
||||||
|
auto filter_shape = Placeholder::Shape({1, 1, 3, 1});
|
||||||
|
auto bias_shape = Placeholder::Shape({3});
|
||||||
|
|
||||||
|
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||||
|
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
|
||||||
|
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
|
||||||
|
|
||||||
|
std::vector<int> strides = {1, 1, 1, 1};
|
||||||
|
auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"),
|
||||||
|
input, filter, strides, "SAME");
|
||||||
|
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);
|
||||||
|
|
||||||
|
ops::Identity fetch = [&]() -> ops::Identity {
|
||||||
|
auto activate = s.WithOpName("activation");
|
||||||
|
auto fetch = s.WithOpName("fetch");
|
||||||
|
|
||||||
|
if (activation == "Relu") {
|
||||||
|
return ops::Identity(fetch, ops::Relu(activate, bias_add));
|
||||||
|
} else if (activation == "Relu6") {
|
||||||
|
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
|
||||||
|
} else if (activation == "Elu") {
|
||||||
|
return ops::Identity(fetch, ops::Elu(activate, bias_add));
|
||||||
|
}
|
||||||
|
|
||||||
|
DCHECK(activation == "None");
|
||||||
|
return ops::Identity(fetch, bias_add);
|
||||||
|
}();
|
||||||
|
|
||||||
|
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
|
||||||
|
auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 1});
|
||||||
|
auto bias_t = GenerateRandomTensor<DT_FLOAT>({3});
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"fetch"};
|
||||||
|
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
// Place all nodes on CPU.
|
||||||
|
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||||
|
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||||
|
}
|
||||||
|
|
||||||
|
Remapper optimizer(RewriterConfig::ON);
|
||||||
|
GraphDef output;
|
||||||
|
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
int found = 0;
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() != "bias_add" && node.name() != "activation") continue;
|
||||||
|
|
||||||
|
EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
|
||||||
|
ASSERT_EQ(node.input_size(), 3);
|
||||||
|
EXPECT_EQ(node.input(0), "input");
|
||||||
|
EXPECT_EQ(node.input(1), "filter");
|
||||||
|
|
||||||
|
EXPECT_EQ(node.attr().at("num_args").i(), 1);
|
||||||
|
EXPECT_EQ(node.input(2), "bias");
|
||||||
|
|
||||||
|
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||||
|
if (node.name() == "bias_add") {
|
||||||
|
ASSERT_EQ(fused_ops.size(), 1);
|
||||||
|
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
if (node.name() == "activation") {
|
||||||
|
ASSERT_EQ(fused_ops.size(), 2);
|
||||||
|
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||||
|
EXPECT_EQ(fused_ops[1], activation);
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(found, 1);
|
||||||
|
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||||
|
ASSERT_EQ(tensors_expected.size(), 1);
|
||||||
|
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||||
|
ASSERT_EQ(tensors.size(), 1);
|
||||||
|
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
#endif // INTEL_MKL
|
#endif // INTEL_MKL
|
||||||
|
@ -57,6 +57,7 @@ namespace {
|
|||||||
|
|
||||||
constexpr char kFusedConv2D[] = "_FusedConv2D";
|
constexpr char kFusedConv2D[] = "_FusedConv2D";
|
||||||
constexpr char kFusedMatMul[] = "_FusedMatMul";
|
constexpr char kFusedMatMul[] = "_FusedMatMul";
|
||||||
|
constexpr char kFusedDepthwiseConv2dNative[] = "_FusedDepthwiseConv2dNative";
|
||||||
constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx";
|
constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx";
|
||||||
|
|
||||||
constexpr char kDataFormat[] = "data_format";
|
constexpr char kDataFormat[] = "data_format";
|
||||||
@ -279,12 +280,24 @@ bool IsCpuCompatibleMatMul(const NodeDef* matmul) {
|
|||||||
return NodeIsOnCpu(matmul) && IsCpuCompatibleDataType(matmul);
|
return NodeIsOnCpu(matmul) && IsCpuCompatibleDataType(matmul);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsCpuCompatibleDepthwiseConv2dNative(const NodeDef* dw_conv2d) {
|
||||||
|
DCHECK(IsDepthwiseConv2dNative(*dw_conv2d))
|
||||||
|
<< "Expected DepthwiseConv2dNative op";
|
||||||
|
return NodeIsOnCpu(dw_conv2d) && IsCpuCompatibleDataType(dw_conv2d);
|
||||||
|
}
|
||||||
|
|
||||||
// Checks if we can rewrite a pattern to the `_Fused{Conv2D,MatMul}` on CPU.
|
// Checks if we can rewrite a pattern to the `_Fused{Conv2D,MatMul}` on CPU.
|
||||||
template <typename Pattern>
|
template <typename Pattern>
|
||||||
bool IsCpuCompatible(const RemapperContext& ctx, const Pattern& matched) {
|
bool IsCpuCompatible(const RemapperContext& ctx, const Pattern& matched) {
|
||||||
const NodeDef& node = ctx.graph_view.graph()->node(matched.contraction);
|
const NodeDef& node = ctx.graph_view.graph()->node(matched.contraction);
|
||||||
if (IsConv2D(node)) {
|
if (IsConv2D(node)) {
|
||||||
return IsCpuCompatibleConv2D(&node);
|
return IsCpuCompatibleConv2D(&node);
|
||||||
|
} else if (IsDepthwiseConv2dNative(node)) {
|
||||||
|
#ifdef INTEL_MKL
|
||||||
|
return IsCpuCompatibleDepthwiseConv2dNative(&node);
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif // INTEL_MKL
|
||||||
} else if (IsMatMul(node)) {
|
} else if (IsMatMul(node)) {
|
||||||
return IsCpuCompatibleMatMul(&node);
|
return IsCpuCompatibleMatMul(&node);
|
||||||
} else {
|
} else {
|
||||||
@ -381,11 +394,12 @@ bool FindContractionWithBias(const RemapperContext& ctx, int node_index,
|
|||||||
const auto* contraction_node_view = regular_fanin_0.node_view();
|
const auto* contraction_node_view = regular_fanin_0.node_view();
|
||||||
const auto* contraction_node_def = contraction_node_view->node();
|
const auto* contraction_node_def = contraction_node_view->node();
|
||||||
|
|
||||||
bool is_conv2d_or_matmul =
|
// Conv2D, MatMul or DepthwiseConv2D
|
||||||
IsConv2D(*contraction_node_def) || IsMatMul(*contraction_node_def);
|
bool is_contraction = IsConv2D(*contraction_node_def) ||
|
||||||
|
IsMatMul(*contraction_node_def) ||
|
||||||
|
IsDepthwiseConv2dNative(*contraction_node_def);
|
||||||
|
|
||||||
if (!is_conv2d_or_matmul ||
|
if (!is_contraction || !HaveSameDataType(node_def, contraction_node_def) ||
|
||||||
!HaveSameDataType(node_def, contraction_node_def) ||
|
|
||||||
HasControlFaninOrFanout(*contraction_node_view) ||
|
HasControlFaninOrFanout(*contraction_node_view) ||
|
||||||
!HasAtMostOneFanoutAtPort0(*contraction_node_view) ||
|
!HasAtMostOneFanoutAtPort0(*contraction_node_view) ||
|
||||||
IsInPreserveSet(ctx, contraction_node_def))
|
IsInPreserveSet(ctx, contraction_node_def))
|
||||||
@ -902,6 +916,21 @@ void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d) {
|
|||||||
(*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
|
(*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CopyDepthwiseConv2dNativeAttributes(const NodeDef& dw_conv2d,
|
||||||
|
NodeDef* fused_dw_conv2d) {
|
||||||
|
DCHECK(IsDepthwiseConv2dNative(dw_conv2d))
|
||||||
|
<< "Input node must be a DepthwiseConv2dNative";
|
||||||
|
|
||||||
|
auto* attr = fused_dw_conv2d->mutable_attr();
|
||||||
|
auto& src_attr = dw_conv2d.attr();
|
||||||
|
|
||||||
|
(*attr)["T"] = src_attr.at("T");
|
||||||
|
(*attr)["strides"] = src_attr.at("strides");
|
||||||
|
(*attr)["padding"] = src_attr.at("padding");
|
||||||
|
(*attr)["dilations"] = src_attr.at("dilations");
|
||||||
|
(*attr)["data_format"] = src_attr.at("data_format");
|
||||||
|
}
|
||||||
|
|
||||||
void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
|
void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
|
||||||
NodeDef* fused_batch_norm_ex) {
|
NodeDef* fused_batch_norm_ex) {
|
||||||
DCHECK(IsFusedBatchNorm(fused_batch_norm))
|
DCHECK(IsFusedBatchNorm(fused_batch_norm))
|
||||||
@ -966,6 +995,9 @@ Status AddFusedContractionNode(RemapperContext* ctx,
|
|||||||
if (IsConv2D(contraction)) {
|
if (IsConv2D(contraction)) {
|
||||||
fused_op.set_op(kFusedConv2D);
|
fused_op.set_op(kFusedConv2D);
|
||||||
CopyConv2DAttributes(contraction, &fused_op);
|
CopyConv2DAttributes(contraction, &fused_op);
|
||||||
|
} else if (IsDepthwiseConv2dNative(contraction)) {
|
||||||
|
fused_op.set_op(kFusedDepthwiseConv2dNative);
|
||||||
|
CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
|
||||||
} else if (IsMatMul(contraction)) {
|
} else if (IsMatMul(contraction)) {
|
||||||
fused_op.set_op(kFusedMatMul);
|
fused_op.set_op(kFusedMatMul);
|
||||||
CopyMatMulAttributes(contraction, &fused_op);
|
CopyMatMulAttributes(contraction, &fused_op);
|
||||||
@ -1010,6 +1042,9 @@ Status AddFusedContractionNode(
|
|||||||
if (IsConv2D(contraction)) {
|
if (IsConv2D(contraction)) {
|
||||||
fused_op.set_op(kFusedConv2D);
|
fused_op.set_op(kFusedConv2D);
|
||||||
CopyConv2DAttributes(contraction, &fused_op);
|
CopyConv2DAttributes(contraction, &fused_op);
|
||||||
|
} else if (IsDepthwiseConv2dNative(contraction)) {
|
||||||
|
fused_op.set_op(kFusedDepthwiseConv2dNative);
|
||||||
|
CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
|
||||||
} else if (IsMatMul(contraction)) {
|
} else if (IsMatMul(contraction)) {
|
||||||
fused_op.set_op(kFusedMatMul);
|
fused_op.set_op(kFusedMatMul);
|
||||||
CopyMatMulAttributes(contraction, &fused_op);
|
CopyMatMulAttributes(contraction, &fused_op);
|
||||||
@ -1660,7 +1695,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
}
|
}
|
||||||
#endif //! INTEL_MKL
|
#endif //! INTEL_MKL
|
||||||
|
|
||||||
// Remap {Conv2D,MatMul}+BiasAdd into the _Fused{Conv2D,MatMul}
|
// Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd into the
|
||||||
|
// _Fused{Conv2D,DepthwiseConv2dNative,MatMul}
|
||||||
ContractionWithBiasAdd contract_with_bias;
|
ContractionWithBiasAdd contract_with_bias;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
FindContractionWithBias(ctx, i, &contract_with_bias)) {
|
FindContractionWithBias(ctx, i, &contract_with_bias)) {
|
||||||
@ -1669,7 +1705,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remap {Conv2D,MatMul}+BiasAdd+Activation into the _Fused{Conv2D,MatMul}.
|
// Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd+Activation into the
|
||||||
|
// _Fused{Conv2D,DepthwiseConv2dNative,MatMul}.
|
||||||
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
|
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
FindContractionWithBiasAndActivation(
|
FindContractionWithBiasAndActivation(
|
||||||
|
@ -1363,6 +1363,58 @@ class MklFusedConvOp
|
|||||||
virtual ~MklFusedConvOp() {}
|
virtual ~MklFusedConvOp() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
|
||||||
|
typename Toutput, typename Ttemp_output, typename Tpadding,
|
||||||
|
bool pad_enabled, bool bias_enabled, bool is_depthwise>
|
||||||
|
class MklFusedDepthwiseConvOp
|
||||||
|
: public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
|
||||||
|
Tpadding, bias_enabled, false, is_depthwise, false> {
|
||||||
|
public:
|
||||||
|
explicit MklFusedDepthwiseConvOp(OpKernelConstruction* context)
|
||||||
|
: MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
|
||||||
|
Tpadding, bias_enabled, false, is_depthwise, false>(context) {
|
||||||
|
// Since we came here through the registration of
|
||||||
|
// _MklFusedDepthwiseConv2dNative, get all
|
||||||
|
// information from 'fused_ops' and 'num_args'
|
||||||
|
std::vector<string> fused_ops;
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops));
|
||||||
|
|
||||||
|
int num_args;
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args));
|
||||||
|
OP_REQUIRES(context, !fused_ops.empty(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Fused DepthwiseConv2D must have at least one fused op."));
|
||||||
|
|
||||||
|
if (fused_ops == std::vector<string>{"BiasAdd"}) {
|
||||||
|
this->set_fuse_biasadd(true);
|
||||||
|
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
|
||||||
|
this->set_fuse_biasadd(true);
|
||||||
|
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
||||||
|
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
|
||||||
|
this->set_fuse_biasadd(true);
|
||||||
|
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
||||||
|
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
|
||||||
|
this->set_fuse_biasadd(true);
|
||||||
|
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES(context, false,
|
||||||
|
errors::Unimplemented("Fusion is not implemented: [",
|
||||||
|
absl::StrJoin(fused_ops, ","), "]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, num_args == 1,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Fused DepthwiseConv2D must have one extra argument: bias."));
|
||||||
|
|
||||||
|
if (pad_enabled) {
|
||||||
|
this->set_fuse_pad(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~MklFusedDepthwiseConvOp() {}
|
||||||
|
};
|
||||||
|
|
||||||
// We create new class for each version of Quantized Convolution and inherit
|
// We create new class for each version of Quantized Convolution and inherit
|
||||||
// from the FP32 version of the base class
|
// from the FP32 version of the base class
|
||||||
template <typename Device, typename Tinput, typename Tbias, typename Toutput,
|
template <typename Device, typename Tinput, typename Tbias, typename Toutput,
|
||||||
@ -2253,6 +2305,11 @@ REGISTER_KERNEL_BUILDER(
|
|||||||
.TypeConstraint<quint8>("out_type"),
|
.TypeConstraint<quint8>("out_type"),
|
||||||
NoOp);
|
NoOp);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("_FusedDepthwiseConv2dNative")
|
||||||
|
.Device(DEVICE_CPU)
|
||||||
|
.TypeConstraint<float>("T"),
|
||||||
|
NoOp);
|
||||||
|
|
||||||
// Register templatized MKL kernels for non-fused and fused-versions of
|
// Register templatized MKL kernels for non-fused and fused-versions of
|
||||||
// QuantizedDepthwiseConv2D.
|
// QuantizedDepthwiseConv2D.
|
||||||
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedDepthwiseConv2D")
|
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedDepthwiseConv2D")
|
||||||
@ -2306,6 +2363,14 @@ REGISTER_KERNEL_BUILDER(
|
|||||||
MklQuantizedConv2DReluOp<CPUDevice, quint8, qint32, quint8, quint8, true,
|
MklQuantizedConv2DReluOp<CPUDevice, quint8, qint32, quint8, quint8, true,
|
||||||
true>);
|
true>);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("_MklFusedDepthwiseConv2dNative")
|
||||||
|
.Device(DEVICE_CPU)
|
||||||
|
.TypeConstraint<float>("T")
|
||||||
|
.Label(mkl_op_registry::kMklLayoutDependentOpLabel),
|
||||||
|
MklFusedDepthwiseConvOp<CPUDevice, float, float, float, float, float, int32,
|
||||||
|
false, true, true>);
|
||||||
|
|
||||||
// Register 2D operations
|
// Register 2D operations
|
||||||
#define REGISTER_MKL_CPU_2D(T) \
|
#define REGISTER_MKL_CPU_2D(T) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
@ -134,6 +134,7 @@ class CommonTestUtilities : public OpsTestBase {
|
|||||||
static void VerifyFusedTensorsClose(int depth, int image_width,
|
static void VerifyFusedTensorsClose(int depth, int image_width,
|
||||||
int image_height, int image_batch_count,
|
int image_height, int image_batch_count,
|
||||||
int filter_size, int filter_count,
|
int filter_size, int filter_count,
|
||||||
|
int bias_size,
|
||||||
const std::vector<string>& fused_ops,
|
const std::vector<string>& fused_ops,
|
||||||
const FusedGraphRunner& run_default,
|
const FusedGraphRunner& run_default,
|
||||||
const FusedGraphRunner& run_fused) {
|
const FusedGraphRunner& run_fused) {
|
||||||
@ -145,7 +146,6 @@ class CommonTestUtilities : public OpsTestBase {
|
|||||||
Tensor filter(dtype, {filter_size, filter_size, depth, filter_count});
|
Tensor filter(dtype, {filter_size, filter_size, depth, filter_count});
|
||||||
filter.flat<T>() = filter.flat<T>().template setRandom<random_gen_>();
|
filter.flat<T>() = filter.flat<T>().template setRandom<random_gen_>();
|
||||||
|
|
||||||
const int bias_size = filter_count;
|
|
||||||
Tensor bias(dtype, {bias_size});
|
Tensor bias(dtype, {bias_size});
|
||||||
bias.flat<T>() = bias.flat<T>().template setRandom<random_gen_>();
|
bias.flat<T>() = bias.flat<T>().template setRandom<random_gen_>();
|
||||||
|
|
||||||
@ -321,9 +321,10 @@ class MklFusedConv2DOpTest : public OpsTestBase {
|
|||||||
out);
|
out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const int bias_size = filter_count;
|
||||||
CommonTestUtilities<T>::VerifyFusedTensorsClose(
|
CommonTestUtilities<T>::VerifyFusedTensorsClose(
|
||||||
depth, image_width, image_height, image_batch_count, filter_size,
|
depth, image_width, image_height, image_batch_count, filter_size,
|
||||||
filter_count, fused_ops, run_default, run_fused);
|
filter_count, bias_size, fused_ops, run_default, run_fused);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -449,6 +450,223 @@ REGISTER_TYPED_TEST_CASE_P(
|
|||||||
using MklFusedBiasAddDataTypes = ::testing::Types<float>;
|
using MklFusedBiasAddDataTypes = ::testing::Types<float>;
|
||||||
INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedConv2DWithBiasOpTest,
|
INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedConv2DWithBiasOpTest,
|
||||||
MklFusedBiasAddDataTypes);
|
MklFusedBiasAddDataTypes);
|
||||||
|
|
||||||
|
// Testing MKL's fused depthwise convolution ops
|
||||||
|
template <typename T>
|
||||||
|
class MklFusedDepthwiseConv2DOpTest : public OpsTestBase {
|
||||||
|
protected:
|
||||||
|
static constexpr int kDepth = 3;
|
||||||
|
static constexpr int kImageWidth = 32;
|
||||||
|
static constexpr int kImageHeight = 32;
|
||||||
|
static constexpr int kImageBatchCount = 8;
|
||||||
|
|
||||||
|
void RunDepthwiseConv2DUnfused(const Tensor& input_data,
|
||||||
|
const Tensor& filter_data,
|
||||||
|
const Tensor& bias_data,
|
||||||
|
const std::vector<string>& fused_ops,
|
||||||
|
Tensor* output, int stride = 1) {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
auto input_data_op =
|
||||||
|
ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
|
||||||
|
Output next_op = ops::DepthwiseConv2dNative(
|
||||||
|
root.WithOpName("depthwise_conv"), input_data_op,
|
||||||
|
ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data)),
|
||||||
|
{1, stride, stride, 1}, "SAME");
|
||||||
|
|
||||||
|
string last_op = "";
|
||||||
|
if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") !=
|
||||||
|
fused_ops.end()) {
|
||||||
|
last_op = "with_bias";
|
||||||
|
next_op = ops::BiasAdd(
|
||||||
|
root.WithOpName(last_op), next_op,
|
||||||
|
ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (std::find(fused_ops.begin(), fused_ops.end(), "Relu") !=
|
||||||
|
fused_ops.end()) {
|
||||||
|
last_op = "with_relu";
|
||||||
|
next_op = ops::Relu(root.WithOpName(last_op), next_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (std::find(fused_ops.begin(), fused_ops.end(), "Relu6") !=
|
||||||
|
fused_ops.end()) {
|
||||||
|
last_op = "with_relu6";
|
||||||
|
next_op = ops::Relu6(root.WithOpName(last_op), next_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (std::find(fused_ops.begin(), fused_ops.end(), "Elu") !=
|
||||||
|
fused_ops.end()) {
|
||||||
|
last_op = "with_elu";
|
||||||
|
next_op = ops::Elu(root.WithOpName(last_op), next_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunMklFusedDepthwiseConv2DOp(const Tensor& image, const Tensor& filter,
|
||||||
|
const std::vector<Tensor>& args,
|
||||||
|
const std::vector<string>& fused_ops,
|
||||||
|
Tensor* output, int stride = 1) {
|
||||||
|
DataType dtype = DataTypeToEnum<T>::v();
|
||||||
|
int num_args = static_cast<int>(args.size());
|
||||||
|
|
||||||
|
TF_EXPECT_OK(NodeDefBuilder("fused_depthwise_conv_op",
|
||||||
|
"_MklFusedDepthwiseConv2dNative")
|
||||||
|
.Input(FakeInput(dtype))
|
||||||
|
.Input(FakeInput(dtype))
|
||||||
|
.Input(FakeInput(num_args, dtype))
|
||||||
|
.Input(FakeInput(DT_UINT8))
|
||||||
|
.Input(FakeInput(DT_UINT8))
|
||||||
|
.Input(FakeInput(num_args, DT_UINT8))
|
||||||
|
.Attr("T", dtype)
|
||||||
|
.Attr("num_args", num_args)
|
||||||
|
.Attr("strides", {1, stride, stride, 1})
|
||||||
|
.Attr("padding", "SAME")
|
||||||
|
.Attr("fused_ops", fused_ops)
|
||||||
|
.Attr("_kernel", "MklLayoutDependentOp")
|
||||||
|
.Finalize(node_def()));
|
||||||
|
|
||||||
|
TF_EXPECT_OK(InitOp());
|
||||||
|
|
||||||
|
AddInputFromArray<T>(image.shape(), image.flat<T>());
|
||||||
|
AddInputFromArray<T>(filter.shape(), filter.flat<T>());
|
||||||
|
for (const Tensor& arg : args)
|
||||||
|
AddInputFromArray<T>(arg.shape(), arg.flat<T>());
|
||||||
|
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||||
|
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||||
|
for (const Tensor& arg : args)
|
||||||
|
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
// Compare output to expected results
|
||||||
|
const Tensor& output_tensor = *GetOutput(0);
|
||||||
|
// Index 2 will need to be changed if the number of outputs produced
|
||||||
|
// by MklDepthwiseConv2D change.
|
||||||
|
const Tensor& output_meta_tensor = *GetOutput(2);
|
||||||
|
CommonTestUtilities<T> test_util;
|
||||||
|
test_util.PerformConversion(dtype, output_tensor, output_meta_tensor,
|
||||||
|
output);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verifies computing unfused ops in a graph is identical to
|
||||||
|
// FusedDepthwiseConv2D.
|
||||||
|
void VerifyFusedDepthwiseConv2D(int filter_size, int filter_count,
|
||||||
|
int bias_size,
|
||||||
|
const std::vector<string>& fused_ops,
|
||||||
|
int depth = kDepth,
|
||||||
|
int image_width = kImageWidth,
|
||||||
|
int image_height = kImageHeight,
|
||||||
|
int image_batch_count = kImageBatchCount) {
|
||||||
|
const FusedGraphRunner run_default =
|
||||||
|
[this](const Tensor& input_data, const Tensor& filter_data,
|
||||||
|
const Tensor& bias_data, const std::vector<string>& fused_ops,
|
||||||
|
Tensor* out) {
|
||||||
|
RunDepthwiseConv2DUnfused(input_data, filter_data, bias_data,
|
||||||
|
fused_ops, out);
|
||||||
|
};
|
||||||
|
|
||||||
|
const FusedGraphRunner run_fused =
|
||||||
|
[this](const Tensor& input_data, const Tensor& filter_data,
|
||||||
|
const Tensor& bias_data, const std::vector<string>& fused_ops,
|
||||||
|
Tensor* out) {
|
||||||
|
std::vector<Tensor> fused_input = {bias_data};
|
||||||
|
RunMklFusedDepthwiseConv2DOp(input_data, filter_data, fused_input,
|
||||||
|
fused_ops, out);
|
||||||
|
};
|
||||||
|
|
||||||
|
CommonTestUtilities<T>::VerifyFusedTensorsClose(
|
||||||
|
depth, image_width, image_height, image_batch_count, filter_size,
|
||||||
|
filter_count, bias_size, fused_ops, run_default, run_fused);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class MklFusedDepthwiseConv2DWithBiasOpTest
|
||||||
|
: public MklFusedDepthwiseConv2DOpTest<T> {};
|
||||||
|
|
||||||
|
TYPED_TEST_SUITE_P(MklFusedDepthwiseConv2DWithBiasOpTest);
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------- //
|
||||||
|
// DepthwiseConv2D + BiasAdd + {Activation} //
|
||||||
|
// -------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolution) {
|
||||||
|
const int kFilterSize = 1;
|
||||||
|
const int kFilterCount = 1;
|
||||||
|
const int kBiasSize = 3;
|
||||||
|
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||||
|
{"BiasAdd"});
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolution) {
|
||||||
|
const int kFilterSize = 3;
|
||||||
|
const int kFilterCount = 1;
|
||||||
|
const int kBiasSize = 3;
|
||||||
|
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||||
|
{"BiasAdd"});
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,
|
||||||
|
OneByOneConvolutionAndRelu) {
|
||||||
|
const int kFilterSize = 1;
|
||||||
|
const int kFilterCount = 1;
|
||||||
|
const int kBiasSize = 3;
|
||||||
|
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||||
|
{"BiasAdd", "Relu"});
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolutionAndRelu) {
|
||||||
|
const int kFilterSize = 3;
|
||||||
|
const int kFilterCount = 1;
|
||||||
|
const int kBiasSize = 3;
|
||||||
|
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||||
|
{"BiasAdd", "Relu"});
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,
|
||||||
|
OneByOneConvolutionAndRelu6) {
|
||||||
|
const int kFilterSize = 1;
|
||||||
|
const int kFilterCount = 1;
|
||||||
|
const int kBiasSize = 3;
|
||||||
|
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||||
|
{"BiasAdd", "Relu6"});
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,
|
||||||
|
SpatialConvolutionAndRelu6) {
|
||||||
|
const int kFilterSize = 3;
|
||||||
|
const int kFilterCount = 1;
|
||||||
|
const int kBiasSize = 3;
|
||||||
|
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||||
|
{"BiasAdd", "Relu6"});
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolutionAndElu) {
|
||||||
|
const int kFilterSize = 1;
|
||||||
|
const int kFilterCount = 1;
|
||||||
|
const int kBiasSize = 3;
|
||||||
|
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||||
|
{"BiasAdd", "Elu"});
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolutionAndElu) {
|
||||||
|
const int kFilterSize = 3;
|
||||||
|
const int kFilterCount = 1;
|
||||||
|
const int kBiasSize = 3;
|
||||||
|
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||||
|
{"BiasAdd", "Elu"});
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_TYPED_TEST_SUITE_P(
|
||||||
|
MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolution,
|
||||||
|
SpatialConvolution, OneByOneConvolutionAndRelu, SpatialConvolutionAndRelu,
|
||||||
|
OneByOneConvolutionAndRelu6, SpatialConvolutionAndRelu6,
|
||||||
|
OneByOneConvolutionAndElu, SpatialConvolutionAndElu);
|
||||||
|
|
||||||
|
using MklFusedBiasAddDataTypes = ::testing::Types<float>;
|
||||||
|
INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedDepthwiseConv2DWithBiasOpTest,
|
||||||
|
MklFusedBiasAddDataTypes);
|
||||||
|
|
||||||
// Testing fusion of pad and convolution
|
// Testing fusion of pad and convolution
|
||||||
|
|
||||||
class FusedPadConvOpTest : public OpsTestBase {
|
class FusedPadConvOpTest : public OpsTestBase {
|
||||||
|
@ -61,6 +61,30 @@ REGISTER_OP("_MklFusedConv2D")
|
|||||||
is expected to create these operators.
|
is expected to create these operators.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("_MklFusedDepthwiseConv2dNative")
|
||||||
|
.Input("input: T")
|
||||||
|
.Input("filter: T")
|
||||||
|
.Input("args: num_args * T")
|
||||||
|
.Input("mkl_input: uint8")
|
||||||
|
.Input("mkl_filter: uint8")
|
||||||
|
.Input("mkl_args: num_args * uint8")
|
||||||
|
.Output("output: T")
|
||||||
|
.Output("filter_output: T")
|
||||||
|
.Output("mkl_output: uint8")
|
||||||
|
.Output("mkl_filter_output: uint8")
|
||||||
|
.Attr("T: {bfloat16, float}")
|
||||||
|
.Attr("num_args: int >= 0")
|
||||||
|
.Attr("strides: list(int)")
|
||||||
|
.Attr("is_filter_const: bool = false")
|
||||||
|
.Attr(GetPaddingAttrString())
|
||||||
|
.Attr(GetConvnetDataFormatAttrString())
|
||||||
|
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||||
|
.Attr("fused_ops: list(string) = []")
|
||||||
|
// Attributes for the FusedBatchNorm ------------------------------------ //
|
||||||
|
.Attr("epsilon: float = 0.0001")
|
||||||
|
// ---------------------------------------------------------------------- //
|
||||||
|
.SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
|
||||||
|
|
||||||
REGISTER_OP("_MklFusedMatMul")
|
REGISTER_OP("_MklFusedMatMul")
|
||||||
.Input("a: T")
|
.Input("a: T")
|
||||||
.Input("b: T")
|
.Input("b: T")
|
||||||
|
@ -596,6 +596,23 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("_FusedDepthwiseConv2dNative")
|
||||||
|
.Input("input: T")
|
||||||
|
.Input("filter: T")
|
||||||
|
.Input("args: num_args * T")
|
||||||
|
.Output("output: T")
|
||||||
|
.Attr("T: {half, bfloat16, float, double}")
|
||||||
|
.Attr("num_args: int >= 0")
|
||||||
|
.Attr("strides: list(int)")
|
||||||
|
.Attr(GetPaddingAttrString())
|
||||||
|
.Attr(GetConvnetDataFormatAttrString())
|
||||||
|
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||||
|
.Attr("fused_ops: list(string) = []")
|
||||||
|
// Attributes for the FusedBatchNorm ------------------------------------ //
|
||||||
|
.Attr("epsilon: float = 0.0001")
|
||||||
|
// ---------------------------------------------------------------------- //
|
||||||
|
.SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
REGISTER_OP("Conv3D")
|
REGISTER_OP("Conv3D")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
|
Loading…
Reference in New Issue
Block a user