Changes to graph rewrite to support native format

This commit is contained in:
Mahmoud Abuzaina 2020-08-26 20:01:22 -07:00
parent 47afb26c2d
commit ff85b93f52
6 changed files with 305 additions and 237 deletions

View File

@ -142,11 +142,13 @@ Status MklEagerOpRewrite::SetupNewOp(
Status MklEagerOpRewrite::CreateGenericMklOp(
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_op) {
const string mkl_op_name = mkl_op_registry::GetMklOpName(orig_op->Name());
const string mkl_op_name =
mkl_op_registry::GetMklNativeOpName(orig_op->Name());
TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_op));
return Status::OK();
}
// TODO(mabuzain): Replace this call with above generic one.
Status MklEagerOpRewrite::CreateMklConv2DOp(
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op) {
const string mkl_op_name =

View File

@ -300,6 +300,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative";
csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
csinfo_.mkl_native_conv2d_with_bias = "_MklNativeConv2DWithBias";
csinfo_.mkl_native_fused_batch_norm_ex = "_MklNativeFusedBatchNormEx";
csinfo_.mkl_native_fused_conv2d = "_MklNativeFusedConv2D";
csinfo_.mkl_native_fused_depthwise_conv2d =
"_MklNativeFusedDepthwiseConv2dNative";
csinfo_.mkl_native_pad_with_conv2d = "_MklNativePadWithConv2D";
csinfo_.mkl_native_pad_with_fused_conv2d = "_MklNativePadWithFusedConv2D";
csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
csinfo_.pad = "Pad";
@ -367,257 +374,241 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.sub = "Sub";
// End - element-wise ops. See note above.
const bool native_fmt = NativeFormatEnabled();
// NOTE: names are alphabetically sorted.
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.add_v2,
mkl_op_registry::GetMklOpName(csinfo_.add_v2),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
kRewriteForLayoutPropagation});
GetRewriteCause()});
rinfo_.push_back(
{csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
{csinfo_.add_v2, mkl_op_registry::GetMklOpName(csinfo_.add_v2),
CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
rinfo_.push_back({csinfo_.avg_pool,
mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.avg_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
rinfo_.push_back(
{csinfo_.avg_pool3d, mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.avg_pool3d,
mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.avg_pool3d_grad,
mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.batch_matmul,
mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.batch_matmul_v2,
mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2),
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.concat, mkl_op_registry::GetMklOpName(csinfo_.concat),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
rinfo_.push_back(
{csinfo_.concatv2, mkl_op_registry::GetMklOpName(csinfo_.concatv2),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.concat,
mkl_op_registry::GetMklOpName(csinfo_.concat),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.concatv2,
mkl_op_registry::GetMklOpName(csinfo_.concatv2),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.conjugate_transpose,
mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.conv2d,
mkl_op_registry::GetMklOpName(csinfo_.conv2d),
rinfo_.push_back(
{csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d),
CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.conv2d_with_bias,
native_fmt ? csinfo_.mkl_native_conv2d_with_bias
: csinfo_.mkl_conv2d_with_bias,
CopyAttrsConvCheckConstFilter, AlwaysRewrite,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
CopyAttrsConvCheckConstFilter, AlwaysRewrite,
kRewriteForLayoutPropagation});
GetRewriteCause()});
rinfo_.push_back({csinfo_.conv2d_grad_filter,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
CopyAttrsConv, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv,
AlwaysRewrite, kRewriteForLayoutPropagation});
AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.conv2d_grad_input,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
CopyAttrsConv, AlwaysRewrite,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.conv3d,
mkl_op_registry::GetMklOpName(csinfo_.conv3d),
CopyAttrsConvCheckConstFilter, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.conv3d, mkl_op_registry::GetMklOpName(csinfo_.conv3d),
CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.conv3d_grad_filter,
mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
CopyAttrsConv, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.conv3d_grad_input,
mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
CopyAttrsConv, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.depthwise_conv2d,
mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d),
CopyAttrsConv2DDepthwiseCheckConstFilter, AlwaysRewrite,
kRewriteForLayoutPropagation});
GetRewriteCause()});
rinfo_.push_back(
{csinfo_.depthwise_conv2d_grad_input,
mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.depthwise_conv2d_grad_filter,
mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
rinfo_.push_back(
{csinfo_.dequantize, mkl_op_registry::GetMklOpName(csinfo_.dequantize),
CopyAttrsAll, DequantizeRewrite, kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.dequantize,
mkl_op_registry::GetMklOpName(csinfo_.dequantize),
CopyAttrsAll, DequantizeRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.fused_batch_norm,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.fused_batch_norm_grad,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.fused_batch_norm_v2,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v2),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.fused_batch_norm_grad_v2,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v2),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
// Using CopyAttrsAll for V3 on CPU, as there are no additional
// attributes.
rinfo_.push_back(
{csinfo_.fused_batch_norm_v3,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.fused_batch_norm_grad_v3,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
#ifdef ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.fused_batch_norm_ex,
csinfo_.mkl_fused_batch_norm_ex, CopyAttrsAll,
FusedBatchNormExRewrite, kRewriteForLayoutPropagation});
native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
: csinfo_.mkl_fused_batch_norm_ex,
CopyAttrsAll, FusedBatchNormExRewrite,
GetRewriteCause()});
#endif
rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
rinfo_.push_back({csinfo_.fused_conv2d,
native_fmt ? csinfo_.mkl_native_fused_conv2d
: csinfo_.mkl_fused_conv2d,
CopyAttrsFusedConv2D, FusedConv2DRewrite,
kRewriteForLayoutPropagation});
GetRewriteCause()});
rinfo_.push_back({csinfo_.fused_depthwise_conv2d,
csinfo_.mkl_fused_depthwise_conv2d, CopyAttrsFusedConv2D,
FusedDepthwiseConv2DRewrite,
kRewriteForLayoutPropagation});
native_fmt ? csinfo_.mkl_native_fused_depthwise_conv2d
: csinfo_.mkl_fused_depthwise_conv2d,
CopyAttrsFusedConv2D, FusedDepthwiseConv2DRewrite,
GetRewriteCause()});
rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul,
CopyAttrsAllCheckConstFilter, FusedMatMulRewrite});
rinfo_.push_back({csinfo_.identity,
mkl_op_registry::GetMklOpName(csinfo_.identity),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
CopyAttrsAll, LrnRewrite, kRewriteForLayoutPropagation});
rinfo_.push_back(
{csinfo_.lrn_grad, mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
CopyAttrsAll, LrnGradRewrite, kRewriteForLayoutPropagation});
{csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity),
CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
CopyAttrsAll, LrnRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.lrn_grad,
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
CopyAttrsAll, LrnGradRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.matmul,
mkl_op_registry::GetMklOpName(csinfo_.matmul),
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.leakyrelu, mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
CopyAttrsAll, LeakyReluRewrite, kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.leakyrelu,
mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.leakyrelu_grad,
mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad),
CopyAttrsAll, LeakyReluRewrite,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.max_pool,
mkl_op_registry::GetMklOpName(csinfo_.max_pool),
CopyAttrsAll, NonDepthBatchWisePoolRewrite,
kRewriteForLayoutPropagation});
CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.max_pool, mkl_op_registry::GetMklOpName(csinfo_.max_pool),
CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.max_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
CopyAttrsAll, MaxpoolGradRewrite,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.max_pool3d,
mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
CopyAttrsAll, NonDepthBatchWisePoolRewrite,
kRewriteForLayoutPropagation});
CopyAttrsAll, MaxpoolGradRewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.max_pool3d, mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.max_pool3d_grad,
mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.maximum,
mkl_op_registry::GetMklOpName(csinfo_.maximum),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum),
CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.pad_with_conv2d, csinfo_.mkl_pad_with_conv2d,
GetRewriteCause()});
rinfo_.push_back({csinfo_.pad_with_conv2d,
native_fmt ? csinfo_.mkl_native_pad_with_conv2d
: csinfo_.mkl_pad_with_conv2d,
CopyAttrsPadWithConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
GetRewriteCause()});
rinfo_.push_back({csinfo_.pad_with_fused_conv2d,
csinfo_.mkl_pad_with_fused_conv2d,
native_fmt ? csinfo_.mkl_native_pad_with_fused_conv2d
: csinfo_.mkl_pad_with_fused_conv2d,
CopyAttrsPadWithFusedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
GetRewriteCause()});
rinfo_.push_back({csinfo_.quantized_avg_pool,
mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.quantized_concatv2,
mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.quantized_conv2d,
mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
GetRewriteCause()});
rinfo_.push_back(
{csinfo_.quantized_conv2d_per_channel,
mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_per_channel),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.quantized_conv2d_with_requantize,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_conv2d_with_requantize),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
GetRewriteCause()});
rinfo_.push_back(
{csinfo_.quantized_conv2d_with_bias,
mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_with_bias),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_requantize,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_conv2d_with_bias_and_requantize),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
GetRewriteCause()});
rinfo_.push_back(
{csinfo_.quantized_conv2d_and_relu,
mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_and_relu),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.quantized_conv2d_and_relu_and_requantize,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_conv2d_and_relu_and_requantize),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
GetRewriteCause()});
rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_relu,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_conv2d_with_bias_and_relu),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
GetRewriteCause()});
rinfo_.push_back(
{csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.quantized_max_pool,
mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_conv2d_with_bias_sum_and_relu),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
GetRewriteCause()});
rinfo_.push_back(
{csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize,
mkl_op_registry::GetMklOpName(
csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.quantized_matmul_with_bias,
mkl_op_registry::GetMklOpName(csinfo_.quantized_matmul_with_bias),
@ -643,72 +634,65 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back(
{csinfo_.quantized_depthwise_conv2d,
mkl_op_registry::GetMklOpName(csinfo_.quantized_depthwise_conv2d),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.quantized_depthwise_conv2d_with_bias,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_depthwise_conv2d_with_bias),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
GetRewriteCause()});
rinfo_.push_back(
{csinfo_.quantized_depthwise_conv2d_with_bias_and_relu,
mkl_op_registry::GetMklOpName(
csinfo_.quantized_depthwise_conv2d_with_bias_and_relu),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize,
mkl_op_registry::GetMklOpName(
csinfo_
.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
kRewriteForLayoutPropagation});
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.quantize_v2,
mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
CopyAttrsAll, QuantizeOpRewrite,
kRewriteForLayoutPropagation});
CopyAttrsAll, QuantizeOpRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
rinfo_.push_back(
{csinfo_.relu_grad, mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
rinfo_.push_back(
{csinfo_.relu6, mkl_op_registry::GetMklOpName(csinfo_.relu6),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
rinfo_.push_back(
{csinfo_.relu6_grad, mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
rinfo_.push_back(
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.relu_grad,
mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.relu6,
mkl_op_registry::GetMklOpName(csinfo_.relu6),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.relu6_grad,
mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.requantize,
mkl_op_registry::GetMklOpName(csinfo_.requantize),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
#ifdef ENABLE_MKLDNN_V1
// Optimized TanhGrad support exists only in DNNL 1.x.
rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
rinfo_.push_back(
{csinfo_.tanh_grad, mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.tanh_grad,
mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
#endif // ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.reshape,
mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.slice,
mkl_op_registry::GetMklOpName(csinfo_.slice),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
kRewriteForLayoutPropagation});
rinfo_.push_back(
{csinfo_.softmax, mkl_op_registry::GetMklOpName(csinfo_.softmax),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
{csinfo_.slice, mkl_op_registry::GetMklOpName(csinfo_.slice),
CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
rinfo_.push_back({csinfo_.softmax,
mkl_op_registry::GetMklOpName(csinfo_.softmax),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.squared_difference,
mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
kRewriteForLayoutPropagation});
GetRewriteCause()});
rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
kRewriteForLayoutPropagation});
GetRewriteCause()});
rinfo_.push_back({csinfo_.transpose,
mkl_op_registry::GetMklOpName(csinfo_.transpose),
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
@ -723,9 +707,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
csinfo_.conv2d_grad_filter_with_bias,
GetConv2DBackpropFilterOrBiasAddGrad});
// Merge Pad and Conv2d, only if the pad op is "Pad"
// Doesn't merge if pad op is "PadV2" or "MirrorPad"
minfo_.push_back(
@ -734,76 +715,82 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
minfo_.push_back({csinfo_.pad, csinfo_.fused_conv2d,
csinfo_.pad_with_fused_conv2d, GetPadOrFusedConv2D});
// The fusion patterns in "finfo_" that show up first will get applied
// first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC,
// A->B->C->D to ABCD}, since the first gets applied first, the final
// graph will be ABC->D.
if (!native_fmt) {
minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
csinfo_.conv2d_grad_filter_with_bias,
GetConv2DBackpropFilterOrBiasAddGrad});
//
// Add rules to fuse sequences such as "Transpose (NCHW -> NHWC) + Conv2D
// (NHWC) + Transpose (NHWC->
// NCHW)" into "Conv2D (NCHW)". Such patterns occur frequently in Keras.
// Note: we use the term "merge" to combine (exactly) 2 nodes into one,
// while "fusion" is for 3+ nodes situation.
//
// The fusion patterns in "finfo_" that show up first will get applied
// first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC,
// A->B->C->D to ABCD}, since the first gets applied first, the final
// graph will be ABC->D.
// Transpose + Conv2d + Transpose:
std::vector<int> transpose_to_nhwc = {NCHW::dim::N, NCHW::dim::H,
NCHW::dim::W, NCHW::dim::C};
std::vector<int> transpose_to_nchw = {NHWC::dim::N, NHWC::dim::C,
NHWC::dim::H, NHWC::dim::W};
auto CheckForTransposeToNHWC =
std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nhwc);
auto CheckForConv2dOp =
std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv2d);
auto CheckForTransposeToNCHW =
std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nchw);
auto FuseConv2D =
std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3, "NCHW");
finfo_.push_back(
{"transpose-elimination for Conv2D",
{CheckForTransposeToNHWC, CheckForConv2dOp, CheckForTransposeToNCHW},
// CheckForMklOp
FuseConv2D,
CopyAttrsConv});
//
// Add rules to fuse sequences such as "Transpose (NCHW -> NHWC) + Conv2D
// (NHWC) + Transpose (NHWC->
// NCHW)" into "Conv2D (NCHW)". Such patterns occur frequently in Keras.
// Note: we use the term "merge" to combine (exactly) 2 nodes into one,
// while "fusion" is for 3+ nodes situation.
//
// Transpose + Conv3d + Transpose:
std::vector<int> transpose_to_ndhwc = {NCDHW::dim::N, NCDHW::dim::D,
NCDHW::dim::H, NCDHW::dim::W,
NCDHW::dim::C};
std::vector<int> transpose_to_ncdhw = {NDHWC::dim::N, NDHWC::dim::C,
NDHWC::dim::D, NDHWC::dim::H,
NDHWC::dim::W};
// Transpose + Conv2d + Transpose:
std::vector<int> transpose_to_nhwc = {NCHW::dim::N, NCHW::dim::H,
NCHW::dim::W, NCHW::dim::C};
std::vector<int> transpose_to_nchw = {NHWC::dim::N, NHWC::dim::C,
NHWC::dim::H, NHWC::dim::W};
auto CheckForTransposeToNHWC = std::bind(
CheckForTranspose, std::placeholders::_1, transpose_to_nhwc);
auto CheckForConv2dOp =
std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv2d);
auto CheckForTransposeToNCHW = std::bind(
CheckForTranspose, std::placeholders::_1, transpose_to_nchw);
auto FuseConv2D =
std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3, "NCHW");
finfo_.push_back(
{"transpose-elimination for Conv2D",
{CheckForTransposeToNHWC, CheckForConv2dOp, CheckForTransposeToNCHW},
// CheckForMklOp
FuseConv2D,
CopyAttrsConv});
auto CheckForTransposeToNDHWC =
std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_ndhwc);
auto CheckForConv3dOp =
std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv3d);
auto CheckForTransposeToNCDHW =
std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_ncdhw);
auto FuseConv3D =
std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3, "NCDHW");
// Transpose + Conv3d + Transpose:
std::vector<int> transpose_to_ndhwc = {NCDHW::dim::N, NCDHW::dim::D,
NCDHW::dim::H, NCDHW::dim::W,
NCDHW::dim::C};
std::vector<int> transpose_to_ncdhw = {NDHWC::dim::N, NDHWC::dim::C,
NDHWC::dim::D, NDHWC::dim::H,
NDHWC::dim::W};
finfo_.push_back(
{"transpose-elimination for Conv3D",
{CheckForTransposeToNDHWC, CheckForConv3dOp, CheckForTransposeToNCDHW},
// CheckForMklOp
FuseConv3D,
CopyAttrsConv});
auto CheckForTransposeToNDHWC = std::bind(
CheckForTranspose, std::placeholders::_1, transpose_to_ndhwc);
auto CheckForConv3dOp =
std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv3d);
auto CheckForTransposeToNCDHW = std::bind(
CheckForTranspose, std::placeholders::_1, transpose_to_ncdhw);
auto FuseConv3D =
std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3, "NCDHW");
auto CheckForMaxPool3DOp =
std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.max_pool3d);
auto FuseMaxPool3D =
std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3, "NCDHW");
finfo_.push_back({"transpose-elimination for MaxPool3D",
{CheckForTransposeToNDHWC, CheckForMaxPool3DOp,
CheckForTransposeToNCDHW},
// CheckForMklOp
FuseMaxPool3D,
CopyAttrsPooling});
finfo_.push_back({"transpose-elimination for Conv3D",
{CheckForTransposeToNDHWC, CheckForConv3dOp,
CheckForTransposeToNCDHW},
// CheckForMklOp
FuseConv3D,
CopyAttrsConv});
auto CheckForMaxPool3DOp =
std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.max_pool3d);
auto FuseMaxPool3D =
std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3, "NCDHW");
finfo_.push_back({"transpose-elimination for MaxPool3D",
{CheckForTransposeToNDHWC, CheckForMaxPool3DOp,
CheckForTransposeToNCDHW},
// CheckForMklOp
FuseMaxPool3D,
CopyAttrsPooling});
}
}
// Standard interface to run pass
@ -824,6 +811,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
/// of ops like MatMul, Transpose, which do not support Mkl layout)
enum RewriteCause { kRewriteForLayoutPropagation, kRewriteForOpNameChange };
// Get the op rewrite cause depending on whether native format mode
// is enabled or not.
RewriteCause GetRewriteCause() {
if (NativeFormatEnabled()) {
return kRewriteForOpNameChange;
} else {
return kRewriteForLayoutPropagation;
}
}
/// Structure to specify the name of an original node, its new name after
/// rewrite, the number of inputs to the original node, the function to
/// be used to copy attributes for the op, and the rule (if any) which
@ -960,6 +957,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string mkl_fused_conv2d;
string mkl_fused_depthwise_conv2d;
string mkl_fused_matmul;
string mkl_native_conv2d_with_bias;
string mkl_native_fused_batch_norm_ex;
string mkl_native_fused_conv2d;
string mkl_native_fused_depthwise_conv2d;
string mkl_native_pad_with_conv2d;
string mkl_native_pad_with_fused_conv2d;
string mkl_pad_with_conv2d;
string mkl_pad_with_fused_conv2d;
string mul;
@ -2403,8 +2406,10 @@ Status MklLayoutRewritePass::CopyInputs(
if (ArgIsList(arg)) {
std::vector<NodeBuilder::NodeOut> new_node_inputs;
int N = GetTensorListLength(arg, old_node);
GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
&new_node_inputs);
if (N != 0) {
GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
&new_node_inputs);
}
nb->Input(new_node_inputs);
} else {
nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
@ -3642,7 +3647,12 @@ Status MklLayoutRewritePass::RewriteNodeForJustOpNameChange(
return s;
}
ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, true);
if (!NativeFormatEnabled()) {
ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, true);
} else {
ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, false);
}
nb.Attr("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
// Finalize graph and get new node.

View File

@ -430,6 +430,11 @@ Status MklToTfConversionPass::Run(const GraphOptimizationPassOptions& options) {
VLOG(2) << "TF-MKL: Disabling MKL";
return Status::OK();
}
if (NativeFormatEnabled()) {
VLOG(2)
<< "Running in native format mode, MklToTfConversionPass won't run.";
return Status::OK();
}
auto process_graph = [&](std::unique_ptr<Graph>* g) {
// Get the ownership of graph

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/util/env_var.h"
namespace tensorflow {
// Since our ops are going to produce and also consume N addition tensors
@ -87,6 +88,31 @@ bool inline DoesControlEdgeExist(const Node* src, const Node* dst) {
return false;
}
// Check if graph should run in layout-dependent mode or native format mode
// based on environment variable setting. User can set
// TF_ENABLE_MKL_NATIVE_FORMAT=1 to enable the native format mode.
bool inline NativeFormatEnabled() {
enum MklGraphMode {
MKL_DEFAULT = 0,
MKL_LAYOUT_DEPENDENT = 1,
MKL_NATIVE_FORMAT = 2
};
static MklGraphMode graph_mode = MKL_DEFAULT;
static absl::once_flag once;
absl::call_once(once, [&] {
bool native_fmt_enabled = false;
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MKL_NATIVE_FORMAT",
/*default_value*/ false,
&native_fmt_enabled));
if (native_fmt_enabled) {
graph_mode = MKL_NATIVE_FORMAT;
} else {
graph_mode = MKL_LAYOUT_DEPENDENT;
}
});
return graph_mode == MKL_NATIVE_FORMAT ? true : false;
}
namespace mkl_op_registry {
// MKL operators whose kernels are registered with 'MklLayoutDependentOp' label
// (e.g., MklConv2D) understand input tensors in MKL layout. These operators
@ -119,10 +145,30 @@ static const char* const kMklEagerOpPrefix = "_MklEager";
// _MklEager prefix.
static const char* const kMklNativeOpPrefix = "_MklNative";
// Get the name of Mkl Native (does not depend on layout propagation) op
// from original TensorFlow op.
inline string GetMklNativeOpName(const string& name) {
// There are few operators that don't depend on layout propagation but are
// prefixed with _Mkl instead of _MklNative.
bool result =
(0 == name.compare("ConjugateTranspose") ||
0 == name.compare("BatchMatMul") || 0 == name.compare("BatchMatMulV2") ||
0 == name.compare("MatMul") || 0 == name.compare("Transpose"));
if (result) {
return string(kMklOpPrefix) + name;
} else {
return string(kMklNativeOpPrefix) + name;
}
}
// Get the name of Mkl op from original TensorFlow op
// We prefix 'Mkl' to the original op to get Mkl op.
// We prefix the original op with _Mkl or _MklNative to get Mkl op.
inline string GetMklOpName(const string& name) {
return string(kMklOpPrefix) + name;
if (!NativeFormatEnabled()) {
return string(kMklOpPrefix) + name;
} else {
return GetMklNativeOpName(name);
}
}
// Get the name of Mkl Eager op from original TensorFlow op
@ -131,12 +177,6 @@ inline string GetMklEagerOpName(const string& name) {
return string(kMklEagerOpPrefix) + name;
}
// Get the name of Mkl Native (does not depend on layout propagation) op
// from original TensorFlow op.
inline string GetMklNativeOpName(const string& name) {
return string(kMklNativeOpPrefix) + name;
}
#ifdef ENABLE_INTEL_MKL_BFLOAT16
static inline bool IsBF16SupportedByOneDNNOnThisCPU() {
return port::TestCPUFeature(port::CPUFeature::AVX512F);

View File

@ -34,6 +34,10 @@ limitations under the License.
#include "third_party/gpus/cudnn/cudnn.h"
#endif // GOOGLE_CUDA
#ifdef INTEL_MKL
#include "tensorflow/core/graph/mkl_graph_util.h"
#endif // INTEL_MKL
namespace tensorflow {
namespace grappler {
@ -805,6 +809,12 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
#ifndef ENABLE_MKLDNN_V1
// We fuse FusedBatchNorm on GPU or MKL CPU.
if (!NodeIsOnGpu(fused_batch_norm_node_def)) return false;
#else
if (NativeFormatEnabled()) {
// Temporarily disable FusedBatchNorm fusion on CPU until
// we support it under native format mode
if (!NodeIsOnGpu(fused_batch_norm_node_def)) return false;
}
#endif
DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm_node_def, "T");

View File

@ -1704,6 +1704,7 @@ REGISTER_OP("_MklNativeConv2D")
.Attr("T: {bfloat16, float}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr("is_filter_const: bool = false")
.Attr(GetPaddingAttrStringWithExplicit())
.Attr(GetExplicitPaddingsAttrString())
.Attr(GetConvnetDataFormatAttrString())