Merge test cases and fix code style

This commit is contained in:
CuiYifeng 2020-04-15 22:55:44 +08:00
parent 0c017ff755
commit ed7693574a
6 changed files with 94 additions and 253 deletions

View File

@ -482,8 +482,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsFusedConv2D, FusedConv2DRewrite,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.fused_depthwise_conv2d,
csinfo_.mkl_fused_depthwise_conv2d,
CopyAttrsFusedDepthwiseConv2D, FusedDepthwiseConv2DRewrite,
csinfo_.mkl_fused_depthwise_conv2d, CopyAttrsFusedConv2D,
FusedDepthwiseConv2DRewrite,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul,
CopyAttrsAllCheckConstFilter, FusedMatMulRewrite});
@ -1683,7 +1683,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"});
}
static bool FusedDepthwiseConv2DRewrite(const Node* n) {
static bool FusedDepthwiseConv2DRewrite(const Node* n) {
// MKL DNN currently doesn't support all fusions that grappler fuses
// together with DepthwiseConv2D (ex. batchnorm). We rewrite
// _FusedDepthwiseConv2DNative only if it includes those we support.
@ -1913,8 +1913,6 @@ static bool FusedDepthwiseConv2DRewrite(const Node* n) {
bool change_format = false);
static void CopyAttrsFusedConv2D(const Node* orig_node, NodeBuilder* nb,
bool change_format = false);
static void CopyAttrsFusedDepthwiseConv2D(const Node* orig_node, NodeBuilder* nb,
bool change_format = false);
static void CopyAttrsPadWithConv2D(const Node* orig_node, NodeBuilder* nb,
bool change_format = false);
static void CopyAttrsPadWithFusedConv2D(const Node* orig_node,
@ -2874,13 +2872,6 @@ void MklLayoutRewritePass::CopyAttrsFusedConv2D(const Node* orig_node,
nb->Attr("epsilon", epsilon);
}
void MklLayoutRewritePass::CopyAttrsFusedDepthwiseConv2D(const Node* orig_node,
NodeBuilder* nb,
bool change_format) {
MklLayoutRewritePass::CopyAttrsFusedConv2D(orig_node, nb, change_format);
}
void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
NodeBuilder* nb,
bool change_format) {

View File

@ -1593,39 +1593,6 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Dequantize_Negative_Non_SCALED_Mode) {
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive1);
#undef REGISTER_TEST
// Rewrite test for _FusedDepthwiseConv2dNative Op with BiasAdd fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
"D(_MklFusedDepthwiseConv2dNative);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive1);
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with Relu fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
@ -1689,39 +1656,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive2);
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive3);
#undef REGISTER_TEST
// Rewrite test for _FusedDepthwiseConv2dNative Op with BiasAdd+Relu fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'fused_ops'" \
" value { list: {s: 'BiasAdd', s: 'Relu'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
"D(_MklFusedDepthwiseConv2dNative);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive3);
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with BiasAdd+Relu6 fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
@ -1754,39 +1688,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive3);
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive4);
#undef REGISTER_TEST
// Rewrite test for _FusedDepthwiseConv2dNative Op with BiasAdd+Relu6 fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'fused_ops'" \
" value { list: {s: 'BiasAdd', s: 'Relu6'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
"D(_MklFusedDepthwiseConv2dNative);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive4);
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with BiasAdd+Elu fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
@ -1819,39 +1720,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive4);
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive5);
#undef REGISTER_TEST
// Rewrite test for _FusedDepthwiseConv2dNative Op with BiasAdd+Elu fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'fused_ops'" \
" value { list: {s: 'BiasAdd', s: 'Elu'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
"D(_MklFusedDepthwiseConv2dNative);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive5);
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with BiasAdd+Add fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
@ -1921,6 +1789,56 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive6);
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive7);
#undef REGISTER_TEST
// Rewrite test for _FusedDepthwiseConv2dNative Op fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: " FUSED_OPS" } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
"D(_MklFusedDepthwiseConv2dNative);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
// BiasAdd fusion
#define FUSED_OPS "{s: 'BiasAdd'}"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive1);
// BiasAdd + Relu fusion
#define FUSED_OPS "{s: 'BiasAdd', s: 'Relu'}"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive2);
// BiasAdd + Relu6 fusion
#define FUSED_OPS "{s: 'BiasAdd', s: 'Relu6'}"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive3);
// BiasAdd + Elu fusion
#define FUSED_OPS "{s: 'BiasAdd', s: 'Elu'}"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive4);
#undef FUSED_OPS
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with unsupported fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \

