Merge pull request #37624 from Intel-tensorflow:yifeng/depthwise_conv2d_fusion

PiperOrigin-RevId: 307656312
Change-Id: I7098e05501111caea1c7e32329eb4e502292273e
This commit is contained in:
TensorFlower Gardener 2020-04-21 12:22:33 -07:00
commit 0fa30f59f3
8 changed files with 645 additions and 8 deletions

View File

@ -273,6 +273,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3"; csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3"; csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3";
csinfo_.fused_conv2d = "_FusedConv2D"; csinfo_.fused_conv2d = "_FusedConv2D";
csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative";
csinfo_.fused_matmul = "_FusedMatMul"; csinfo_.fused_matmul = "_FusedMatMul";
csinfo_.identity = "Identity"; csinfo_.identity = "Identity";
csinfo_.leakyrelu = "LeakyRelu"; csinfo_.leakyrelu = "LeakyRelu";
@ -295,6 +296,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.mkl_depthwise_conv2d_grad_filter = csinfo_.mkl_depthwise_conv2d_grad_filter =
"_MklDepthwiseConv2dNativeBackpropFilter"; "_MklDepthwiseConv2dNativeBackpropFilter";
csinfo_.mkl_fused_conv2d = "_MklFusedConv2D"; csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative";
csinfo_.mkl_fused_matmul = "_MklFusedMatMul"; csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D"; csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D"; csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
@ -479,6 +481,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d, rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
CopyAttrsFusedConv2D, FusedConv2DRewrite, CopyAttrsFusedConv2D, FusedConv2DRewrite,
kRewriteForLayoutPropagation}); kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.fused_depthwise_conv2d,
csinfo_.mkl_fused_depthwise_conv2d, CopyAttrsFusedConv2D,
FusedDepthwiseConv2DRewrite,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul, rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul,
CopyAttrsAllCheckConstFilter, FusedMatMulRewrite}); CopyAttrsAllCheckConstFilter, FusedMatMulRewrite});
@ -925,6 +931,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string fused_batch_norm_v3; string fused_batch_norm_v3;
string fused_batch_norm_grad_v3; string fused_batch_norm_grad_v3;
string fused_conv2d; string fused_conv2d;
string fused_depthwise_conv2d;
string fused_matmul; string fused_matmul;
string identity; string identity;
string leakyrelu; string leakyrelu;
@ -945,6 +952,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string mkl_depthwise_conv2d_grad_input; string mkl_depthwise_conv2d_grad_input;
string mkl_depthwise_conv2d_grad_filter; string mkl_depthwise_conv2d_grad_filter;
string mkl_fused_conv2d; string mkl_fused_conv2d;
string mkl_fused_depthwise_conv2d;
string mkl_fused_matmul; string mkl_fused_matmul;
string mkl_pad_with_conv2d; string mkl_pad_with_conv2d;
string mkl_pad_with_fused_conv2d; string mkl_pad_with_fused_conv2d;
@ -1675,6 +1683,25 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}); fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"});
} }
static bool FusedDepthwiseConv2DRewrite(const Node* n) {
// MKL DNN currently doesn't support all fusions that grappler fuses
// together with DepthwiseConv2D (ex. batchnorm). We rewrite
// _FusedDepthwiseConv2DNative only if it includes those we support.
DataType T;
if (!TryGetNodeAttr(n->def(), "T", &T) ||
!mkl_op_registry::IsMklLayoutDependentOp(
csinfo_.mkl_fused_depthwise_conv2d, T)) {
return false;
}
std::vector<string> fused_ops;
TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
return (fused_ops == std::vector<string>{"BiasAdd"} ||
fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
fused_ops == std::vector<string>{"BiasAdd", "Elu"});
}
// Rewrites input node to a new node specified by its matching rewrite info. // Rewrites input node to a new node specified by its matching rewrite info.
// //
// Method first searches matching rewrite info for input node and then // Method first searches matching rewrite info for input node and then
@ -3703,6 +3730,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
n->type_string() != csinfo_.pad_with_fused_conv2d && n->type_string() != csinfo_.pad_with_fused_conv2d &&
n->type_string() != csinfo_.conv2d_grad_filter_with_bias && n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
n->type_string() != csinfo_.fused_conv2d && n->type_string() != csinfo_.fused_conv2d &&
n->type_string() != csinfo_.fused_depthwise_conv2d &&
n->type_string() != csinfo_.fused_matmul && n->type_string() != csinfo_.fused_matmul &&
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
T)) { T)) {

View File

@ -1789,6 +1789,56 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive6);
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive7); REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive7);
#undef REGISTER_TEST #undef REGISTER_TEST
// Rewrite test for _FusedDepthwiseConv2dNative Op fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: " FUSED_OPS " } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
"D(_MklFusedDepthwiseConv2dNative);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
// BiasAdd fusion
#define FUSED_OPS "{s: 'BiasAdd'}"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive1);
// BiasAdd + Relu fusion
#define FUSED_OPS "{s: 'BiasAdd', s: 'Relu'}"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive2);
// BiasAdd + Relu6 fusion
#define FUSED_OPS "{s: 'BiasAdd', s: 'Relu6'}"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive3);
// BiasAdd + Elu fusion
#define FUSED_OPS "{s: 'BiasAdd', s: 'Elu'}"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive4);
#undef FUSED_OPS
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with unsupported fusion // Rewrite test for _FusedConv2D Op with unsupported fusion
#define REGISTER_TEST(NAME, T, INPUT) \ #define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \ TEST_F(MklLayoutPassTest, NAME##_##T) { \
@ -1818,6 +1868,36 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive7);
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Negative1); REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Negative1);
#undef REGISTER_TEST #undef REGISTER_TEST
// Rewrite test for _FusedDepthwiseConv2dNative with unsupported fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: {s: 'Unsupported'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
"D(_FusedDepthwiseConv2dNative);" \
"E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Negative1);
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with unsupported type // Rewrite test for _FusedConv2D Op with unsupported type
#define REGISTER_TEST(NAME, T, INPUT) \ #define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \ TEST_F(MklLayoutPassTest, NAME##_##T) { \
@ -1847,6 +1927,37 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Negative1);
REGISTER_TEST(NodeRewrite_FusedConv2D_Negative2, DT_DOUBLE, DoubleInput); REGISTER_TEST(NodeRewrite_FusedConv2D_Negative2, DT_DOUBLE, DoubleInput);
#undef REGISTER_TEST #undef REGISTER_TEST
// Rewrite test for _FusedDepthwiseConv2dNativeOp with unsupported type
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
" attr { key: 'T' value { type:" #T "} }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T "} }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
"D(_FusedDepthwiseConv2dNative);" \
"E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E"); \
}
REGISTER_TEST(NodeRewrite_FusedDepthwiseConv2dNative_Negative2,
DT_DOUBLE, DoubleInput);
#undef REGISTER_TEST
// Test set: _FusedMatMul -> MklFusedMatMul rewrite tests // Test set: _FusedMatMul -> MklFusedMatMul rewrite tests
#define REGISTER_TEST(NAME, T, INPUT) \ #define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \ TEST_F(MklLayoutPassTest, NAME##_##T) { \
@ -4240,6 +4351,33 @@ TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Positive) {
"_MklFusedConv2D")); "_MklFusedConv2D"));
} }
// _FusedDepthwiseConv2dNative + BiasAdd fusion where filter is a constant.
TEST_F(MklLayoutPassTest,
FusedDepthwiseConv2dNativeWithBias_FilterCaching_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Const'" // Filter
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
"is_filter_const", "_MklFusedDepthwiseConv2dNative"));
}
// _FusedConv2D + BiasAdd fusion where filter is NOT a constant. // _FusedConv2D + BiasAdd fusion where filter is NOT a constant.
TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Negative) { TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Negative) {
InitGraph( InitGraph(
@ -4262,6 +4400,28 @@ TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Negative) {
"_MklFusedConv2D")); "_MklFusedConv2D"));
} }
// _FusedDepthwiseConv2dNative + BiasAdd fusion where filter is NOT a constant.
TEST_F(MklLayoutPassTest,
FusedDepthwiseConv2dNativeWithBias_FilterCaching_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}" // Filter
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
"is_filter_const", "_MklFusedDepthwiseConv2dNative"));
}
// Depthwise Conv2D op where filter is a constant. // Depthwise Conv2D op where filter is a constant.
TEST_F(MklLayoutPassTest, DepthwiseConv2dNative_FilterCaching_Positive) { TEST_F(MklLayoutPassTest, DepthwiseConv2dNative_FilterCaching_Positive) {
InitGraph( InitGraph(

View File

@ -206,6 +206,94 @@ CREATE_CONV2DFUSION_ADD_BCAST_TEST(AddV2);
#undef CREATE_CONV2DFUSION_ADD_BCAST_TEST #undef CREATE_CONV2DFUSION_ADD_BCAST_TEST
#undef CREATE_CONV2DFUSION_TEST #undef CREATE_CONV2DFUSION_TEST
TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBiasAndActivation) {
using ::tensorflow::ops::Placeholder;
for (const string& activation : {"Relu", "Relu6", "Elu", "None"}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = Placeholder::Shape({8, 32, 32, 3});
auto filter_shape = Placeholder::Shape({1, 1, 3, 1});
auto bias_shape = Placeholder::Shape({3});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
std::vector<int> strides = {1, 1, 1, 1};
auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"),
input, filter, strides, "SAME");
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);
ops::Identity fetch = [&]() -> ops::Identity {
auto activate = s.WithOpName("activation");
auto fetch = s.WithOpName("fetch");
if (activation == "Relu") {
return ops::Identity(fetch, ops::Relu(activate, bias_add));
} else if (activation == "Relu6") {
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
} else if (activation == "Elu") {
return ops::Identity(fetch, ops::Elu(activate, bias_add));
}
DCHECK(activation == "None");
return ops::Identity(fetch, bias_add);
}();
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 1});
auto bias_t = GenerateRandomTensor<DT_FLOAT>({3});
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
Remapper optimizer(RewriterConfig::ON);
GraphDef output;
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() != "bias_add" && node.name() != "activation") continue;
EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
ASSERT_EQ(node.input_size(), 3);
EXPECT_EQ(node.input(0), "input");
EXPECT_EQ(node.input(1), "filter");
EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "bias");
const auto fused_ops = node.attr().at("fused_ops").list().s();
if (node.name() == "bias_add") {
ASSERT_EQ(fused_ops.size(), 1);
EXPECT_EQ(fused_ops[0], "BiasAdd");
found++;
}
if (node.name() == "activation") {
ASSERT_EQ(fused_ops.size(), 2);
EXPECT_EQ(fused_ops[0], "BiasAdd");
EXPECT_EQ(fused_ops[1], activation);
found++;
}
}
EXPECT_EQ(found, 1);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
}
}
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow
#endif // INTEL_MKL #endif // INTEL_MKL

