Changes to graph rewrite to support native format
This commit is contained in:
parent
47afb26c2d
commit
ff85b93f52
tensorflow/core
common_runtime
graph
grappler/optimizers
ops
@ -142,11 +142,13 @@ Status MklEagerOpRewrite::SetupNewOp(
|
||||
|
||||
Status MklEagerOpRewrite::CreateGenericMklOp(
|
||||
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_op) {
|
||||
const string mkl_op_name = mkl_op_registry::GetMklOpName(orig_op->Name());
|
||||
const string mkl_op_name =
|
||||
mkl_op_registry::GetMklNativeOpName(orig_op->Name());
|
||||
TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_op));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(mabuzain): Replace this call with above generic one.
|
||||
Status MklEagerOpRewrite::CreateMklConv2DOp(
|
||||
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op) {
|
||||
const string mkl_op_name =
|
||||
|
@ -300,6 +300,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
|
||||
csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative";
|
||||
csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
|
||||
csinfo_.mkl_native_conv2d_with_bias = "_MklNativeConv2DWithBias";
|
||||
csinfo_.mkl_native_fused_batch_norm_ex = "_MklNativeFusedBatchNormEx";
|
||||
csinfo_.mkl_native_fused_conv2d = "_MklNativeFusedConv2D";
|
||||
csinfo_.mkl_native_fused_depthwise_conv2d =
|
||||
"_MklNativeFusedDepthwiseConv2dNative";
|
||||
csinfo_.mkl_native_pad_with_conv2d = "_MklNativePadWithConv2D";
|
||||
csinfo_.mkl_native_pad_with_fused_conv2d = "_MklNativePadWithFusedConv2D";
|
||||
csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
|
||||
csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
|
||||
csinfo_.pad = "Pad";
|
||||
@ -367,257 +374,241 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
csinfo_.sub = "Sub";
|
||||
// End - element-wise ops. See note above.
|
||||
|
||||
const bool native_fmt = NativeFormatEnabled();
|
||||
// NOTE: names are alphabetically sorted.
|
||||
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.add_v2,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.add_v2),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput,
|
||||
kRewriteForLayoutPropagation});
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
{csinfo_.add_v2, mkl_op_registry::GetMklOpName(csinfo_.add_v2),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.avg_pool,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.avg_pool_grad,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.avg_pool3d, mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.avg_pool3d,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.avg_pool3d_grad,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.batch_matmul,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
|
||||
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
|
||||
rinfo_.push_back({csinfo_.batch_matmul_v2,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2),
|
||||
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.concat, mkl_op_registry::GetMklOpName(csinfo_.concat),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.concatv2, mkl_op_registry::GetMklOpName(csinfo_.concatv2),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.concat,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.concat),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.concatv2,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.concatv2),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.conjugate_transpose,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
|
||||
rinfo_.push_back({csinfo_.conv2d,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.conv2d),
|
||||
rinfo_.push_back(
|
||||
{csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d),
|
||||
CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.conv2d_with_bias,
|
||||
native_fmt ? csinfo_.mkl_native_conv2d_with_bias
|
||||
: csinfo_.mkl_conv2d_with_bias,
|
||||
CopyAttrsConvCheckConstFilter, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
|
||||
CopyAttrsConvCheckConstFilter, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.conv2d_grad_filter,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
|
||||
CopyAttrsConv, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
|
||||
csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv,
|
||||
AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.conv2d_grad_input,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
|
||||
CopyAttrsConv, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.conv3d,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.conv3d),
|
||||
CopyAttrsConvCheckConstFilter, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.conv3d, mkl_op_registry::GetMklOpName(csinfo_.conv3d),
|
||||
CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.conv3d_grad_filter,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
|
||||
CopyAttrsConv, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.conv3d_grad_input,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
|
||||
CopyAttrsConv, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.depthwise_conv2d,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d),
|
||||
CopyAttrsConv2DDepthwiseCheckConstFilter, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.depthwise_conv2d_grad_input,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.depthwise_conv2d_grad_filter,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.dequantize, mkl_op_registry::GetMklOpName(csinfo_.dequantize),
|
||||
CopyAttrsAll, DequantizeRewrite, kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.dequantize,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.dequantize),
|
||||
CopyAttrsAll, DequantizeRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.fused_batch_norm,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.fused_batch_norm_grad,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.fused_batch_norm_v2,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v2),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.fused_batch_norm_grad_v2,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v2),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
|
||||
// Using CopyAttrsAll for V3 on CPU, as there are no additional
|
||||
// attributes.
|
||||
rinfo_.push_back(
|
||||
{csinfo_.fused_batch_norm_v3,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.fused_batch_norm_grad_v3,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.fused_batch_norm_ex,
|
||||
csinfo_.mkl_fused_batch_norm_ex, CopyAttrsAll,
|
||||
FusedBatchNormExRewrite, kRewriteForLayoutPropagation});
|
||||
native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
|
||||
: csinfo_.mkl_fused_batch_norm_ex,
|
||||
CopyAttrsAll, FusedBatchNormExRewrite,
|
||||
GetRewriteCause()});
|
||||
#endif
|
||||
rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
|
||||
rinfo_.push_back({csinfo_.fused_conv2d,
|
||||
native_fmt ? csinfo_.mkl_native_fused_conv2d
|
||||
: csinfo_.mkl_fused_conv2d,
|
||||
CopyAttrsFusedConv2D, FusedConv2DRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.fused_depthwise_conv2d,
|
||||
csinfo_.mkl_fused_depthwise_conv2d, CopyAttrsFusedConv2D,
|
||||
FusedDepthwiseConv2DRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
native_fmt ? csinfo_.mkl_native_fused_depthwise_conv2d
|
||||
: csinfo_.mkl_fused_depthwise_conv2d,
|
||||
CopyAttrsFusedConv2D, FusedDepthwiseConv2DRewrite,
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul,
|
||||
CopyAttrsAllCheckConstFilter, FusedMatMulRewrite});
|
||||
|
||||
rinfo_.push_back({csinfo_.identity,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.identity),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
|
||||
CopyAttrsAll, LrnRewrite, kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.lrn_grad, mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
|
||||
CopyAttrsAll, LrnGradRewrite, kRewriteForLayoutPropagation});
|
||||
{csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
|
||||
CopyAttrsAll, LrnRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.lrn_grad,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
|
||||
CopyAttrsAll, LrnGradRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.matmul,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.matmul),
|
||||
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.leakyrelu, mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
|
||||
CopyAttrsAll, LeakyReluRewrite, kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.leakyrelu,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
|
||||
CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.leakyrelu_grad,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad),
|
||||
CopyAttrsAll, LeakyReluRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.max_pool,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.max_pool),
|
||||
CopyAttrsAll, NonDepthBatchWisePoolRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.max_pool, mkl_op_registry::GetMklOpName(csinfo_.max_pool),
|
||||
CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.max_pool_grad,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
|
||||
CopyAttrsAll, MaxpoolGradRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.max_pool3d,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
|
||||
CopyAttrsAll, NonDepthBatchWisePoolRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, MaxpoolGradRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.max_pool3d, mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
|
||||
CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.max_pool3d_grad,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.maximum,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.maximum),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.pad_with_conv2d, csinfo_.mkl_pad_with_conv2d,
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.pad_with_conv2d,
|
||||
native_fmt ? csinfo_.mkl_native_pad_with_conv2d
|
||||
: csinfo_.mkl_pad_with_conv2d,
|
||||
CopyAttrsPadWithConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.pad_with_fused_conv2d,
|
||||
csinfo_.mkl_pad_with_fused_conv2d,
|
||||
native_fmt ? csinfo_.mkl_native_pad_with_fused_conv2d
|
||||
: csinfo_.mkl_pad_with_fused_conv2d,
|
||||
CopyAttrsPadWithFusedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.quantized_avg_pool,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.quantized_concatv2,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.quantized_conv2d,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.quantized_conv2d_per_channel,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_per_channel),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.quantized_conv2d_with_requantize,
|
||||
mkl_op_registry::GetMklOpName(
|
||||
csinfo_.quantized_conv2d_with_requantize),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.quantized_conv2d_with_bias,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_with_bias),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_requantize,
|
||||
mkl_op_registry::GetMklOpName(
|
||||
csinfo_.quantized_conv2d_with_bias_and_requantize),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.quantized_conv2d_and_relu,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_and_relu),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.quantized_conv2d_and_relu_and_requantize,
|
||||
mkl_op_registry::GetMklOpName(
|
||||
csinfo_.quantized_conv2d_and_relu_and_requantize),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_relu,
|
||||
mkl_op_registry::GetMklOpName(
|
||||
csinfo_.quantized_conv2d_with_bias_and_relu),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize,
|
||||
mkl_op_registry::GetMklOpName(
|
||||
csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.quantized_max_pool,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu,
|
||||
mkl_op_registry::GetMklOpName(
|
||||
csinfo_.quantized_conv2d_with_bias_sum_and_relu),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize,
|
||||
mkl_op_registry::GetMklOpName(
|
||||
csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize,
|
||||
mkl_op_registry::GetMklOpName(
|
||||
csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.quantized_matmul_with_bias,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.quantized_matmul_with_bias),
|
||||
@ -643,72 +634,65 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
rinfo_.push_back(
|
||||
{csinfo_.quantized_depthwise_conv2d,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.quantized_depthwise_conv2d),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.quantized_depthwise_conv2d_with_bias,
|
||||
mkl_op_registry::GetMklOpName(
|
||||
csinfo_.quantized_depthwise_conv2d_with_bias),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.quantized_depthwise_conv2d_with_bias_and_relu,
|
||||
mkl_op_registry::GetMklOpName(
|
||||
csinfo_.quantized_depthwise_conv2d_with_bias_and_relu),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize,
|
||||
mkl_op_registry::GetMklOpName(
|
||||
csinfo_
|
||||
.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.quantize_v2,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
|
||||
CopyAttrsAll, QuantizeOpRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, QuantizeOpRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.relu_grad, mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.relu6, mkl_op_registry::GetMklOpName(csinfo_.relu6),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.relu6_grad, mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.relu_grad,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.relu6,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.relu6),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.relu6_grad,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.requantize,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.requantize),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Optimized TanhGrad support exists only in DNNL 1.x.
|
||||
rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.tanh_grad, mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.tanh_grad,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.reshape,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.reshape),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.slice,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.slice),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.softmax, mkl_op_registry::GetMklOpName(csinfo_.softmax),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
{csinfo_.slice, mkl_op_registry::GetMklOpName(csinfo_.slice),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.softmax,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.softmax),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
|
||||
rinfo_.push_back({csinfo_.squared_difference,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput,
|
||||
kRewriteForLayoutPropagation});
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput,
|
||||
kRewriteForLayoutPropagation});
|
||||
GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.transpose,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.transpose),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
|
||||
@ -723,9 +707,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
|
||||
csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
|
||||
|
||||
minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
|
||||
csinfo_.conv2d_grad_filter_with_bias,
|
||||
GetConv2DBackpropFilterOrBiasAddGrad});
|
||||
// Merge Pad and Conv2d, only if the pad op is "Pad"
|
||||
// Doesn't merge if pad op is "PadV2" or "MirrorPad"
|
||||
minfo_.push_back(
|
||||
@ -734,76 +715,82 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
minfo_.push_back({csinfo_.pad, csinfo_.fused_conv2d,
|
||||
csinfo_.pad_with_fused_conv2d, GetPadOrFusedConv2D});
|
||||
|
||||
// The fusion patterns in "finfo_" that show up first will get applied
|
||||
// first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC,
|
||||
// A->B->C->D to ABCD}, since the first gets applied first, the final
|
||||
// graph will be ABC->D.
|
||||
if (!native_fmt) {
|
||||
minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
|
||||
csinfo_.conv2d_grad_filter_with_bias,
|
||||
GetConv2DBackpropFilterOrBiasAddGrad});
|
||||
|
||||
//
|
||||
// Add rules to fuse sequences such as "Transpose (NCHW -> NHWC) + Conv2D
|
||||
// (NHWC) + Transpose (NHWC->
|
||||
// NCHW)" into "Conv2D (NCHW)". Such patterns occur frequently in Keras.
|
||||
// Note: we use the term "merge" to combine (exactly) 2 nodes into one,
|
||||
// while "fusion" is for 3+ nodes situation.
|
||||
//
|
||||
// The fusion patterns in "finfo_" that show up first will get applied
|
||||
// first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC,
|
||||
// A->B->C->D to ABCD}, since the first gets applied first, the final
|
||||
// graph will be ABC->D.
|
||||
|
||||
// Transpose + Conv2d + Transpose:
|
||||
std::vector<int> transpose_to_nhwc = {NCHW::dim::N, NCHW::dim::H,
|
||||
NCHW::dim::W, NCHW::dim::C};
|
||||
std::vector<int> transpose_to_nchw = {NHWC::dim::N, NHWC::dim::C,
|
||||
NHWC::dim::H, NHWC::dim::W};
|
||||
auto CheckForTransposeToNHWC =
|
||||
std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nhwc);
|
||||
auto CheckForConv2dOp =
|
||||
std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv2d);
|
||||
auto CheckForTransposeToNCHW =
|
||||
std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nchw);
|
||||
auto FuseConv2D =
|
||||
std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3, "NCHW");
|
||||
finfo_.push_back(
|
||||
{"transpose-elimination for Conv2D",
|
||||
{CheckForTransposeToNHWC, CheckForConv2dOp, CheckForTransposeToNCHW},
|
||||
// CheckForMklOp
|
||||
FuseConv2D,
|
||||
CopyAttrsConv});
|
||||
//
|
||||
// Add rules to fuse sequences such as "Transpose (NCHW -> NHWC) + Conv2D
|
||||
// (NHWC) + Transpose (NHWC->
|
||||
// NCHW)" into "Conv2D (NCHW)". Such patterns occur frequently in Keras.
|
||||
// Note: we use the term "merge" to combine (exactly) 2 nodes into one,
|
||||
// while "fusion" is for 3+ nodes situation.
|
||||
//
|
||||
|
||||
// Transpose + Conv3d + Transpose:
|
||||
std::vector<int> transpose_to_ndhwc = {NCDHW::dim::N, NCDHW::dim::D,
|
||||
NCDHW::dim::H, NCDHW::dim::W,
|
||||
NCDHW::dim::C};
|
||||
std::vector<int> transpose_to_ncdhw = {NDHWC::dim::N, NDHWC::dim::C,
|
||||
NDHWC::dim::D, NDHWC::dim::H,
|
||||
NDHWC::dim::W};
|
||||
// Transpose + Conv2d + Transpose:
|
||||
std::vector<int> transpose_to_nhwc = {NCHW::dim::N, NCHW::dim::H,
|
||||
NCHW::dim::W, NCHW::dim::C};
|
||||
std::vector<int> transpose_to_nchw = {NHWC::dim::N, NHWC::dim::C,
|
||||
NHWC::dim::H, NHWC::dim::W};
|
||||
auto CheckForTransposeToNHWC = std::bind(
|
||||
CheckForTranspose, std::placeholders::_1, transpose_to_nhwc);
|
||||
auto CheckForConv2dOp =
|
||||
std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv2d);
|
||||
auto CheckForTransposeToNCHW = std::bind(
|
||||
CheckForTranspose, std::placeholders::_1, transpose_to_nchw);
|
||||
auto FuseConv2D =
|
||||
std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3, "NCHW");
|
||||
finfo_.push_back(
|
||||
{"transpose-elimination for Conv2D",
|
||||
{CheckForTransposeToNHWC, CheckForConv2dOp, CheckForTransposeToNCHW},
|
||||
// CheckForMklOp
|
||||
FuseConv2D,
|
||||
CopyAttrsConv});
|
||||
|
||||
auto CheckForTransposeToNDHWC =
|
||||
std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_ndhwc);
|
||||
auto CheckForConv3dOp =
|
||||
std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv3d);
|
||||
auto CheckForTransposeToNCDHW =
|
||||
std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_ncdhw);
|
||||
auto FuseConv3D =
|
||||
std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3, "NCDHW");
|
||||
// Transpose + Conv3d + Transpose:
|
||||
std::vector<int> transpose_to_ndhwc = {NCDHW::dim::N, NCDHW::dim::D,
|
||||
NCDHW::dim::H, NCDHW::dim::W,
|
||||
NCDHW::dim::C};
|
||||
std::vector<int> transpose_to_ncdhw = {NDHWC::dim::N, NDHWC::dim::C,
|
||||
NDHWC::dim::D, NDHWC::dim::H,
|
||||
NDHWC::dim::W};
|
||||
|
||||
finfo_.push_back(
|
||||
{"transpose-elimination for Conv3D",
|
||||
{CheckForTransposeToNDHWC, CheckForConv3dOp, CheckForTransposeToNCDHW},
|
||||
// CheckForMklOp
|
||||
FuseConv3D,
|
||||
CopyAttrsConv});
|
||||
auto CheckForTransposeToNDHWC = std::bind(
|
||||
CheckForTranspose, std::placeholders::_1, transpose_to_ndhwc);
|
||||
auto CheckForConv3dOp =
|
||||
std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv3d);
|
||||
auto CheckForTransposeToNCDHW = std::bind(
|
||||
CheckForTranspose, std::placeholders::_1, transpose_to_ncdhw);
|
||||
auto FuseConv3D =
|
||||
std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3, "NCDHW");
|
||||
|
||||
auto CheckForMaxPool3DOp =
|
||||
std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.max_pool3d);
|
||||
auto FuseMaxPool3D =
|
||||
std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3, "NCDHW");
|
||||
finfo_.push_back({"transpose-elimination for MaxPool3D",
|
||||
{CheckForTransposeToNDHWC, CheckForMaxPool3DOp,
|
||||
CheckForTransposeToNCDHW},
|
||||
// CheckForMklOp
|
||||
FuseMaxPool3D,
|
||||
CopyAttrsPooling});
|
||||
finfo_.push_back({"transpose-elimination for Conv3D",
|
||||
{CheckForTransposeToNDHWC, CheckForConv3dOp,
|
||||
CheckForTransposeToNCDHW},
|
||||
// CheckForMklOp
|
||||
FuseConv3D,
|
||||
CopyAttrsConv});
|
||||
|
||||
auto CheckForMaxPool3DOp =
|
||||
std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.max_pool3d);
|
||||
auto FuseMaxPool3D =
|
||||
std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3, "NCDHW");
|
||||
finfo_.push_back({"transpose-elimination for MaxPool3D",
|
||||
{CheckForTransposeToNDHWC, CheckForMaxPool3DOp,
|
||||
CheckForTransposeToNCDHW},
|
||||
// CheckForMklOp
|
||||
FuseMaxPool3D,
|
||||
CopyAttrsPooling});
|
||||
}
|
||||
}
|
||||
|
||||
// Standard interface to run pass
|
||||
@ -824,6 +811,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
/// of ops like MatMul, Transpose, which do not support Mkl layout)
|
||||
enum RewriteCause { kRewriteForLayoutPropagation, kRewriteForOpNameChange };
|
||||
|
||||
// Get the op rewrite cause depending on whether native format mode
|
||||
// is enabled or not.
|
||||
RewriteCause GetRewriteCause() {
|
||||
if (NativeFormatEnabled()) {
|
||||
return kRewriteForOpNameChange;
|
||||
} else {
|
||||
return kRewriteForLayoutPropagation;
|
||||
}
|
||||
}
|
||||
|
||||
/// Structure to specify the name of an original node, its new name after
|
||||
/// rewrite, the number of inputs to the original node, the function to
|
||||
/// be used to copy attributes for the op, and the rule (if any) which
|
||||
@ -960,6 +957,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
string mkl_fused_conv2d;
|
||||
string mkl_fused_depthwise_conv2d;
|
||||
string mkl_fused_matmul;
|
||||
string mkl_native_conv2d_with_bias;
|
||||
string mkl_native_fused_batch_norm_ex;
|
||||
string mkl_native_fused_conv2d;
|
||||
string mkl_native_fused_depthwise_conv2d;
|
||||
string mkl_native_pad_with_conv2d;
|
||||
string mkl_native_pad_with_fused_conv2d;
|
||||
string mkl_pad_with_conv2d;
|
||||
string mkl_pad_with_fused_conv2d;
|
||||
string mul;
|
||||
@ -2403,8 +2406,10 @@ Status MklLayoutRewritePass::CopyInputs(
|
||||
if (ArgIsList(arg)) {
|
||||
std::vector<NodeBuilder::NodeOut> new_node_inputs;
|
||||
int N = GetTensorListLength(arg, old_node);
|
||||
GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
|
||||
&new_node_inputs);
|
||||
if (N != 0) {
|
||||
GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
|
||||
&new_node_inputs);
|
||||
}
|
||||
nb->Input(new_node_inputs);
|
||||
} else {
|
||||
nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
|
||||
@ -3642,7 +3647,12 @@ Status MklLayoutRewritePass::RewriteNodeForJustOpNameChange(
|
||||
return s;
|
||||
}
|
||||
|
||||
ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, true);
|
||||
if (!NativeFormatEnabled()) {
|
||||
ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, true);
|
||||
} else {
|
||||
ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, false);
|
||||
}
|
||||
|
||||
nb.Attr("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
|
||||
|
||||
// Finalize graph and get new node.
|
||||
|
@ -430,6 +430,11 @@ Status MklToTfConversionPass::Run(const GraphOptimizationPassOptions& options) {
|
||||
VLOG(2) << "TF-MKL: Disabling MKL";
|
||||
return Status::OK();
|
||||
}
|
||||
if (NativeFormatEnabled()) {
|
||||
VLOG(2)
|
||||
<< "Running in native format mode, MklToTfConversionPass won't run.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto process_graph = [&](std::unique_ptr<Graph>* g) {
|
||||
// Get the ownership of graph
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Since our ops are going to produce and also consume N addition tensors
|
||||
@ -87,6 +88,31 @@ bool inline DoesControlEdgeExist(const Node* src, const Node* dst) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if graph should run in layout-dependent mode or native format mode
|
||||
// based on environment variable setting. User can set
|
||||
// TF_ENABLE_MKL_NATIVE_FORMAT=1 to enable the native format mode.
|
||||
bool inline NativeFormatEnabled() {
|
||||
enum MklGraphMode {
|
||||
MKL_DEFAULT = 0,
|
||||
MKL_LAYOUT_DEPENDENT = 1,
|
||||
MKL_NATIVE_FORMAT = 2
|
||||
};
|
||||
static MklGraphMode graph_mode = MKL_DEFAULT;
|
||||
static absl::once_flag once;
|
||||
absl::call_once(once, [&] {
|
||||
bool native_fmt_enabled = false;
|
||||
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MKL_NATIVE_FORMAT",
|
||||
/*default_value*/ false,
|
||||
&native_fmt_enabled));
|
||||
if (native_fmt_enabled) {
|
||||
graph_mode = MKL_NATIVE_FORMAT;
|
||||
} else {
|
||||
graph_mode = MKL_LAYOUT_DEPENDENT;
|
||||
}
|
||||
});
|
||||
return graph_mode == MKL_NATIVE_FORMAT ? true : false;
|
||||
}
|
||||
|
||||
namespace mkl_op_registry {
|
||||
// MKL operators whose kernels are registered with 'MklLayoutDependentOp' label
|
||||
// (e.g., MklConv2D) understand input tensors in MKL layout. These operators
|
||||
@ -119,10 +145,30 @@ static const char* const kMklEagerOpPrefix = "_MklEager";
|
||||
// _MklEager prefix.
|
||||
static const char* const kMklNativeOpPrefix = "_MklNative";
|
||||
|
||||
// Get the name of Mkl Native (does not depend on layout propagation) op
|
||||
// from original TensorFlow op.
|
||||
inline string GetMklNativeOpName(const string& name) {
|
||||
// There are few operators that don't depend on layout propagation but are
|
||||
// prefixed with _Mkl instead of _MklNative.
|
||||
bool result =
|
||||
(0 == name.compare("ConjugateTranspose") ||
|
||||
0 == name.compare("BatchMatMul") || 0 == name.compare("BatchMatMulV2") ||
|
||||
0 == name.compare("MatMul") || 0 == name.compare("Transpose"));
|
||||
if (result) {
|
||||
return string(kMklOpPrefix) + name;
|
||||
} else {
|
||||
return string(kMklNativeOpPrefix) + name;
|
||||
}
|
||||
}
|
||||
|
||||
// Get the name of Mkl op from original TensorFlow op
|
||||
// We prefix 'Mkl' to the original op to get Mkl op.
|
||||
// We prefix the original op with _Mkl or _MklNative to get Mkl op.
|
||||
inline string GetMklOpName(const string& name) {
|
||||
return string(kMklOpPrefix) + name;
|
||||
if (!NativeFormatEnabled()) {
|
||||
return string(kMklOpPrefix) + name;
|
||||
} else {
|
||||
return GetMklNativeOpName(name);
|
||||
}
|
||||
}
|
||||
|
||||
// Get the name of Mkl Eager op from original TensorFlow op
|
||||
@ -131,12 +177,6 @@ inline string GetMklEagerOpName(const string& name) {
|
||||
return string(kMklEagerOpPrefix) + name;
|
||||
}
|
||||
|
||||
// Get the name of Mkl Native (does not depend on layout propagation) op
|
||||
// from original TensorFlow op.
|
||||
inline string GetMklNativeOpName(const string& name) {
|
||||
return string(kMklNativeOpPrefix) + name;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
static inline bool IsBF16SupportedByOneDNNOnThisCPU() {
|
||||
return port::TestCPUFeature(port::CPUFeature::AVX512F);
|
||||
|
@ -34,6 +34,10 @@ limitations under the License.
|
||||
#include "third_party/gpus/cudnn/cudnn.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||
#endif // INTEL_MKL
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
@ -805,6 +809,12 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// We fuse FusedBatchNorm on GPU or MKL CPU.
|
||||
if (!NodeIsOnGpu(fused_batch_norm_node_def)) return false;
|
||||
#else
|
||||
if (NativeFormatEnabled()) {
|
||||
// Temporarily disable FusedBatchNorm fusion on CPU until
|
||||
// we support it under native format mode
|
||||
if (!NodeIsOnGpu(fused_batch_norm_node_def)) return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm_node_def, "T");
|
||||
|
@ -1704,6 +1704,7 @@ REGISTER_OP("_MklNativeConv2D")
|
||||
.Attr("T: {bfloat16, float}")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr("use_cudnn_on_gpu: bool = true")
|
||||
.Attr("is_filter_const: bool = false")
|
||||
.Attr(GetPaddingAttrStringWithExplicit())
|
||||
.Attr(GetExplicitPaddingsAttrString())
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
|
Loading…
Reference in New Issue
Block a user