View File

@ -173,72 +173,10 @@ TEST_F(MklRemapperTest, FuseConv2DWithBiasAndAddNRelu) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBias) {
using ::tensorflow::ops::Placeholder;
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3});
auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 1});
auto bias_shape = ops::Placeholder::Shape({3});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
std::vector<int> strides = {1, 1, 1, 1};
auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"), input, filter, strides, "SAME");
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);
auto fetch = ops::Identity(s.WithOpName("fetch"), bias_add);
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 1});
auto bias_t = GenerateRandomTensor<DT_FLOAT>({3});
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
Remapper optimizer(RewriterConfig::ON);
GraphDef output;
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "bias_add") {
EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
ASSERT_GE(node.input_size(), 3);
EXPECT_EQ(node.input(0), "input");
EXPECT_EQ(node.input(1), "filter");
EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "bias");
const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(fused_ops.size(), 1);
EXPECT_EQ(fused_ops[0], "BiasAdd");
found++;
}
}
EXPECT_EQ(found, 1);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
}
TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBiasAndActivation) {
using ::tensorflow::ops::Placeholder;
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
for (const string& activation : {"Relu", "Relu6", "Elu", "None"}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = Placeholder::Shape({8, 32, 32, 3});
@ -251,7 +189,7 @@ TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBiasAndActivation) {
std::vector<int> strides = {1, 1, 1, 1};
auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"),
input, filter, strides, "SAME");
input, filter, strides, "SAME");
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);
ops::Identity fetch = [&]() -> ops::Identity {
@ -264,6 +202,8 @@ TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBiasAndActivation) {
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
} else if (activation == "Elu") {
return ops::Identity(fetch, ops::Elu(activate, bias_add));
} else if (activation == "None") {
return ops::Identity(s.WithOpName("fetch"), bias_add);
}
return ops::Identity(fetch, bias);
@ -289,16 +229,23 @@ TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBiasAndActivation) {
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() != "bias_add" && node.name() != "activation") continue;
EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
ASSERT_EQ(node.input_size(), 3);
EXPECT_EQ(node.input(0), "input");
EXPECT_EQ(node.input(1), "filter");
EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "bias");
const auto fused_ops = node.attr().at("fused_ops").list().s();
if (node.name() == "bias_add") {
ASSERT_EQ(fused_ops.size(), 1);
EXPECT_EQ(fused_ops[0], "BiasAdd");
found++;
}
if (node.name() == "activation") {
EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
ASSERT_GE(node.input_size(), 3);
EXPECT_EQ(node.input(0), "input");
EXPECT_EQ(node.input(1), "filter");
EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "bias");
const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(fused_ops.size(), 2);
EXPECT_EQ(fused_ops[0], "BiasAdd");
EXPECT_EQ(fused_ops[1], activation);

View File

