Merge test cases and fix code style
This commit is contained in:
parent
0c017ff755
commit
ed7693574a
tensorflow/core
graph
grappler/optimizers
kernels
@ -482,8 +482,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
CopyAttrsFusedConv2D, FusedConv2DRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.fused_depthwise_conv2d,
|
||||
csinfo_.mkl_fused_depthwise_conv2d,
|
||||
CopyAttrsFusedDepthwiseConv2D, FusedDepthwiseConv2DRewrite,
|
||||
csinfo_.mkl_fused_depthwise_conv2d, CopyAttrsFusedConv2D,
|
||||
FusedDepthwiseConv2DRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul,
|
||||
CopyAttrsAllCheckConstFilter, FusedMatMulRewrite});
|
||||
@ -1683,7 +1683,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"});
|
||||
}
|
||||
|
||||
static bool FusedDepthwiseConv2DRewrite(const Node* n) {
|
||||
static bool FusedDepthwiseConv2DRewrite(const Node* n) {
|
||||
// MKL DNN currently doesn't support all fusions that grappler fuses
|
||||
// together with DepthwiseConv2D (ex. batchnorm). We rewrite
|
||||
// _FusedDepthwiseConv2DNative only if it includes those we support.
|
||||
@ -1913,8 +1913,6 @@ static bool FusedDepthwiseConv2DRewrite(const Node* n) {
|
||||
bool change_format = false);
|
||||
static void CopyAttrsFusedConv2D(const Node* orig_node, NodeBuilder* nb,
|
||||
bool change_format = false);
|
||||
static void CopyAttrsFusedDepthwiseConv2D(const Node* orig_node, NodeBuilder* nb,
|
||||
bool change_format = false);
|
||||
static void CopyAttrsPadWithConv2D(const Node* orig_node, NodeBuilder* nb,
|
||||
bool change_format = false);
|
||||
static void CopyAttrsPadWithFusedConv2D(const Node* orig_node,
|
||||
@ -2874,13 +2872,6 @@ void MklLayoutRewritePass::CopyAttrsFusedConv2D(const Node* orig_node,
|
||||
nb->Attr("epsilon", epsilon);
|
||||
}
|
||||
|
||||
void MklLayoutRewritePass::CopyAttrsFusedDepthwiseConv2D(const Node* orig_node,
|
||||
NodeBuilder* nb,
|
||||
bool change_format) {
|
||||
MklLayoutRewritePass::CopyAttrsFusedConv2D(orig_node, nb, change_format);
|
||||
}
|
||||
|
||||
|
||||
void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
|
||||
NodeBuilder* nb,
|
||||
bool change_format) {
|
||||
|
@ -1593,39 +1593,6 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Dequantize_Negative_Non_SCALED_Mode) {
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive1);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedDepthwiseConv2dNative Op with BiasAdd fusion
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
InitGraph( \
|
||||
"node { name: 'A' op: '" #INPUT "'}" \
|
||||
"node { name: 'B' op: '" #INPUT "'}" \
|
||||
"node { name: 'C' op: '" #INPUT "'}" \
|
||||
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
|
||||
" attr { key: 'T' value { type: " #T " } }" \
|
||||
" attr { key: 'num_args' value { i: 1 } }" \
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||
"} }" \
|
||||
" attr { key: 'padding' value { s: 'SAME' } }" \
|
||||
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||
"} }" \
|
||||
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \
|
||||
" attr { key: 'epsilon' value { f: 0.001 }}" \
|
||||
" input: ['A', 'B', 'C']}" \
|
||||
"node { name: 'E' op: 'Zeta'" \
|
||||
"attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['D', 'C'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
|
||||
"D(_MklFusedDepthwiseConv2dNative);" \
|
||||
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
|
||||
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
|
||||
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
|
||||
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive1);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedConv2D Op with Relu fusion
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
@ -1689,39 +1656,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive2);
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive3);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedDepthwiseConv2dNative Op with BiasAdd+Relu fusion
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
|
||||
"node { name: 'B' op: '" #INPUT "'}" \
|
||||
"node { name: 'C' op: '" #INPUT "'}" \
|
||||
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
|
||||
" attr { key: 'T' value { type: " #T " } }" \
|
||||
" attr { key: 'num_args' value { i: 1 } }" \
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, " \
|
||||
"i:1, i:1} } }" \
|
||||
" attr { key: 'padding' value { s: 'SAME' } }" \
|
||||
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
|
||||
"i:1, i:1} } }" \
|
||||
" attr { key: 'fused_ops'" \
|
||||
" value { list: {s: 'BiasAdd', s: 'Relu'} } }" \
|
||||
" attr { key: 'epsilon' value { f: 0.001 }}" \
|
||||
" input: ['A', 'B', 'C']}" \
|
||||
"node { name: 'E' op: 'Zeta'" \
|
||||
"attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['D', 'C'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
|
||||
"D(_MklFusedDepthwiseConv2dNative);" \
|
||||
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
|
||||
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
|
||||
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
|
||||
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive3);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedConv2D Op with BiasAdd+Relu6 fusion
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
@ -1754,39 +1688,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive3);
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive4);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedDepthwiseConv2dNative Op with BiasAdd+Relu6 fusion
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
|
||||
"node { name: 'B' op: '" #INPUT "'}" \
|
||||
"node { name: 'C' op: '" #INPUT "'}" \
|
||||
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
|
||||
" attr { key: 'T' value { type: " #T " } }" \
|
||||
" attr { key: 'num_args' value { i: 1 } }" \
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, " \
|
||||
"i:1, i:1} } }" \
|
||||
" attr { key: 'padding' value { s: 'SAME' } }" \
|
||||
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
|
||||
"i:1, i:1} } }" \
|
||||
" attr { key: 'fused_ops'" \
|
||||
" value { list: {s: 'BiasAdd', s: 'Relu6'} } }" \
|
||||
" attr { key: 'epsilon' value { f: 0.001 }}" \
|
||||
" input: ['A', 'B', 'C']}" \
|
||||
"node { name: 'E' op: 'Zeta'" \
|
||||
"attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['D', 'C'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
|
||||
"D(_MklFusedDepthwiseConv2dNative);" \
|
||||
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
|
||||
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
|
||||
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
|
||||
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive4);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedConv2D Op with BiasAdd+Elu fusion
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
@ -1819,39 +1720,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive4);
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive5);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedDepthwiseConv2dNative Op with BiasAdd+Elu fusion
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
|
||||
"node { name: 'B' op: '" #INPUT "'}" \
|
||||
"node { name: 'C' op: '" #INPUT "'}" \
|
||||
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
|
||||
" attr { key: 'T' value { type: " #T " } }" \
|
||||
" attr { key: 'num_args' value { i: 1 } }" \
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, " \
|
||||
"i:1, i:1} } }" \
|
||||
" attr { key: 'padding' value { s: 'SAME' } }" \
|
||||
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
|
||||
"i:1, i:1} } }" \
|
||||
" attr { key: 'fused_ops'" \
|
||||
" value { list: {s: 'BiasAdd', s: 'Elu'} } }" \
|
||||
" attr { key: 'epsilon' value { f: 0.001 }}" \
|
||||
" input: ['A', 'B', 'C']}" \
|
||||
"node { name: 'E' op: 'Zeta'" \
|
||||
"attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['D', 'C'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
|
||||
"D(_MklFusedDepthwiseConv2dNative);" \
|
||||
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
|
||||
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
|
||||
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
|
||||
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive5);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedConv2D Op with BiasAdd+Add fusion
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
@ -1921,6 +1789,56 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive6);
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive7);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedDepthwiseConv2dNative Op fusion
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
InitGraph( \
|
||||
"node { name: 'A' op: '" #INPUT "'}" \
|
||||
"node { name: 'B' op: '" #INPUT "'}" \
|
||||
"node { name: 'C' op: '" #INPUT "'}" \
|
||||
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
|
||||
" attr { key: 'T' value { type: " #T " } }" \
|
||||
" attr { key: 'num_args' value { i: 1 } }" \
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }" \
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||
"} }" \
|
||||
" attr { key: 'padding' value { s: 'SAME' } }" \
|
||||
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
|
||||
"} }" \
|
||||
" attr { key: 'fused_ops' value { list: " FUSED_OPS" } }" \
|
||||
" attr { key: 'epsilon' value { f: 0.001 }}" \
|
||||
" input: ['A', 'B', 'C']}" \
|
||||
"node { name: 'E' op: 'Zeta'" \
|
||||
"attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['D', 'C'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
|
||||
"D(_MklFusedDepthwiseConv2dNative);" \
|
||||
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
|
||||
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
|
||||
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
|
||||
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
|
||||
}
|
||||
|
||||
// BiasAdd fusion
|
||||
#define FUSED_OPS "{s: 'BiasAdd'}"
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive1);
|
||||
|
||||
// BiasAdd + Relu fusion
|
||||
#define FUSED_OPS "{s: 'BiasAdd', s: 'Relu'}"
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive2);
|
||||
|
||||
// BiasAdd + Relu6 fusion
|
||||
#define FUSED_OPS "{s: 'BiasAdd', s: 'Relu6'}"
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive3);
|
||||
|
||||
// BiasAdd + Elu fusion
|
||||
#define FUSED_OPS "{s: 'BiasAdd', s: 'Elu'}"
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive4);
|
||||
|
||||
#undef FUSED_OPS
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Rewrite test for _FusedConv2D Op with unsupported fusion
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
|
@ -173,72 +173,10 @@ TEST_F(MklRemapperTest, FuseConv2DWithBiasAndAddNRelu) {
|
||||
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
||||
}
|
||||
|
||||
TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBias) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3});
|
||||
auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 1});
|
||||
auto bias_shape = ops::Placeholder::Shape({3});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
|
||||
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
|
||||
|
||||
std::vector<int> strides = {1, 1, 1, 1};
|
||||
auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"), input, filter, strides, "SAME");
|
||||
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);
|
||||
auto fetch = ops::Identity(s.WithOpName("fetch"), bias_add);
|
||||
|
||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
|
||||
auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 1});
|
||||
auto bias_t = GenerateRandomTensor<DT_FLOAT>({3});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on CPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::ON);
|
||||
GraphDef output;
|
||||
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "bias_add") {
|
||||
EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
|
||||
ASSERT_GE(node.input_size(), 3);
|
||||
EXPECT_EQ(node.input(0), "input");
|
||||
EXPECT_EQ(node.input(1), "filter");
|
||||
|
||||
EXPECT_EQ(node.attr().at("num_args").i(), 1);
|
||||
EXPECT_EQ(node.input(2), "bias");
|
||||
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||
ASSERT_EQ(fused_ops.size(), 1);
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||
found++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(found, 1);
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
|
||||
TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBiasAndActivation) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "None"}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto input_shape = Placeholder::Shape({8, 32, 32, 3});
|
||||
@ -251,7 +189,7 @@ TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBiasAndActivation) {
|
||||
|
||||
std::vector<int> strides = {1, 1, 1, 1};
|
||||
auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"),
|
||||
input, filter, strides, "SAME");
|
||||
input, filter, strides, "SAME");
|
||||
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);
|
||||
|
||||
ops::Identity fetch = [&]() -> ops::Identity {
|
||||
@ -264,6 +202,8 @@ TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBiasAndActivation) {
|
||||
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
|
||||
} else if (activation == "Elu") {
|
||||
return ops::Identity(fetch, ops::Elu(activate, bias_add));
|
||||
} else if (activation == "None") {
|
||||
return ops::Identity(s.WithOpName("fetch"), bias_add);
|
||||
}
|
||||
|
||||
return ops::Identity(fetch, bias);
|
||||
@ -289,16 +229,23 @@ TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBiasAndActivation) {
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() != "bias_add" && node.name() != "activation") continue;
|
||||
|
||||
EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
|
||||
ASSERT_EQ(node.input_size(), 3);
|
||||
EXPECT_EQ(node.input(0), "input");
|
||||
EXPECT_EQ(node.input(1), "filter");
|
||||
|
||||
EXPECT_EQ(node.attr().at("num_args").i(), 1);
|
||||
EXPECT_EQ(node.input(2), "bias");
|
||||
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||
if (node.name() == "bias_add") {
|
||||
ASSERT_EQ(fused_ops.size(), 1);
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||
found++;
|
||||
}
|
||||
if (node.name() == "activation") {
|
||||
EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
|
||||
ASSERT_GE(node.input_size(), 3);
|
||||
EXPECT_EQ(node.input(0), "input");
|
||||
EXPECT_EQ(node.input(1), "filter");
|
||||
|
||||
EXPECT_EQ(node.attr().at("num_args").i(), 1);
|
||||
EXPECT_EQ(node.input(2), "bias");
|
||||
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||
ASSERT_EQ(fused_ops.size(), 2);
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||
EXPECT_EQ(fused_ops[1], activation);
|
||||
|
@ -222,7 +222,11 @@ bool IsCpuCompatibleDataType(const NodeDef* contraction,
|
||||
if (IsConv2D(*contraction)) {
|
||||
return dtype == DT_FLOAT || dtype == DT_DOUBLE;
|
||||
} else if (IsDepthwiseConv2dNative(*contraction)) {
|
||||
#ifdef INTEL_MKL
|
||||
return dtype == DT_FLOAT;
|
||||
#else
|
||||
return false;
|
||||
#endif // INTEL_MKL
|
||||
} else if (IsMatMul(*contraction)) {
|
||||
return dtype == DT_FLOAT;
|
||||
} else {
|
||||
@ -384,12 +388,11 @@ bool FindContractionWithBias(const RemapperContext& ctx, int node_index,
|
||||
const auto* contraction_node_def = contraction_node_view->node();
|
||||
|
||||
// Conv2D, MatMul or DepthwiseConv2D
|
||||
bool is_required_contraction = IsConv2D(*contraction_node_def) ||
|
||||
IsMatMul(*contraction_node_def) ||
|
||||
IsDepthwiseConv2dNative(*contraction_node_def);
|
||||
bool is_contraction = IsConv2D(*contraction_node_def) ||
|
||||
IsMatMul(*contraction_node_def) ||
|
||||
IsDepthwiseConv2dNative(*contraction_node_def);
|
||||
|
||||
if (!is_required_contraction ||
|
||||
!HaveSameDataType(node_def, contraction_node_def) ||
|
||||
if (!is_contraction || !HaveSameDataType(node_def, contraction_node_def) ||
|
||||
HasControlFaninOrFanout(*contraction_node_view) ||
|
||||
!HasAtMostOneFanoutAtPort0(*contraction_node_view) ||
|
||||
IsInPreserveSet(ctx, contraction_node_def))
|
||||
|
@ -1387,37 +1387,26 @@ class MklFusedDepthwiseConvOp
|
||||
|
||||
if (fused_ops == std::vector<string>{"BiasAdd"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
OP_REQUIRES(
|
||||
context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused DepthwiseConv2D must have one extra argument: bias."));
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
||||
OP_REQUIRES(
|
||||
context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused DepthwiseConv2D must have one extra argument: bias."));
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
||||
OP_REQUIRES(
|
||||
context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused DepthwiseConv2D must have one extra argument: bias."));
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
|
||||
OP_REQUIRES(
|
||||
context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused DepthwiseConv2D must have one extra argument: bias."));
|
||||
} else {
|
||||
OP_REQUIRES(context, false,
|
||||
errors::Unimplemented("Fusion is not implemented: [",
|
||||
absl::StrJoin(fused_ops, ","), "]"));
|
||||
}
|
||||
|
||||
OP_REQUIRES(
|
||||
context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused DepthwiseConv2D must have one extra argument: bias."));
|
||||
|
||||
if (pad_enabled) {
|
||||
this->set_fuse_pad(true);
|
||||
}
|
||||
|
@ -452,7 +452,6 @@ INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedConv2DWithBiasOpTest,
|
||||
MklFusedBiasAddDataTypes);
|
||||
|
||||
// Testing MKL's fused depthwise convolution ops
|
||||
//
|
||||
template <typename T>
|
||||
class MklFusedDepthwiseConv2DOpTest : public OpsTestBase {
|
||||
protected:
|
||||
@ -571,10 +570,6 @@ class MklFusedDepthwiseConv2DOpTest : public OpsTestBase {
|
||||
const Tensor& bias_data, const std::vector<string>& fused_ops,
|
||||
Tensor* out) {
|
||||
std::vector<Tensor> fused_input = {bias_data};
|
||||
if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
|
||||
fused_ops.end()) {
|
||||
fused_input.push_back(input_data);
|
||||
}
|
||||
RunMklFusedDepthwiseConv2DOp(input_data, filter_data, fused_input,
|
||||
fused_ops, out);
|
||||
};
|
||||
@ -589,7 +584,7 @@ template <typename T>
|
||||
class MklFusedDepthwiseConv2DWithBiasOpTest
|
||||
: public MklFusedDepthwiseConv2DOpTest<T> {};
|
||||
|
||||
TYPED_TEST_CASE_P(MklFusedDepthwiseConv2DWithBiasOpTest);
|
||||
TYPED_TEST_SUITE_P(MklFusedDepthwiseConv2DWithBiasOpTest);
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
// DepthwiseConv2D + BiasAdd + {Activation} //
|
||||
@ -662,17 +657,15 @@ TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolutionAndElu) {
|
||||
{"BiasAdd", "Elu"});
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_CASE_P(MklFusedDepthwiseConv2DWithBiasOpTest,
|
||||
OneByOneConvolution, SpatialConvolution,
|
||||
OneByOneConvolutionAndRelu,
|
||||
SpatialConvolutionAndRelu,
|
||||
OneByOneConvolutionAndRelu6,
|
||||
SpatialConvolutionAndRelu6,
|
||||
OneByOneConvolutionAndElu, SpatialConvolutionAndElu);
|
||||
REGISTER_TYPED_TEST_SUITE_P(
|
||||
MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolution,
|
||||
SpatialConvolution, OneByOneConvolutionAndRelu, SpatialConvolutionAndRelu,
|
||||
OneByOneConvolutionAndRelu6, SpatialConvolutionAndRelu6,
|
||||
OneByOneConvolutionAndElu, SpatialConvolutionAndElu);
|
||||
|
||||
using MklFusedBiasAddDataTypes = ::testing::Types<float>;
|
||||
INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedDepthwiseConv2DWithBiasOpTest,
|
||||
MklFusedBiasAddDataTypes);
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedDepthwiseConv2DWithBiasOpTest,
|
||||
MklFusedBiasAddDataTypes);
|
||||
|
||||
// Testing fusion of pad and convolution
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user