Merge pull request #37624 from Intel-tensorflow:yifeng/depthwise_conv2d_fusion
PiperOrigin-RevId: 307656312 Change-Id: I7098e05501111caea1c7e32329eb4e502292273e
This commit is contained in:
commit
0fa30f59f3
@ -273,6 +273,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
|
||||
csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3";
|
||||
csinfo_.fused_conv2d = "_FusedConv2D";
|
||||
csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative";
|
||||
csinfo_.fused_matmul = "_FusedMatMul";
|
||||
csinfo_.identity = "Identity";
|
||||
csinfo_.leakyrelu = "LeakyRelu";
|
||||
@ -295,6 +296,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
csinfo_.mkl_depthwise_conv2d_grad_filter =
|
||||
"_MklDepthwiseConv2dNativeBackpropFilter";
|
||||
csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
|
||||
csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative";
|
||||
csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
|
||||
csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
|
||||
csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
|
||||
@ -479,6 +481,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
|
||||
CopyAttrsFusedConv2D, FusedConv2DRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.fused_depthwise_conv2d,
|
||||
csinfo_.mkl_fused_depthwise_conv2d, CopyAttrsFusedConv2D,
|
||||
FusedDepthwiseConv2DRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul,
|
||||
CopyAttrsAllCheckConstFilter, FusedMatMulRewrite});
|
||||
|
||||
@ -925,6 +931,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
string fused_batch_norm_v3;
|
||||
string fused_batch_norm_grad_v3;
|
||||
string fused_conv2d;
|
||||
string fused_depthwise_conv2d;
|
||||
string fused_matmul;
|
||||
string identity;
|
||||
string leakyrelu;
|
||||
@ -945,6 +952,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
string mkl_depthwise_conv2d_grad_input;
|
||||
string mkl_depthwise_conv2d_grad_filter;
|
||||
string mkl_fused_conv2d;
|
||||
string mkl_fused_depthwise_conv2d;
|
||||
string mkl_fused_matmul;
|
||||
string mkl_pad_with_conv2d;
|
||||
string mkl_pad_with_fused_conv2d;
|
||||
@ -1675,6 +1683,25 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"});
|
||||
}
|
||||
|
||||
static bool FusedDepthwiseConv2DRewrite(const Node* n) {
|
||||
// MKL DNN currently doesn't support all fusions that grappler fuses
|
||||
// together with DepthwiseConv2D (ex. batchnorm). We rewrite
|
||||
// _FusedDepthwiseConv2DNative only if it includes those we support.
|
||||
DataType T;
|
||||
if (!TryGetNodeAttr(n->def(), "T", &T) ||
|
||||
!mkl_op_registry::IsMklLayoutDependentOp(
|
||||
csinfo_.mkl_fused_depthwise_conv2d, T)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<string> fused_ops;
|
||||
TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
|
||||
return (fused_ops == std::vector<string>{"BiasAdd"} ||
|
||||
fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
|
||||
fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
|
||||
fused_ops == std::vector<string>{"BiasAdd", "Elu"});
|
||||
}
|
||||
|
||||
// Rewrites input node to a new node specified by its matching rewrite info.
|
||||
//
|
||||
// Method first searches matching rewrite info for input node and then
|
||||
@ -3703,6 +3730,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
|
||||
n->type_string() != csinfo_.pad_with_fused_conv2d &&
|
||||
n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
|
||||
n->type_string() != csinfo_.fused_conv2d &&
|
||||
n->type_string() != csinfo_.fused_depthwise_conv2d &&
|
||||
n->type_string() != csinfo_.fused_matmul &&
|
||||
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
|
||||
T)) {
|
||||
|
@ -1789,6 +1789,56 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive6);
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive7);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedDepthwiseConv2dNative Op fusion
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
InitGraph( \
|
||||
"node { name: 'A' op: '" #INPUT "'}" \
|
||||
"node { name: 'B' op: '" #INPUT "'}" \
|
||||
"node { name: 'C' op: '" #INPUT "'}" \
|
||||
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
|
||||
" attr { key: 'T' value { type: " #T " } }" \
|
||||
" attr { key: 'num_args' value { i: 1 } }" \
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||
"} }" \
|
||||
" attr { key: 'padding' value { s: 'SAME' } }" \
|
||||
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||
"} }" \
|
||||
" attr { key: 'fused_ops' value { list: " FUSED_OPS " } }" \
|
||||
" attr { key: 'epsilon' value { f: 0.001 }}" \
|
||||
" input: ['A', 'B', 'C']}" \
|
||||
"node { name: 'E' op: 'Zeta'" \
|
||||
"attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['D', 'C'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
|
||||
"D(_MklFusedDepthwiseConv2dNative);" \
|
||||
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
|
||||
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
|
||||
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
|
||||
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
|
||||
}
|
||||
|
||||
// BiasAdd fusion
|
||||
#define FUSED_OPS "{s: 'BiasAdd'}"
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive1);
|
||||
|
||||
// BiasAdd + Relu fusion
|
||||
#define FUSED_OPS "{s: 'BiasAdd', s: 'Relu'}"
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive2);
|
||||
|
||||
// BiasAdd + Relu6 fusion
|
||||
#define FUSED_OPS "{s: 'BiasAdd', s: 'Relu6'}"
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive3);
|
||||
|
||||
// BiasAdd + Elu fusion
|
||||
#define FUSED_OPS "{s: 'BiasAdd', s: 'Elu'}"
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive4);
|
||||
|
||||
#undef FUSED_OPS
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedConv2D Op with unsupported fusion
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
@ -1818,6 +1868,36 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive7);
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Negative1);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedDepthwiseConv2dNative with unsupported fusion
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
InitGraph( \
|
||||
"node { name: 'A' op: '" #INPUT "'}" \
|
||||
"node { name: 'B' op: '" #INPUT "'}" \
|
||||
"node { name: 'C' op: '" #INPUT "'}" \
|
||||
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
|
||||
" attr { key: 'T' value { type: " #T " } }" \
|
||||
" attr { key: 'num_args' value { i: 1 } }" \
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||
"} }" \
|
||||
" attr { key: 'padding' value { s: 'SAME' } }" \
|
||||
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||
"} }" \
|
||||
" attr { key: 'fused_ops' value { list: {s: 'Unsupported'} } }" \
|
||||
" attr { key: 'epsilon' value { f: 0.001 }}" \
|
||||
" input: ['A', 'B', 'C']}" \
|
||||
"node { name: 'E' op: 'Zeta'" \
|
||||
"attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['D', 'C'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
|
||||
"D(_FusedDepthwiseConv2dNative);" \
|
||||
"E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Negative1);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedConv2D Op with unsupported type
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
@ -1847,6 +1927,37 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Negative1);
|
||||
REGISTER_TEST(NodeRewrite_FusedConv2D_Negative2, DT_DOUBLE, DoubleInput);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedDepthwiseConv2dNativeOp with unsupported type
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
InitGraph( \
|
||||
"node { name: 'A' op: '" #INPUT "'}" \
|
||||
"node { name: 'B' op: '" #INPUT "'}" \
|
||||
"node { name: 'C' op: '" #INPUT "'}" \
|
||||
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
|
||||
" attr { key: 'T' value { type:" #T "} }" \
|
||||
" attr { key: 'num_args' value { i: 1 } }" \
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||
"} }" \
|
||||
" attr { key: 'padding' value { s: 'SAME' } }" \
|
||||
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||
"} }" \
|
||||
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \
|
||||
" attr { key: 'epsilon' value { f: 0.001 }}" \
|
||||
" input: ['A', 'B', 'C']}" \
|
||||
"node { name: 'E' op: 'Zeta'" \
|
||||
"attr { key: 'T' value { type: " #T "} }" \
|
||||
" input: ['D', 'C'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
|
||||
"D(_FusedDepthwiseConv2dNative);" \
|
||||
"E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E"); \
|
||||
}
|
||||
REGISTER_TEST(NodeRewrite_FusedDepthwiseConv2dNative_Negative2,
|
||||
DT_DOUBLE, DoubleInput);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Test set: _FusedMatMul -> MklFusedMatMul rewrite tests
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
@ -4240,6 +4351,33 @@ TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Positive) {
|
||||
"_MklFusedConv2D"));
|
||||
}
|
||||
|
||||
// _FusedDepthwiseConv2dNative + BiasAdd fusion where filter is a constant.
|
||||
TEST_F(MklLayoutPassTest,
|
||||
FusedDepthwiseConv2dNativeWithBias_FilterCaching_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Const'" // Filter
|
||||
" attr { key: 'dtype' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'value' value { "
|
||||
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
|
||||
" int_val: 0 } } } }"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'num_args' value { i: 1 } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
|
||||
" attr { key: 'epsilon' value { f: 0.001 }}"
|
||||
" input: ['A', 'B', 'C']}"
|
||||
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['D', 'C'] }");
|
||||
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
|
||||
"is_filter_const", "_MklFusedDepthwiseConv2dNative"));
|
||||
}
|
||||
|
||||
// _FusedConv2D + BiasAdd fusion where filter is NOT a constant.
|
||||
TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Negative) {
|
||||
InitGraph(
|
||||
@ -4262,6 +4400,28 @@ TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Negative) {
|
||||
"_MklFusedConv2D"));
|
||||
}
|
||||
|
||||
// _FusedDepthwiseConv2dNative + BiasAdd fusion where filter is NOT a constant.
|
||||
TEST_F(MklLayoutPassTest,
|
||||
FusedDepthwiseConv2dNativeWithBias_FilterCaching_Negative) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}" // Filter
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'num_args' value { i: 1 } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
|
||||
" attr { key: 'epsilon' value { f: 0.001 }}"
|
||||
" input: ['A', 'B', 'C']}"
|
||||
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['D', 'C'] }");
|
||||
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
|
||||
"is_filter_const", "_MklFusedDepthwiseConv2dNative"));
|
||||
}
|
||||
// Depthwise Conv2D op where filter is a constant.
|
||||
TEST_F(MklLayoutPassTest, DepthwiseConv2dNative_FilterCaching_Positive) {
|
||||
InitGraph(
|
||||
|
@ -206,6 +206,94 @@ CREATE_CONV2DFUSION_ADD_BCAST_TEST(AddV2);
|
||||
#undef CREATE_CONV2DFUSION_ADD_BCAST_TEST
|
||||
#undef CREATE_CONV2DFUSION_TEST
|
||||
|
||||
TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBiasAndActivation) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "None"}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto input_shape = Placeholder::Shape({8, 32, 32, 3});
|
||||
auto filter_shape = Placeholder::Shape({1, 1, 3, 1});
|
||||
auto bias_shape = Placeholder::Shape({3});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
|
||||
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
|
||||
|
||||
std::vector<int> strides = {1, 1, 1, 1};
|
||||
auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"),
|
||||
input, filter, strides, "SAME");
|
||||
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);
|
||||
|
||||
ops::Identity fetch = [&]() -> ops::Identity {
|
||||
auto activate = s.WithOpName("activation");
|
||||
auto fetch = s.WithOpName("fetch");
|
||||
|
||||
if (activation == "Relu") {
|
||||
return ops::Identity(fetch, ops::Relu(activate, bias_add));
|
||||
} else if (activation == "Relu6") {
|
||||
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
|
||||
} else if (activation == "Elu") {
|
||||
return ops::Identity(fetch, ops::Elu(activate, bias_add));
|
||||
}
|
||||
|
||||
DCHECK(activation == "None");
|
||||
return ops::Identity(fetch, bias_add);
|
||||
}();
|
||||
|
||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
|
||||
auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 1});
|
||||
auto bias_t = GenerateRandomTensor<DT_FLOAT>({3});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on CPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::ON);
|
||||
GraphDef output;
|
||||
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() != "bias_add" && node.name() != "activation") continue;
|
||||
|
||||
EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
|
||||
ASSERT_EQ(node.input_size(), 3);
|
||||
EXPECT_EQ(node.input(0), "input");
|
||||
EXPECT_EQ(node.input(1), "filter");
|
||||
|
||||
EXPECT_EQ(node.attr().at("num_args").i(), 1);
|
||||
EXPECT_EQ(node.input(2), "bias");
|
||||
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||
if (node.name() == "bias_add") {
|
||||
ASSERT_EQ(fused_ops.size(), 1);
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||
found++;
|
||||
}
|
||||
if (node.name() == "activation") {
|
||||
ASSERT_EQ(fused_ops.size(), 2);
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||
EXPECT_EQ(fused_ops[1], activation);
|
||||
found++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(found, 1);
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
#endif // INTEL_MKL
|
||||
|
@ -57,6 +57,7 @@ namespace {
|
||||
|
||||
constexpr char kFusedConv2D[] = "_FusedConv2D";
|
||||
constexpr char kFusedMatMul[] = "_FusedMatMul";
|
||||
constexpr char kFusedDepthwiseConv2dNative[] = "_FusedDepthwiseConv2dNative";
|
||||
constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx";
|
||||
|
||||
constexpr char kDataFormat[] = "data_format";
|
||||
@ -279,12 +280,24 @@ bool IsCpuCompatibleMatMul(const NodeDef* matmul) {
|
||||
return NodeIsOnCpu(matmul) && IsCpuCompatibleDataType(matmul);
|
||||
}
|
||||
|
||||
bool IsCpuCompatibleDepthwiseConv2dNative(const NodeDef* dw_conv2d) {
|
||||
DCHECK(IsDepthwiseConv2dNative(*dw_conv2d))
|
||||
<< "Expected DepthwiseConv2dNative op";
|
||||
return NodeIsOnCpu(dw_conv2d) && IsCpuCompatibleDataType(dw_conv2d);
|
||||
}
|
||||
|
||||
// Checks if we can rewrite a pattern to the `_Fused{Conv2D,MatMul}` on CPU.
|
||||
template <typename Pattern>
|
||||
bool IsCpuCompatible(const RemapperContext& ctx, const Pattern& matched) {
|
||||
const NodeDef& node = ctx.graph_view.graph()->node(matched.contraction);
|
||||
if (IsConv2D(node)) {
|
||||
return IsCpuCompatibleConv2D(&node);
|
||||
} else if (IsDepthwiseConv2dNative(node)) {
|
||||
#ifdef INTEL_MKL
|
||||
return IsCpuCompatibleDepthwiseConv2dNative(&node);
|
||||
#else
|
||||
return false;
|
||||
#endif // INTEL_MKL
|
||||
} else if (IsMatMul(node)) {
|
||||
return IsCpuCompatibleMatMul(&node);
|
||||
} else {
|
||||
@ -381,11 +394,12 @@ bool FindContractionWithBias(const RemapperContext& ctx, int node_index,
|
||||
const auto* contraction_node_view = regular_fanin_0.node_view();
|
||||
const auto* contraction_node_def = contraction_node_view->node();
|
||||
|
||||
bool is_conv2d_or_matmul =
|
||||
IsConv2D(*contraction_node_def) || IsMatMul(*contraction_node_def);
|
||||
// Conv2D, MatMul or DepthwiseConv2D
|
||||
bool is_contraction = IsConv2D(*contraction_node_def) ||
|
||||
IsMatMul(*contraction_node_def) ||
|
||||
IsDepthwiseConv2dNative(*contraction_node_def);
|
||||
|
||||
if (!is_conv2d_or_matmul ||
|
||||
!HaveSameDataType(node_def, contraction_node_def) ||
|
||||
if (!is_contraction || !HaveSameDataType(node_def, contraction_node_def) ||
|
||||
HasControlFaninOrFanout(*contraction_node_view) ||
|
||||
!HasAtMostOneFanoutAtPort0(*contraction_node_view) ||
|
||||
IsInPreserveSet(ctx, contraction_node_def))
|
||||
@ -902,6 +916,21 @@ void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d) {
|
||||
(*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
|
||||
}
|
||||
|
||||
void CopyDepthwiseConv2dNativeAttributes(const NodeDef& dw_conv2d,
|
||||
NodeDef* fused_dw_conv2d) {
|
||||
DCHECK(IsDepthwiseConv2dNative(dw_conv2d))
|
||||
<< "Input node must be a DepthwiseConv2dNative";
|
||||
|
||||
auto* attr = fused_dw_conv2d->mutable_attr();
|
||||
auto& src_attr = dw_conv2d.attr();
|
||||
|
||||
(*attr)["T"] = src_attr.at("T");
|
||||
(*attr)["strides"] = src_attr.at("strides");
|
||||
(*attr)["padding"] = src_attr.at("padding");
|
||||
(*attr)["dilations"] = src_attr.at("dilations");
|
||||
(*attr)["data_format"] = src_attr.at("data_format");
|
||||
}
|
||||
|
||||
void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
|
||||
NodeDef* fused_batch_norm_ex) {
|
||||
DCHECK(IsFusedBatchNorm(fused_batch_norm))
|
||||
@ -966,6 +995,9 @@ Status AddFusedContractionNode(RemapperContext* ctx,
|
||||
if (IsConv2D(contraction)) {
|
||||
fused_op.set_op(kFusedConv2D);
|
||||
CopyConv2DAttributes(contraction, &fused_op);
|
||||
} else if (IsDepthwiseConv2dNative(contraction)) {
|
||||
fused_op.set_op(kFusedDepthwiseConv2dNative);
|
||||
CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
|
||||
} else if (IsMatMul(contraction)) {
|
||||
fused_op.set_op(kFusedMatMul);
|
||||
CopyMatMulAttributes(contraction, &fused_op);
|
||||
@ -1010,6 +1042,9 @@ Status AddFusedContractionNode(
|
||||
if (IsConv2D(contraction)) {
|
||||
fused_op.set_op(kFusedConv2D);
|
||||
CopyConv2DAttributes(contraction, &fused_op);
|
||||
} else if (IsDepthwiseConv2dNative(contraction)) {
|
||||
fused_op.set_op(kFusedDepthwiseConv2dNative);
|
||||
CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
|
||||
} else if (IsMatMul(contraction)) {
|
||||
fused_op.set_op(kFusedMatMul);
|
||||
CopyMatMulAttributes(contraction, &fused_op);
|
||||
@ -1660,7 +1695,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
}
|
||||
#endif //! INTEL_MKL
|
||||
|
||||
// Remap {Conv2D,MatMul}+BiasAdd into the _Fused{Conv2D,MatMul}
|
||||
// Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd into the
|
||||
// _Fused{Conv2D,DepthwiseConv2dNative,MatMul}
|
||||
ContractionWithBiasAdd contract_with_bias;
|
||||
if (allow_non_differentiable_rewrites &&
|
||||
FindContractionWithBias(ctx, i, &contract_with_bias)) {
|
||||
@ -1669,7 +1705,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
continue;
|
||||
}
|
||||
|
||||
// Remap {Conv2D,MatMul}+BiasAdd+Activation into the _Fused{Conv2D,MatMul}.
|
||||
// Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd+Activation into the
|
||||
// _Fused{Conv2D,DepthwiseConv2dNative,MatMul}.
|
||||
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
|
||||
if (allow_non_differentiable_rewrites &&
|
||||
FindContractionWithBiasAndActivation(
|
||||
|
@ -1363,6 +1363,58 @@ class MklFusedConvOp
|
||||
virtual ~MklFusedConvOp() {}
|
||||
};
|
||||
|
||||
template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
|
||||
typename Toutput, typename Ttemp_output, typename Tpadding,
|
||||
bool pad_enabled, bool bias_enabled, bool is_depthwise>
|
||||
class MklFusedDepthwiseConvOp
|
||||
: public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
|
||||
Tpadding, bias_enabled, false, is_depthwise, false> {
|
||||
public:
|
||||
explicit MklFusedDepthwiseConvOp(OpKernelConstruction* context)
|
||||
: MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
|
||||
Tpadding, bias_enabled, false, is_depthwise, false>(context) {
|
||||
// Since we came here through the registration of
|
||||
// _MklFusedDepthwiseConv2dNative, get all
|
||||
// information from 'fused_ops' and 'num_args'
|
||||
std::vector<string> fused_ops;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops));
|
||||
|
||||
int num_args;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args));
|
||||
OP_REQUIRES(context, !fused_ops.empty(),
|
||||
errors::InvalidArgument(
|
||||
"Fused DepthwiseConv2D must have at least one fused op."));
|
||||
|
||||
if (fused_ops == std::vector<string>{"BiasAdd"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
|
||||
} else {
|
||||
OP_REQUIRES(context, false,
|
||||
errors::Unimplemented("Fusion is not implemented: [",
|
||||
absl::StrJoin(fused_ops, ","), "]"));
|
||||
}
|
||||
|
||||
OP_REQUIRES(
|
||||
context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused DepthwiseConv2D must have one extra argument: bias."));
|
||||
|
||||
if (pad_enabled) {
|
||||
this->set_fuse_pad(true);
|
||||
}
|
||||
}
|
||||
|
||||
virtual ~MklFusedDepthwiseConvOp() {}
|
||||
};
|
||||
|
||||
// We create new class for each version of Quantized Convolution and inherit
|
||||
// from the FP32 version of the base class
|
||||
template <typename Device, typename Tinput, typename Tbias, typename Toutput,
|
||||
@ -2253,6 +2305,11 @@ REGISTER_KERNEL_BUILDER(
|
||||
.TypeConstraint<quint8>("out_type"),
|
||||
NoOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_FusedDepthwiseConv2dNative")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<float>("T"),
|
||||
NoOp);
|
||||
|
||||
// Register templatized MKL kernels for non-fused and fused-versions of
|
||||
// QuantizedDepthwiseConv2D.
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedDepthwiseConv2D")
|
||||
@ -2306,6 +2363,14 @@ REGISTER_KERNEL_BUILDER(
|
||||
MklQuantizedConv2DReluOp<CPUDevice, quint8, qint32, quint8, quint8, true,
|
||||
true>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("_MklFusedDepthwiseConv2dNative")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<float>("T")
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel),
|
||||
MklFusedDepthwiseConvOp<CPUDevice, float, float, float, float, float, int32,
|
||||
false, true, true>);
|
||||
|
||||
// Register 2D operations
|
||||
#define REGISTER_MKL_CPU_2D(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
|
@ -134,6 +134,7 @@ class CommonTestUtilities : public OpsTestBase {
|
||||
static void VerifyFusedTensorsClose(int depth, int image_width,
|
||||
int image_height, int image_batch_count,
|
||||
int filter_size, int filter_count,
|
||||
int bias_size,
|
||||
const std::vector<string>& fused_ops,
|
||||
const FusedGraphRunner& run_default,
|
||||
const FusedGraphRunner& run_fused) {
|
||||
@ -145,7 +146,6 @@ class CommonTestUtilities : public OpsTestBase {
|
||||
Tensor filter(dtype, {filter_size, filter_size, depth, filter_count});
|
||||
filter.flat<T>() = filter.flat<T>().template setRandom<random_gen_>();
|
||||
|
||||
const int bias_size = filter_count;
|
||||
Tensor bias(dtype, {bias_size});
|
||||
bias.flat<T>() = bias.flat<T>().template setRandom<random_gen_>();
|
||||
|
||||
@ -321,9 +321,10 @@ class MklFusedConv2DOpTest : public OpsTestBase {
|
||||
out);
|
||||
};
|
||||
|
||||
const int bias_size = filter_count;
|
||||
CommonTestUtilities<T>::VerifyFusedTensorsClose(
|
||||
depth, image_width, image_height, image_batch_count, filter_size,
|
||||
filter_count, fused_ops, run_default, run_fused);
|
||||
filter_count, bias_size, fused_ops, run_default, run_fused);
|
||||
}
|
||||
};
|
||||
|
||||
@ -449,6 +450,223 @@ REGISTER_TYPED_TEST_CASE_P(
|
||||
using MklFusedBiasAddDataTypes = ::testing::Types<float>;
|
||||
INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedConv2DWithBiasOpTest,
|
||||
MklFusedBiasAddDataTypes);
|
||||
|
||||
// Testing MKL's fused depthwise convolution ops
|
||||
template <typename T>
|
||||
class MklFusedDepthwiseConv2DOpTest : public OpsTestBase {
|
||||
protected:
|
||||
static constexpr int kDepth = 3;
|
||||
static constexpr int kImageWidth = 32;
|
||||
static constexpr int kImageHeight = 32;
|
||||
static constexpr int kImageBatchCount = 8;
|
||||
|
||||
void RunDepthwiseConv2DUnfused(const Tensor& input_data,
|
||||
const Tensor& filter_data,
|
||||
const Tensor& bias_data,
|
||||
const std::vector<string>& fused_ops,
|
||||
Tensor* output, int stride = 1) {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
auto input_data_op =
|
||||
ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
|
||||
Output next_op = ops::DepthwiseConv2dNative(
|
||||
root.WithOpName("depthwise_conv"), input_data_op,
|
||||
ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data)),
|
||||
{1, stride, stride, 1}, "SAME");
|
||||
|
||||
string last_op = "";
|
||||
if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") !=
|
||||
fused_ops.end()) {
|
||||
last_op = "with_bias";
|
||||
next_op = ops::BiasAdd(
|
||||
root.WithOpName(last_op), next_op,
|
||||
ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data)));
|
||||
}
|
||||
|
||||
if (std::find(fused_ops.begin(), fused_ops.end(), "Relu") !=
|
||||
fused_ops.end()) {
|
||||
last_op = "with_relu";
|
||||
next_op = ops::Relu(root.WithOpName(last_op), next_op);
|
||||
}
|
||||
|
||||
if (std::find(fused_ops.begin(), fused_ops.end(), "Relu6") !=
|
||||
fused_ops.end()) {
|
||||
last_op = "with_relu6";
|
||||
next_op = ops::Relu6(root.WithOpName(last_op), next_op);
|
||||
}
|
||||
|
||||
if (std::find(fused_ops.begin(), fused_ops.end(), "Elu") !=
|
||||
fused_ops.end()) {
|
||||
last_op = "with_elu";
|
||||
next_op = ops::Elu(root.WithOpName(last_op), next_op);
|
||||
}
|
||||
|
||||
CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
|
||||
}
|
||||
|
||||
void RunMklFusedDepthwiseConv2DOp(const Tensor& image, const Tensor& filter,
|
||||
const std::vector<Tensor>& args,
|
||||
const std::vector<string>& fused_ops,
|
||||
Tensor* output, int stride = 1) {
|
||||
DataType dtype = DataTypeToEnum<T>::v();
|
||||
int num_args = static_cast<int>(args.size());
|
||||
|
||||
TF_EXPECT_OK(NodeDefBuilder("fused_depthwise_conv_op",
|
||||
"_MklFusedDepthwiseConv2dNative")
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(num_args, dtype))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(num_args, DT_UINT8))
|
||||
.Attr("T", dtype)
|
||||
.Attr("num_args", num_args)
|
||||
.Attr("strides", {1, stride, stride, 1})
|
||||
.Attr("padding", "SAME")
|
||||
.Attr("fused_ops", fused_ops)
|
||||
.Attr("_kernel", "MklLayoutDependentOp")
|
||||
.Finalize(node_def()));
|
||||
|
||||
TF_EXPECT_OK(InitOp());
|
||||
|
||||
AddInputFromArray<T>(image.shape(), image.flat<T>());
|
||||
AddInputFromArray<T>(filter.shape(), filter.flat<T>());
|
||||
for (const Tensor& arg : args)
|
||||
AddInputFromArray<T>(arg.shape(), arg.flat<T>());
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
for (const Tensor& arg : args)
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Compare output to expected results
|
||||
const Tensor& output_tensor = *GetOutput(0);
|
||||
// Index 2 will need to be changed if the number of outputs produced
|
||||
// by MklDepthwiseConv2D change.
|
||||
const Tensor& output_meta_tensor = *GetOutput(2);
|
||||
CommonTestUtilities<T> test_util;
|
||||
test_util.PerformConversion(dtype, output_tensor, output_meta_tensor,
|
||||
output);
|
||||
}
|
||||
|
||||
// Verifies computing unfused ops in a graph is identical to
|
||||
// FusedDepthwiseConv2D.
|
||||
void VerifyFusedDepthwiseConv2D(int filter_size, int filter_count,
|
||||
int bias_size,
|
||||
const std::vector<string>& fused_ops,
|
||||
int depth = kDepth,
|
||||
int image_width = kImageWidth,
|
||||
int image_height = kImageHeight,
|
||||
int image_batch_count = kImageBatchCount) {
|
||||
const FusedGraphRunner run_default =
|
||||
[this](const Tensor& input_data, const Tensor& filter_data,
|
||||
const Tensor& bias_data, const std::vector<string>& fused_ops,
|
||||
Tensor* out) {
|
||||
RunDepthwiseConv2DUnfused(input_data, filter_data, bias_data,
|
||||
fused_ops, out);
|
||||
};
|
||||
|
||||
const FusedGraphRunner run_fused =
|
||||
[this](const Tensor& input_data, const Tensor& filter_data,
|
||||
const Tensor& bias_data, const std::vector<string>& fused_ops,
|
||||
Tensor* out) {
|
||||
std::vector<Tensor> fused_input = {bias_data};
|
||||
RunMklFusedDepthwiseConv2DOp(input_data, filter_data, fused_input,
|
||||
fused_ops, out);
|
||||
};
|
||||
|
||||
CommonTestUtilities<T>::VerifyFusedTensorsClose(
|
||||
depth, image_width, image_height, image_batch_count, filter_size,
|
||||
filter_count, bias_size, fused_ops, run_default, run_fused);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MklFusedDepthwiseConv2DWithBiasOpTest
|
||||
: public MklFusedDepthwiseConv2DOpTest<T> {};
|
||||
|
||||
TYPED_TEST_SUITE_P(MklFusedDepthwiseConv2DWithBiasOpTest);
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
// DepthwiseConv2D + BiasAdd + {Activation} //
|
||||
// -------------------------------------------------------------------------- //
|
||||
|
||||
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolution) {
|
||||
const int kFilterSize = 1;
|
||||
const int kFilterCount = 1;
|
||||
const int kBiasSize = 3;
|
||||
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||
{"BiasAdd"});
|
||||
}
|
||||
|
||||
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolution) {
|
||||
const int kFilterSize = 3;
|
||||
const int kFilterCount = 1;
|
||||
const int kBiasSize = 3;
|
||||
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||
{"BiasAdd"});
|
||||
}
|
||||
|
||||
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,
|
||||
OneByOneConvolutionAndRelu) {
|
||||
const int kFilterSize = 1;
|
||||
const int kFilterCount = 1;
|
||||
const int kBiasSize = 3;
|
||||
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||
{"BiasAdd", "Relu"});
|
||||
}
|
||||
|
||||
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolutionAndRelu) {
|
||||
const int kFilterSize = 3;
|
||||
const int kFilterCount = 1;
|
||||
const int kBiasSize = 3;
|
||||
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||
{"BiasAdd", "Relu"});
|
||||
}
|
||||
|
||||
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,
|
||||
OneByOneConvolutionAndRelu6) {
|
||||
const int kFilterSize = 1;
|
||||
const int kFilterCount = 1;
|
||||
const int kBiasSize = 3;
|
||||
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||
{"BiasAdd", "Relu6"});
|
||||
}
|
||||
|
||||
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,
|
||||
SpatialConvolutionAndRelu6) {
|
||||
const int kFilterSize = 3;
|
||||
const int kFilterCount = 1;
|
||||
const int kBiasSize = 3;
|
||||
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||
{"BiasAdd", "Relu6"});
|
||||
}
|
||||
|
||||
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolutionAndElu) {
|
||||
const int kFilterSize = 1;
|
||||
const int kFilterCount = 1;
|
||||
const int kBiasSize = 3;
|
||||
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||
{"BiasAdd", "Elu"});
|
||||
}
|
||||
|
||||
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolutionAndElu) {
|
||||
const int kFilterSize = 3;
|
||||
const int kFilterCount = 1;
|
||||
const int kBiasSize = 3;
|
||||
this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
|
||||
{"BiasAdd", "Elu"});
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_SUITE_P(
|
||||
MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolution,
|
||||
SpatialConvolution, OneByOneConvolutionAndRelu, SpatialConvolutionAndRelu,
|
||||
OneByOneConvolutionAndRelu6, SpatialConvolutionAndRelu6,
|
||||
OneByOneConvolutionAndElu, SpatialConvolutionAndElu);
|
||||
|
||||
using MklFusedBiasAddDataTypes = ::testing::Types<float>;
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedDepthwiseConv2DWithBiasOpTest,
|
||||
MklFusedBiasAddDataTypes);
|
||||
|
||||
// Testing fusion of pad and convolution
|
||||
|
||||
class FusedPadConvOpTest : public OpsTestBase {
|
||||
|
@ -61,6 +61,30 @@ REGISTER_OP("_MklFusedConv2D")
|
||||
is expected to create these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_MklFusedDepthwiseConv2dNative")
|
||||
.Input("input: T")
|
||||
.Input("filter: T")
|
||||
.Input("args: num_args * T")
|
||||
.Input("mkl_input: uint8")
|
||||
.Input("mkl_filter: uint8")
|
||||
.Input("mkl_args: num_args * uint8")
|
||||
.Output("output: T")
|
||||
.Output("filter_output: T")
|
||||
.Output("mkl_output: uint8")
|
||||
.Output("mkl_filter_output: uint8")
|
||||
.Attr("T: {bfloat16, float}")
|
||||
.Attr("num_args: int >= 0")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr("is_filter_const: bool = false")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("fused_ops: list(string) = []")
|
||||
// Attributes for the FusedBatchNorm ------------------------------------ //
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
// ---------------------------------------------------------------------- //
|
||||
.SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
|
||||
|
||||
REGISTER_OP("_MklFusedMatMul")
|
||||
.Input("a: T")
|
||||
.Input("b: T")
|
||||
|
@ -596,6 +596,23 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("_FusedDepthwiseConv2dNative")
|
||||
.Input("input: T")
|
||||
.Input("filter: T")
|
||||
.Input("args: num_args * T")
|
||||
.Output("output: T")
|
||||
.Attr("T: {half, bfloat16, float, double}")
|
||||
.Attr("num_args: int >= 0")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("fused_ops: list(string) = []")
|
||||
// Attributes for the FusedBatchNorm ------------------------------------ //
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
// ---------------------------------------------------------------------- //
|
||||
.SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("Conv3D")
|
||||
.Input("input: T")
|
||||
|
Loading…
Reference in New Issue
Block a user