Skip FusedBatchNorm rewrite when input is 5D tensor

This commit is contained in:
Mahmoud Abuzaina 2020-10-14 18:08:00 -07:00
parent 412e03a30f
commit 646bc6ecee
5 changed files with 121 additions and 4 deletions

View File

@ -54,6 +54,9 @@ class MklEagerOpRewrite : public EagerOpRewrite {
// Rewrite rule for Conv2D, Conv2DBackpropInput and Conv2DBackpropFilter.
static bool RewriteConv2D(EagerOperation* op);
// Rewrite rule for FusedBatchNormV3 and FusedBatchNormGradV3
static bool RewriteFusedBatchNormV3(EagerOperation* op);
// Calls op-specific rewrite function to create new MKL op.
Status RewriteToMklOp(EagerOperation* orig_op,
std::unique_ptr<EagerOperation>* mkl_op);
@ -110,9 +113,10 @@ MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line)
InsertMKLEagerOps(
{"FusedBatchNormGradV2", AlwaysRewrite, CreateGenericMklOp});
InsertMKLEagerOps(
{"FusedBatchNormGradV3", AlwaysRewrite, CreateGenericMklOp});
{"FusedBatchNormGradV3", RewriteFusedBatchNormV3, CreateGenericMklOp});
InsertMKLEagerOps({"FusedBatchNormV2", AlwaysRewrite, CreateGenericMklOp});
InsertMKLEagerOps({"FusedBatchNormV3", AlwaysRewrite, CreateGenericMklOp});
InsertMKLEagerOps(
{"FusedBatchNormV3", RewriteFusedBatchNormV3, CreateGenericMklOp});
InsertMKLEagerOps({"MatMul", AlwaysRewrite, CreateGenericMklOp});
};
@ -246,5 +250,15 @@ bool MklEagerOpRewrite::RewriteConv2D(EagerOperation* op) {
return (padding != "EXPLICIT");
}
bool MklEagerOpRewrite::RewriteFusedBatchNormV3(EagerOperation* op) {
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
if (Check5DFormat(ndef)) {
VLOG(1) << "Eager Op Rewrite: FusedBatchNorm(Grad)V3 op currently does not "
<< "support 5D tensors.";
return false;
}
return true;
}
} // namespace tensorflow
#endif // INTEL_MKL

View File

@ -130,6 +130,26 @@ REGISTER_TEST_ALL_TYPES(ConvOpsExplicitPadding_Negative);
REGISTER_TEST_ALL_TYPES(MostOps_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(EagerOpRewriteTest, NAME##_##T) { \
std::vector<string> Fused_BN_ops = {"FusedBatchNormV3", \
"FusedBatchNormGradV3"}; \
for (int i = 0; i < Fused_BN_ops.size(); ++i) { \
auto orig_op = CreateOp(Fused_BN_ops[i]); \
orig_op->MutableAttrs()->Set("T", T); \
orig_op->MutableAttrs()->Set("data_format", "" DATA_FORMAT ""); \
CheckRewrite(orig_op.get(), Fused_BN_ops[i]); \
} \
}
#define DATA_FORMAT "NCDHW"
REGISTER_TEST_ALL_TYPES(FusedBatchNormV3_5D_Negative_1);
#define DATA_FORMAT "NDHWC"
REGISTER_TEST_ALL_TYPES(FusedBatchNormV3_5D_Negative_2);
#undef DATA_FORMAT
#undef REGISTER_TEST
} // namespace tensorflow
#endif // INTEL_MKL

View File

@ -475,11 +475,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back(
{csinfo_.fused_batch_norm_v3,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.fused_batch_norm_grad_v3,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
#ifdef ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.fused_batch_norm_ex,
native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
@ -1705,6 +1705,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return do_rewrite;
}
static bool FusedBatchNormV3Rewrite(const Node* n) {
DCHECK(n);
if (Check5DFormat(n->def())) {
VLOG(1) << "Graph Rewrite: FusedBatchNorm(Grad)V3 op currently does not "
<< "support 5D tensors.";
return false;
}
return true;
}
static bool FusedBatchNormExRewrite(const Node* n) {
DCHECK(n);

View File

@ -3394,6 +3394,37 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV2_Negative) {
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV3_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Float32Input'}" \
"node { name: 'C' op: 'Float32Input'}" \
"node { name: 'D' op: 'Float32Input'}" \
"node { name: 'E' op: 'Float32Input'}" \
"node { name: 'F' op: 'FusedBatchNormV3'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: " DATA_FORMAT " } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'is_training' value { b: true } }" \
" input: ['A', 'B', 'C', 'D', 'E'] }" \
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'F'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Float32Input);C(Float32Input);" \
"D(Float32Input);E(Float32Input);F(FusedBatchNormV3);G(Zeta)" \
"|A->F;A->G;B->F:1;C->F:2;D->F:3;E->F:4;F->G:1"); \
}
#define DATA_FORMAT "'NCDHW'"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV3_5D_Negative_1);
#define DATA_FORMAT "'NDHWC'"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV3_5D_Negative_2);
#undef DATA_FORMAT
#undef REGISTER_TEST
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Negative) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
@ -3417,6 +3448,38 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Negative) {
"B->F:1;C->F:2;D->F:3;E->F:4;F->G:1");
}
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'Float32Input'}" \
"node { name: 'D' op: 'Float32Input'}" \
"node { name: 'E' op: 'Float32Input'}" \
"node { name: 'F' op: 'Float32Input'}" \
"node { name: 'G' op: 'FusedBatchNormGradV3'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: " DATA_FORMAT " } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'is_training' value { b: true } }" \
" input: ['A', 'B', 'C', 'D', 'E', 'F'] }" \
"node { name: 'H' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'G'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(Float32Input);D(Float32Input);" \
"E(Float32Input);F(Float32Input);G(FusedBatchNormGradV3);H(Zeta)" \
"|A->G;A->H;B->G:1;C->G:2;D->G:3;E->G:4;F->G:5;G->H:1"); \
}
#define DATA_FORMAT "'NCDHW'"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGradV3_5D_Negative_1);
#define DATA_FORMAT "'NDHWC'"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGradV3_5D_Negative_2);
#undef DATA_FORMAT
#undef REGISTER_TEST
#ifdef ENABLE_MKLDNN_V1
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \

View File

@ -102,6 +102,16 @@ bool inline NativeFormatEnabled() {
return native_fmt_enabled;
}
// Check if the data_format attribute in the node def represents 5D tensor
bool inline Check5DFormat(const NodeDef& ndef) {
string data_format;
TF_CHECK_OK(GetNodeAttr(ndef, "data_format", &data_format));
if (data_format.compare("NCDHW") == 0 || data_format.compare("NDHWC") == 0) {
return true;
}
return false;
}
namespace mkl_op_registry {
// MKL operators whose kernels are registered with 'MklLayoutDependentOp' label
// (e.g., MklConv2D) understand input tensors in MKL layout. These operators