@ -222,7 +222,11 @@ bool IsCpuCompatibleDataType(const NodeDef* contraction,
if (IsConv2D(*contraction)) {
return dtype == DT_FLOAT || dtype == DT_DOUBLE;
} else if (IsDepthwiseConv2dNative(*contraction)) {
#ifdef INTEL_MKL
return dtype == DT_FLOAT;
#else
return false;
#endif // INTEL_MKL
} else if (IsMatMul(*contraction)) {
return dtype == DT_FLOAT;
} else {
@ -384,12 +388,11 @@ bool FindContractionWithBias(const RemapperContext& ctx, int node_index,
const auto* contraction_node_def = contraction_node_view->node();
// Conv2D, MatMul or DepthwiseConv2D
bool is_required_contraction = IsConv2D(*contraction_node_def) ||
IsMatMul(*contraction_node_def) ||
IsDepthwiseConv2dNative(*contraction_node_def);
bool is_contraction = IsConv2D(*contraction_node_def) ||
IsMatMul(*contraction_node_def) ||
IsDepthwiseConv2dNative(*contraction_node_def);
if (!is_required_contraction ||
!HaveSameDataType(node_def, contraction_node_def) ||
if (!is_contraction || !HaveSameDataType(node_def, contraction_node_def) ||
HasControlFaninOrFanout(*contraction_node_view) ||
!HasAtMostOneFanoutAtPort0(*contraction_node_view) ||
IsInPreserveSet(ctx, contraction_node_def))

View File

@ -1387,37 +1387,26 @@ class MklFusedDepthwiseConvOp
if (fused_ops == std::vector<string>{"BiasAdd"}) {
this->set_fuse_biasadd(true);
OP_REQUIRES(
context, num_args == 1,
errors::InvalidArgument(
"Fused DepthwiseConv2D must have one extra argument: bias."));
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
OP_REQUIRES(
context, num_args == 1,
errors::InvalidArgument(
"Fused DepthwiseConv2D must have one extra argument: bias."));
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
OP_REQUIRES(
context, num_args == 1,
errors::InvalidArgument(
"Fused DepthwiseConv2D must have one extra argument: bias."));
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
OP_REQUIRES(
context, num_args == 1,
errors::InvalidArgument(
"Fused DepthwiseConv2D must have one extra argument: bias."));
} else {
OP_REQUIRES(context, false,
errors::Unimplemented("Fusion is not implemented: [",
absl::StrJoin(fused_ops, ","), "]"));
}
OP_REQUIRES(
context, num_args == 1,
errors::InvalidArgument(
"Fused DepthwiseConv2D must have one extra argument: bias."));
if (pad_enabled) {
this->set_fuse_pad(true);
}

View File

@ -452,7 +452,6 @@ INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedConv2DWithBiasOpTest,
MklFusedBiasAddDataTypes);
// Testing MKL's fused depthwise convolution ops
//
template <typename T>
class MklFusedDepthwiseConv2DOpTest : public OpsTestBase {
protected:
@ -571,10 +570,6 @@ class MklFusedDepthwiseConv2DOpTest : public OpsTestBase {
const Tensor& bias_data, const std::vector<string>& fused_ops,
Tensor* out) {
std::vector<Tensor> fused_input = {bias_data};
if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
fused_ops.end()) {
fused_input.push_back(input_data);
}
RunMklFusedDepthwiseConv2DOp(input_data, filter_data, fused_input,
fused_ops, out);
};
@ -589,7 +584,7 @@ template <typename T>
class MklFusedDepthwiseConv2DWithBiasOpTest
: public MklFusedDepthwiseConv2DOpTest<T> {};
TYPED_TEST_CASE_P(MklFusedDepthwiseConv2DWithBiasOpTest);
TYPED_TEST_SUITE_P(MklFusedDepthwiseConv2DWithBiasOpTest);
// -------------------------------------------------------------------------- //
// DepthwiseConv2D + BiasAdd + {Activation} //
@ -662,17 +657,15 @@ TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolutionAndElu) {
{"BiasAdd", "Elu"});
}
REGISTER_TYPED_TEST_CASE_P(MklFusedDepthwiseConv2DWithBiasOpTest,
OneByOneConvolution, SpatialConvolution,
OneByOneConvolutionAndRelu,
SpatialConvolutionAndRelu,
OneByOneConvolutionAndRelu6,
SpatialConvolutionAndRelu6,
OneByOneConvolutionAndElu, SpatialConvolutionAndElu);
REGISTER_TYPED_TEST_SUITE_P(
MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolution,
SpatialConvolution, OneByOneConvolutionAndRelu, SpatialConvolutionAndRelu,
OneByOneConvolutionAndRelu6, SpatialConvolutionAndRelu6,
OneByOneConvolutionAndElu, SpatialConvolutionAndElu);
using MklFusedBiasAddDataTypes = ::testing::Types<float>;
INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedDepthwiseConv2DWithBiasOpTest,
MklFusedBiasAddDataTypes);
INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedDepthwiseConv2DWithBiasOpTest,
MklFusedBiasAddDataTypes);
// Testing fusion of pad and convolution