View File

@ -57,6 +57,7 @@ namespace {
constexpr char kFusedConv2D[] = "_FusedConv2D"; constexpr char kFusedConv2D[] = "_FusedConv2D";
constexpr char kFusedMatMul[] = "_FusedMatMul"; constexpr char kFusedMatMul[] = "_FusedMatMul";
constexpr char kFusedDepthwiseConv2dNative[] = "_FusedDepthwiseConv2dNative";
constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx"; constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx";
constexpr char kDataFormat[] = "data_format"; constexpr char kDataFormat[] = "data_format";
@ -279,12 +280,24 @@ bool IsCpuCompatibleMatMul(const NodeDef* matmul) {
return NodeIsOnCpu(matmul) && IsCpuCompatibleDataType(matmul); return NodeIsOnCpu(matmul) && IsCpuCompatibleDataType(matmul);
} }
bool IsCpuCompatibleDepthwiseConv2dNative(const NodeDef* dw_conv2d) {
DCHECK(IsDepthwiseConv2dNative(*dw_conv2d))
<< "Expected DepthwiseConv2dNative op";
return NodeIsOnCpu(dw_conv2d) && IsCpuCompatibleDataType(dw_conv2d);
}
// Checks if we can rewrite a pattern to the `_Fused{Conv2D,MatMul}` on CPU. // Checks if we can rewrite a pattern to the `_Fused{Conv2D,MatMul}` on CPU.
template <typename Pattern> template <typename Pattern>
bool IsCpuCompatible(const RemapperContext& ctx, const Pattern& matched) { bool IsCpuCompatible(const RemapperContext& ctx, const Pattern& matched) {
const NodeDef& node = ctx.graph_view.graph()->node(matched.contraction); const NodeDef& node = ctx.graph_view.graph()->node(matched.contraction);
if (IsConv2D(node)) { if (IsConv2D(node)) {
return IsCpuCompatibleConv2D(&node); return IsCpuCompatibleConv2D(&node);
} else if (IsDepthwiseConv2dNative(node)) {
#ifdef INTEL_MKL
return IsCpuCompatibleDepthwiseConv2dNative(&node);
#else
return false;
#endif // INTEL_MKL
} else if (IsMatMul(node)) { } else if (IsMatMul(node)) {
return IsCpuCompatibleMatMul(&node); return IsCpuCompatibleMatMul(&node);
} else { } else {
@ -381,11 +394,12 @@ bool FindContractionWithBias(const RemapperContext& ctx, int node_index,
const auto* contraction_node_view = regular_fanin_0.node_view(); const auto* contraction_node_view = regular_fanin_0.node_view();
const auto* contraction_node_def = contraction_node_view->node(); const auto* contraction_node_def = contraction_node_view->node();
bool is_conv2d_or_matmul = // Conv2D, MatMul or DepthwiseConv2D
IsConv2D(*contraction_node_def) || IsMatMul(*contraction_node_def); bool is_contraction = IsConv2D(*contraction_node_def) ||
IsMatMul(*contraction_node_def) ||
IsDepthwiseConv2dNative(*contraction_node_def);
if (!is_conv2d_or_matmul || if (!is_contraction || !HaveSameDataType(node_def, contraction_node_def) ||
!HaveSameDataType(node_def, contraction_node_def) ||
HasControlFaninOrFanout(*contraction_node_view) || HasControlFaninOrFanout(*contraction_node_view) ||
!HasAtMostOneFanoutAtPort0(*contraction_node_view) || !HasAtMostOneFanoutAtPort0(*contraction_node_view) ||
IsInPreserveSet(ctx, contraction_node_def)) IsInPreserveSet(ctx, contraction_node_def))
@ -902,6 +916,21 @@ void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d) {
(*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu"); (*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
} }
void CopyDepthwiseConv2dNativeAttributes(const NodeDef& dw_conv2d,
NodeDef* fused_dw_conv2d) {
DCHECK(IsDepthwiseConv2dNative(dw_conv2d))
<< "Input node must be a DepthwiseConv2dNative";
auto* attr = fused_dw_conv2d->mutable_attr();
auto& src_attr = dw_conv2d.attr();
(*attr)["T"] = src_attr.at("T");
(*attr)["strides"] = src_attr.at("strides");
(*attr)["padding"] = src_attr.at("padding");
(*attr)["dilations"] = src_attr.at("dilations");
(*attr)["data_format"] = src_attr.at("data_format");
}
void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm, void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
NodeDef* fused_batch_norm_ex) { NodeDef* fused_batch_norm_ex) {
DCHECK(IsFusedBatchNorm(fused_batch_norm)) DCHECK(IsFusedBatchNorm(fused_batch_norm))
@ -966,6 +995,9 @@ Status AddFusedContractionNode(RemapperContext* ctx,
if (IsConv2D(contraction)) { if (IsConv2D(contraction)) {
fused_op.set_op(kFusedConv2D); fused_op.set_op(kFusedConv2D);
CopyConv2DAttributes(contraction, &fused_op); CopyConv2DAttributes(contraction, &fused_op);
} else if (IsDepthwiseConv2dNative(contraction)) {
fused_op.set_op(kFusedDepthwiseConv2dNative);
CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
} else if (IsMatMul(contraction)) { } else if (IsMatMul(contraction)) {
fused_op.set_op(kFusedMatMul); fused_op.set_op(kFusedMatMul);
CopyMatMulAttributes(contraction, &fused_op); CopyMatMulAttributes(contraction, &fused_op);
@ -1010,6 +1042,9 @@ Status AddFusedContractionNode(
if (IsConv2D(contraction)) { if (IsConv2D(contraction)) {
fused_op.set_op(kFusedConv2D); fused_op.set_op(kFusedConv2D);
CopyConv2DAttributes(contraction, &fused_op); CopyConv2DAttributes(contraction, &fused_op);
} else if (IsDepthwiseConv2dNative(contraction)) {
fused_op.set_op(kFusedDepthwiseConv2dNative);
CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
} else if (IsMatMul(contraction)) { } else if (IsMatMul(contraction)) {
fused_op.set_op(kFusedMatMul); fused_op.set_op(kFusedMatMul);
CopyMatMulAttributes(contraction, &fused_op); CopyMatMulAttributes(contraction, &fused_op);
@ -1660,7 +1695,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
} }
#endif //! INTEL_MKL #endif //! INTEL_MKL
// Remap {Conv2D,MatMul}+BiasAdd into the _Fused{Conv2D,MatMul} // Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd into the
// _Fused{Conv2D,DepthwiseConv2dNative,MatMul}
ContractionWithBiasAdd contract_with_bias; ContractionWithBiasAdd contract_with_bias;
if (allow_non_differentiable_rewrites && if (allow_non_differentiable_rewrites &&
FindContractionWithBias(ctx, i, &contract_with_bias)) { FindContractionWithBias(ctx, i, &contract_with_bias)) {
@ -1669,7 +1705,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
continue; continue;
} }
// Remap {Conv2D,MatMul}+BiasAdd+Activation into the _Fused{Conv2D,MatMul}. // Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd+Activation into the
// _Fused{Conv2D,DepthwiseConv2dNative,MatMul}.
ContractionWithBiasAddAndActivation contract_with_bias_and_activation; ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
if (allow_non_differentiable_rewrites && if (allow_non_differentiable_rewrites &&
FindContractionWithBiasAndActivation( FindContractionWithBiasAndActivation(

View File

@ -1363,6 +1363,58 @@ class MklFusedConvOp
virtual ~MklFusedConvOp() {} virtual ~MklFusedConvOp() {}
}; };
template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
typename Toutput, typename Ttemp_output, typename Tpadding,
bool pad_enabled, bool bias_enabled, bool is_depthwise>
class MklFusedDepthwiseConvOp
: public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
Tpadding, bias_enabled, false, is_depthwise, false> {
public:
explicit MklFusedDepthwiseConvOp(OpKernelConstruction* context)
: MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
Tpadding, bias_enabled, false, is_depthwise, false>(context) {
// Since we came here through the registration of
// _MklFusedDepthwiseConv2dNative, get all
// information from 'fused_ops' and 'num_args'
std::vector<string> fused_ops;
OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops));
int num_args;
OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args));
OP_REQUIRES(context, !fused_ops.empty(),
errors::InvalidArgument(
"Fused DepthwiseConv2D must have at least one fused op."));
if (fused_ops == std::vector<string>{"BiasAdd"}) {
this->set_fuse_biasadd(true);
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
} else {
OP_REQUIRES(context, false,
errors::Unimplemented("Fusion is not implemented: [",
absl::StrJoin(fused_ops, ","), "]"));
}
OP_REQUIRES(
context, num_args == 1,
errors::InvalidArgument(
"Fused DepthwiseConv2D must have one extra argument: bias."));
if (pad_enabled) {
this->set_fuse_pad(true);
}
}
virtual ~MklFusedDepthwiseConvOp() {}
};
// We create new class for each version of Quantized Convolution and inherit // We create new class for each version of Quantized Convolution and inherit
// from the FP32 version of the base class // from the FP32 version of the base class
template <typename Device, typename Tinput, typename Tbias, typename Toutput, template <typename Device, typename Tinput, typename Tbias, typename Toutput,
@ -2253,6 +2305,11 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<quint8>("out_type"), .TypeConstraint<quint8>("out_type"),
NoOp); NoOp);
REGISTER_KERNEL_BUILDER(Name("_FusedDepthwiseConv2dNative")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T"),
NoOp);
// Register templatized MKL kernels for non-fused and fused-versions of // Register templatized MKL kernels for non-fused and fused-versions of
// QuantizedDepthwiseConv2D. // QuantizedDepthwiseConv2D.
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedDepthwiseConv2D") REGISTER_KERNEL_BUILDER(Name("_MklQuantizedDepthwiseConv2D")
@ -2306,6 +2363,14 @@ REGISTER_KERNEL_BUILDER(
MklQuantizedConv2DReluOp<CPUDevice, quint8, qint32, quint8, quint8, true, MklQuantizedConv2DReluOp<CPUDevice, quint8, qint32, quint8, quint8, true,
true>); true>);
REGISTER_KERNEL_BUILDER(
Name("_MklFusedDepthwiseConv2dNative")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T")
.Label(mkl_op_registry::kMklLayoutDependentOpLabel),
MklFusedDepthwiseConvOp<CPUDevice, float, float, float, float, float, int32,
false, true, true>);
// Register 2D operations // Register 2D operations
#define REGISTER_MKL_CPU_2D(T) \ #define REGISTER_MKL_CPU_2D(T) \
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \

View File

@ -134,6 +134,7 @@ class CommonTestUtilities : public OpsTestBase {
static void VerifyFusedTensorsClose(int depth, int image_width, static void VerifyFusedTensorsClose(int depth, int image_width,
int image_height, int image_batch_count, int image_height, int image_batch_count,
int filter_size, int filter_count, int filter_size, int filter_count,
int bias_size,
const std::vector<string>& fused_ops, const std::vector<string>& fused_ops,
const FusedGraphRunner& run_default, const FusedGraphRunner& run_default,
const FusedGraphRunner& run_fused) { const FusedGraphRunner& run_fused) {
@ -145,7 +146,6 @@ class CommonTestUtilities : public OpsTestBase {
Tensor filter(dtype, {filter_size, filter_size, depth, filter_count}); Tensor filter(dtype, {filter_size, filter_size, depth, filter_count});
filter.flat<T>() = filter.flat<T>().template setRandom<random_gen_>(); filter.flat<T>() = filter.flat<T>().template setRandom<random_gen_>();
const int bias_size = filter_count;
Tensor bias(dtype, {bias_size}); Tensor bias(dtype, {bias_size});
bias.flat<T>() = bias.flat<T>().template setRandom<random_gen_>(); bias.flat<T>() = bias.flat<T>().template setRandom<random_gen_>();
@ -321,9 +321,10 @@ class MklFusedConv2DOpTest : public OpsTestBase {
out); out);
}; };
const int bias_size = filter_count;
CommonTestUtilities<T>::VerifyFusedTensorsClose( CommonTestUtilities<T>::VerifyFusedTensorsClose(
depth, image_width, image_height, image_batch_count, filter_size, depth, image_width, image_height, image_batch_count, filter_size,
filter_count, fused_ops, run_default, run_fused); filter_count, bias_size, fused_ops, run_default, run_fused);
} }
}; };
@ -449,6 +450,223 @@ REGISTER_TYPED_TEST_CASE_P(
using MklFusedBiasAddDataTypes = ::testing::Types<float>; using MklFusedBiasAddDataTypes = ::testing::Types<float>;
INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedConv2DWithBiasOpTest, INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedConv2DWithBiasOpTest,
MklFusedBiasAddDataTypes); MklFusedBiasAddDataTypes);
// Testing MKL's fused depthwise convolution ops
template <typename T>
class MklFusedDepthwiseConv2DOpTest : public OpsTestBase {
protected:
static constexpr int kDepth = 3;
static constexpr int kImageWidth = 32;
static constexpr int kImageHeight = 32;
static constexpr int kImageBatchCount = 8;
void RunDepthwiseConv2DUnfused(const Tensor& input_data,
const Tensor& filter_data,
const Tensor& bias_data,
const std::vector<string>& fused_ops,
Tensor* output, int stride = 1) {
auto root = tensorflow::Scope::NewRootScope();
auto input_data_op =
ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
Output next_op = ops::DepthwiseConv2dNative(
root.WithOpName("depthwise_conv"), input_data_op,
ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data)),
{1, stride, stride, 1}, "SAME");
string last_op = "";
if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") !=
fused_ops.end()) {
last_op = "with_bias";
next_op = ops::BiasAdd(
root.WithOpName(last_op), next_op,
ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data)));
}
if (std::find(fused_ops.begin(), fused_ops.end(), "Relu") !=
fused_ops.end()) {
last_op = "with_relu";
next_op = ops::Relu(root.WithOpName(last_op), next_op);
}
if (std::find(fused_ops.begin(), fused_ops.end(), "Relu6") !=
fused_ops.end()) {
last_op = "with_relu6";
next_op = ops::Relu6(root.WithOpName(last_op), next_op);
}
if (std::find(fused_ops.begin(), fused_ops.end(), "Elu") !=
fused_ops.end()) {
last_op = "with_elu";
next_op = ops::Elu(root.WithOpName(last_op), next_op);
}
CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
}
void RunMklFusedDepthwiseConv2DOp(const Tensor& image, const Tensor& filter,
const std::vector<Tensor>& args,
const std::vector<string>& fused_ops,
Tensor* output, int stride = 1) {
DataType dtype = DataTypeToEnum<T>::v();
int num_args = static_cast<int>(args.size());
TF_EXPECT_OK(NodeDefBuilder("fused_depthwise_conv_op",
"_MklFusedDepthwiseConv2dNative")
.Input(FakeInput(dtype))
.Input(FakeInput(dtype))
.Input(FakeInput(num_args, dtype))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(num_args, DT_UINT8))
.Attr("T", dtype)
.Attr("num_args", num_args)
.Attr("strides", {1, stride, stride, 1})
.Attr("padding", "SAME")
.Attr("fused_ops", fused_ops)
.Attr("_kernel", "MklLayoutDependentOp")
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
AddInputFromArray<T>(image.shape(), image.flat<T>());
AddInputFromArray<T>(filter.shape(), filter.flat<T>());
for (const Tensor& arg : args)
AddInputFromArray<T>(arg.shape(), arg.flat<T>());
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
for (const Tensor& arg : args)
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
TF_ASSERT_OK(RunOpKernel());
// Compare output to expected results
const Tensor& output_tensor = *GetOutput(0);
// Index 2 will need to be changed if the number of outputs produced
// by MklDepthwiseConv2D change.
const Tensor& output_meta_tensor = *GetOutput(2);
CommonTestUtilities<T> test_util;
test_util.PerformConversion(dtype, output_tensor, output_meta_tensor,
output);
}
// Verifies computing unfused ops in a graph is identical to
// FusedDepthwiseConv2D.
void VerifyFusedDepthwiseConv2D(int filter_size, int filter_count,
int bias_size,
const std::vector<string>& fused_ops,
int depth = kDepth,
int image_width = kImageWidth,
int image_height = kImageHeight,
int image_batch_count = kImageBatchCount) {
const FusedGraphRunner run_default =
[this](const Tensor& input_data, const Tensor& filter_data,
const Tensor& bias_data, const std::vector<string>& fused_ops,
Tensor* out) {
RunDepthwiseConv2DUnfused(input_data, filter_data, bias_data,
fused_ops, out);
};
const FusedGraphRunner run_fused =
[this](const Tensor& input_data, const Tensor& filter_data,
const Tensor& bias_data, const std::vector<string>& fused_ops,
Tensor* out) {
std::vector<Tensor> fused_input = {bias_data};
RunMklFusedDepthwiseConv2DOp(input_data, filter_data, fused_input,
fused_ops, out);
};
CommonTestUtilities<T>::VerifyFusedTensorsClose(
depth, image_width, image_height, image_batch_count, filter_size,
filter_count, bias_size, fused_ops, run_default, run_fused);
}
};
template <typename T>
class MklFusedDepthwiseConv2DWithBiasOpTest
: public MklFusedDepthwiseConv2DOpTest<T> {};
TYPED_TEST_SUITE_P(MklFusedDepthwiseConv2DWithBiasOpTest);
// -------------------------------------------------------------------------- //
// DepthwiseConv2D + BiasAdd + {Activation} //
// -------------------------------------------------------------------------- //
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolution) {
const int kFilterSize = 1;
const int kFilterCount = 1;
const int kBiasSize = 3;
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
{"BiasAdd"});
}
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolution) {
const int kFilterSize = 3;
const int kFilterCount = 1;
const int kBiasSize = 3;
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
{"BiasAdd"});
}
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,
OneByOneConvolutionAndRelu) {
const int kFilterSize = 1;
const int kFilterCount = 1;
const int kBiasSize = 3;
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
{"BiasAdd", "Relu"});
}
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolutionAndRelu) {
const int kFilterSize = 3;
const int kFilterCount = 1;
const int kBiasSize = 3;
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
{"BiasAdd", "Relu"});
}
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,
OneByOneConvolutionAndRelu6) {
const int kFilterSize = 1;
const int kFilterCount = 1;
const int kBiasSize = 3;
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
{"BiasAdd", "Relu6"});
}
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,
SpatialConvolutionAndRelu6) {
const int kFilterSize = 3;
const int kFilterCount = 1;
const int kBiasSize = 3;
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
{"BiasAdd", "Relu6"});
}
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolutionAndElu) {
const int kFilterSize = 1;
const int kFilterCount = 1;
const int kBiasSize = 3;
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
{"BiasAdd", "Elu"});
}
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolutionAndElu) {
const int kFilterSize = 3;
const int kFilterCount = 1;
const int kBiasSize = 3;
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
{"BiasAdd", "Elu"});
}
REGISTER_TYPED_TEST_SUITE_P(
MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolution,
SpatialConvolution, OneByOneConvolutionAndRelu, SpatialConvolutionAndRelu,
OneByOneConvolutionAndRelu6, SpatialConvolutionAndRelu6,
OneByOneConvolutionAndElu, SpatialConvolutionAndElu);
using MklFusedBiasAddDataTypes = ::testing::Types<float>;
INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedDepthwiseConv2DWithBiasOpTest,
MklFusedBiasAddDataTypes);
// Testing fusion of pad and convolution // Testing fusion of pad and convolution
class FusedPadConvOpTest : public OpsTestBase { class FusedPadConvOpTest : public OpsTestBase {

View File

@ -61,6 +61,30 @@ REGISTER_OP("_MklFusedConv2D")
is expected to create these operators. is expected to create these operators.
)doc"); )doc");
REGISTER_OP("_MklFusedDepthwiseConv2dNative")
.Input("input: T")
.Input("filter: T")
.Input("args: num_args * T")
.Input("mkl_input: uint8")
.Input("mkl_filter: uint8")
.Input("mkl_args: num_args * uint8")
.Output("output: T")
.Output("filter_output: T")
.Output("mkl_output: uint8")
.Output("mkl_filter_output: uint8")
.Attr("T: {bfloat16, float}")
.Attr("num_args: int >= 0")
.Attr("strides: list(int)")
.Attr("is_filter_const: bool = false")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr("dilations: list(int) = [1, 1, 1, 1]")
.Attr("fused_ops: list(string) = []")
// Attributes for the FusedBatchNorm ------------------------------------ //
.Attr("epsilon: float = 0.0001")
// ---------------------------------------------------------------------- //
.SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
REGISTER_OP("_MklFusedMatMul") REGISTER_OP("_MklFusedMatMul")
.Input("a: T") .Input("a: T")
.Input("b: T") .Input("b: T")

View File

@ -596,6 +596,23 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
return Status::OK(); return Status::OK();
}); });
REGISTER_OP("_FusedDepthwiseConv2dNative")
.Input("input: T")
.Input("filter: T")
.Input("args: num_args * T")
.Output("output: T")
.Attr("T: {half, bfloat16, float, double}")
.Attr("num_args: int >= 0")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr("dilations: list(int) = [1, 1, 1, 1]")
.Attr("fused_ops: list(string) = []")
// Attributes for the FusedBatchNorm ------------------------------------ //
.Attr("epsilon: float = 0.0001")
// ---------------------------------------------------------------------- //
.SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
REGISTER_OP("Conv3D") REGISTER_OP("Conv3D")
.Input("input: T") .Input("input: T")