Skip FusedBatchNorm rewrite when input is 5D tensor
This commit is contained in:
parent
412e03a30f
commit
646bc6ecee
tensorflow/core
@ -54,6 +54,9 @@ class MklEagerOpRewrite : public EagerOpRewrite {
|
||||
// Rewrite rule for Conv2D, Conv2DBackpropInput and Conv2DBackpropFilter.
|
||||
static bool RewriteConv2D(EagerOperation* op);
|
||||
|
||||
// Rewrite rule for FusedBatchNormV3 and FusedBatchNormGradV3
|
||||
static bool RewriteFusedBatchNormV3(EagerOperation* op);
|
||||
|
||||
// Calls op-specific rewrite function to create new MKL op.
|
||||
Status RewriteToMklOp(EagerOperation* orig_op,
|
||||
std::unique_ptr<EagerOperation>* mkl_op);
|
||||
@ -110,9 +113,10 @@ MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line)
|
||||
InsertMKLEagerOps(
|
||||
{"FusedBatchNormGradV2", AlwaysRewrite, CreateGenericMklOp});
|
||||
InsertMKLEagerOps(
|
||||
{"FusedBatchNormGradV3", AlwaysRewrite, CreateGenericMklOp});
|
||||
{"FusedBatchNormGradV3", RewriteFusedBatchNormV3, CreateGenericMklOp});
|
||||
InsertMKLEagerOps({"FusedBatchNormV2", AlwaysRewrite, CreateGenericMklOp});
|
||||
InsertMKLEagerOps({"FusedBatchNormV3", AlwaysRewrite, CreateGenericMklOp});
|
||||
InsertMKLEagerOps(
|
||||
{"FusedBatchNormV3", RewriteFusedBatchNormV3, CreateGenericMklOp});
|
||||
InsertMKLEagerOps({"MatMul", AlwaysRewrite, CreateGenericMklOp});
|
||||
};
|
||||
|
||||
@ -246,5 +250,15 @@ bool MklEagerOpRewrite::RewriteConv2D(EagerOperation* op) {
|
||||
return (padding != "EXPLICIT");
|
||||
}
|
||||
|
||||
bool MklEagerOpRewrite::RewriteFusedBatchNormV3(EagerOperation* op) {
|
||||
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
|
||||
if (Check5DFormat(ndef)) {
|
||||
VLOG(1) << "Eager Op Rewrite: FusedBatchNorm(Grad)V3 op currently does not "
|
||||
<< "support 5D tensors.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // INTEL_MKL
|
||||
|
@ -130,6 +130,26 @@ REGISTER_TEST_ALL_TYPES(ConvOpsExplicitPadding_Negative);
|
||||
REGISTER_TEST_ALL_TYPES(MostOps_Positive);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(EagerOpRewriteTest, NAME##_##T) { \
|
||||
std::vector<string> Fused_BN_ops = {"FusedBatchNormV3", \
|
||||
"FusedBatchNormGradV3"}; \
|
||||
for (int i = 0; i < Fused_BN_ops.size(); ++i) { \
|
||||
auto orig_op = CreateOp(Fused_BN_ops[i]); \
|
||||
orig_op->MutableAttrs()->Set("T", T); \
|
||||
orig_op->MutableAttrs()->Set("data_format", "" DATA_FORMAT ""); \
|
||||
CheckRewrite(orig_op.get(), Fused_BN_ops[i]); \
|
||||
} \
|
||||
}
|
||||
#define DATA_FORMAT "NCDHW"
|
||||
REGISTER_TEST_ALL_TYPES(FusedBatchNormV3_5D_Negative_1);
|
||||
|
||||
#define DATA_FORMAT "NDHWC"
|
||||
REGISTER_TEST_ALL_TYPES(FusedBatchNormV3_5D_Negative_2);
|
||||
|
||||
#undef DATA_FORMAT
|
||||
#undef REGISTER_TEST
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
||||
|
@ -475,11 +475,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
rinfo_.push_back(
|
||||
{csinfo_.fused_batch_norm_v3,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.fused_batch_norm_grad_v3,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.fused_batch_norm_ex,
|
||||
native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
|
||||
@ -1705,6 +1705,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
return do_rewrite;
|
||||
}
|
||||
|
||||
static bool FusedBatchNormV3Rewrite(const Node* n) {
|
||||
DCHECK(n);
|
||||
if (Check5DFormat(n->def())) {
|
||||
VLOG(1) << "Graph Rewrite: FusedBatchNorm(Grad)V3 op currently does not "
|
||||
<< "support 5D tensors.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool FusedBatchNormExRewrite(const Node* n) {
|
||||
DCHECK(n);
|
||||
|
||||
|
@ -3394,6 +3394,37 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV2_Negative) {
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV3_Positive);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
InitGraph( \
|
||||
"node { name: 'A' op: '" #INPUT "'}" \
|
||||
"node { name: 'B' op: 'Float32Input'}" \
|
||||
"node { name: 'C' op: 'Float32Input'}" \
|
||||
"node { name: 'D' op: 'Float32Input'}" \
|
||||
"node { name: 'E' op: 'Float32Input'}" \
|
||||
"node { name: 'F' op: 'FusedBatchNormV3'" \
|
||||
" attr { key: 'T' value { type: " #T " } }" \
|
||||
" attr { key: 'U' value { type: DT_FLOAT } }" \
|
||||
" attr { key: 'data_format' value { s: " DATA_FORMAT " } }" \
|
||||
" attr { key: 'epsilon' value { f: 0.0001 } }" \
|
||||
" attr { key: 'is_training' value { b: true } }" \
|
||||
" input: ['A', 'B', 'C', 'D', 'E'] }" \
|
||||
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['A', 'F'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT ");B(Float32Input);C(Float32Input);" \
|
||||
"D(Float32Input);E(Float32Input);F(FusedBatchNormV3);G(Zeta)" \
|
||||
"|A->F;A->G;B->F:1;C->F:2;D->F:3;E->F:4;F->G:1"); \
|
||||
}
|
||||
#define DATA_FORMAT "'NCDHW'"
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV3_5D_Negative_1);
|
||||
|
||||
#define DATA_FORMAT "'NDHWC'"
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV3_5D_Negative_2);
|
||||
|
||||
#undef DATA_FORMAT
|
||||
#undef REGISTER_TEST
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Negative) {
|
||||
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
InitGraph(
|
||||
@ -3417,6 +3448,38 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Negative) {
|
||||
"B->F:1;C->F:2;D->F:3;E->F:4;F->G:1");
|
||||
}
|
||||
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
InitGraph( \
|
||||
"node { name: 'A' op: '" #INPUT "'}" \
|
||||
"node { name: 'B' op: '" #INPUT "'}" \
|
||||
"node { name: 'C' op: 'Float32Input'}" \
|
||||
"node { name: 'D' op: 'Float32Input'}" \
|
||||
"node { name: 'E' op: 'Float32Input'}" \
|
||||
"node { name: 'F' op: 'Float32Input'}" \
|
||||
"node { name: 'G' op: 'FusedBatchNormGradV3'" \
|
||||
" attr { key: 'T' value { type: " #T " } }" \
|
||||
" attr { key: 'U' value { type: DT_FLOAT } }" \
|
||||
" attr { key: 'data_format' value { s: " DATA_FORMAT " } }" \
|
||||
" attr { key: 'epsilon' value { f: 0.0001 } }" \
|
||||
" attr { key: 'is_training' value { b: true } }" \
|
||||
" input: ['A', 'B', 'C', 'D', 'E', 'F'] }" \
|
||||
"node { name: 'H' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['A', 'G'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT ");B(" #INPUT ");C(Float32Input);D(Float32Input);" \
|
||||
"E(Float32Input);F(Float32Input);G(FusedBatchNormGradV3);H(Zeta)" \
|
||||
"|A->G;A->H;B->G:1;C->G:2;D->G:3;E->G:4;F->G:5;G->H:1"); \
|
||||
}
|
||||
#define DATA_FORMAT "'NCDHW'"
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGradV3_5D_Negative_1);
|
||||
|
||||
#define DATA_FORMAT "'NDHWC'"
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGradV3_5D_Negative_2);
|
||||
|
||||
#undef DATA_FORMAT
|
||||
#undef REGISTER_TEST
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
|
@ -102,6 +102,16 @@ bool inline NativeFormatEnabled() {
|
||||
return native_fmt_enabled;
|
||||
}
|
||||
|
||||
// Check if the data_format attribute in the node def represents 5D tensor
|
||||
bool inline Check5DFormat(const NodeDef& ndef) {
|
||||
string data_format;
|
||||
TF_CHECK_OK(GetNodeAttr(ndef, "data_format", &data_format));
|
||||
if (data_format.compare("NCDHW") == 0 || data_format.compare("NDHWC") == 0) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace mkl_op_registry {
|
||||
// MKL operators whose kernels are registered with 'MklLayoutDependentOp' label
|
||||
// (e.g., MklConv2D) understand input tensors in MKL layout. These operators
|
||||
|
Loading…
Reference in New Issue
Block a user