From 243ac0132001a5b8fa9c1d4435842d0b588b82be Mon Sep 17 00:00:00 2001 From: Vivek Rane Date: Thu, 4 May 2017 16:32:16 -0700 Subject: [PATCH] Added identity op and a fix for LRN (#9641) * relu grad and maxpooling grad fixes for perf * Graph layout pass and conversion pass changes This commit makes following changes: - Enables support for ReluGrad and BiasAddGrad - Adds support for detecting depthwise/batchwise pooling - Adds more unit tests for Graph rewrite pass - Improvements to handling control-flow edges - Bug fixes * Defaulting to Eigen when LRN depth_radius!=2 * Fixed mkl_conv_grad_filter.cc for conv_ops_tests.py * Style fix to mkl_matmul and remove unnecessary 'MKL' label on matmul kernel * Style fixes based on clang-format to mkl_conv_* and mkl_matmul * Bug fixes * Adding OP_REQUIRES_OK check in Concat * Making some style changes * Enabled the configuration of MKL settings * relu grad and maxpooling grad fixes for perf * Graph layout pass and conversion pass changes This commit makes following changes: - Enables support for ReluGrad and BiasAddGrad - Adds support for detecting depthwise/batchwise pooling - Adds more unit tests for Graph rewrite pass - Improvements to handling control-flow edges - Bug fixes * Defaulting to Eigen when LRN depth_radius!=2 * Fixed mkl_conv_grad_filter.cc for conv_ops_tests.py * Style fix to mkl_matmul and remove unnecessary 'MKL' label on matmul kernel * Style fixes based on clang-format to mkl_conv_* and mkl_matmul * Bug fixes * Adding OP_REQUIRES_OK check in Concat * Making some style changes * Enabled the configuration of MKL settings * Fixing graph unit tests with Mkl op name change to _Mkl; Fixed missing _ in MklToTf op * Fixed missing libdl.so.2 in BUILD file * Fixes for unit test build failures. * Changes in mkl_conv_grad_filter_ops.cc for Google code style * Fixes to remove dead code * removed the dead code and added a TODO for mkl implementation to handle this case in the future * Enabling MklIdentityOp * Calling MKL for all values of depth radius in LRN * Fixed buildifier sanity check error * Adding support for google's CI automation * Updated link to new MKL version * Enabling MklIdentityOp * Calling MKL for all values of depth radius in LRN * Fix for missing locate binary * Fix for missing locate command in CI * Adding updatedb to populate the database after installing mlocate * Fixed buildifier issue * setting tf_need_mkl=0 in libtf files * Added third_party/mkl/* to .gitignore * Added third_party/eigen3/mkl_include to .gitignore * In configured, set MKL-enabling options only for Linux. * Enabling MklIdentityOp * Calling MKL for all values of depth radius in LRN * Making style fix in LRN * Fixed Indentation --- tensorflow/core/BUILD | 2 + tensorflow/core/graph/mkl_layout_pass.cc | 16 ++++++ tensorflow/core/kernels/BUILD | 8 +++ tensorflow/core/kernels/mkl_identity_op.cc | 63 ++++++++++++++++++++++ tensorflow/core/kernels/mkl_lrn_op.cc | 30 +++++------ tensorflow/core/kernels/mkl_reshape_op.cc | 2 +- tensorflow/core/ops/array_ops.cc | 17 ++++++ tensorflow/core/util/mkl_util.h | 44 +++++++++++++-- 8 files changed, 160 insertions(+), 22 deletions(-) create mode 100644 tensorflow/core/kernels/mkl_identity_op.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 2d27b5d9d42..4cfdf844ce4 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -736,6 +736,7 @@ cc_library( "//tensorflow/core/kernels:mkl_concat_op", "//tensorflow/core/kernels:mkl_conv_op", "//tensorflow/core/kernels:mkl_fused_batch_norm_op", + "//tensorflow/core/kernels:mkl_identity_op", "//tensorflow/core/kernels:mkl_lrn_op", "//tensorflow/core/kernels:mkl_pooling_ops", "//tensorflow/core/kernels:mkl_relu_op", @@ -2123,6 +2124,7 @@ tf_cc_test_mkl( "//tensorflow/core/kernels:mkl_concat_op", "//tensorflow/core/kernels:mkl_conv_op", "//tensorflow/core/kernels:mkl_fused_batch_norm_op", + "//tensorflow/core/kernels:mkl_identity_op", "//tensorflow/core/kernels:mkl_lrn_op", "//tensorflow/core/kernels:mkl_pooling_ops", "//tensorflow/core/kernels:mkl_relu_op", diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 87388ebe7b6..94741a11ffa 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -273,6 +273,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; csinfo_.fused_batch_norm = "FusedBatchNorm"; csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; + csinfo_.identity = "Identity"; csinfo_.lrn = "LRN"; csinfo_.lrn_grad = "LRNGrad"; csinfo_.matmul = "MatMul"; @@ -326,6 +327,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.fused_batch_norm_grad, GetMklOpName(csinfo_.fused_batch_norm_grad), CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.identity, + GetMklOpName(csinfo_.identity), + CopyAttrsIdentity, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.lrn, GetMklOpName(csinfo_.lrn), CopyAttrsLRN, AlwaysRewrite, nullptr}); @@ -446,6 +450,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string conv2d_grad_filter; string fused_batch_norm; string fused_batch_norm_grad; + string identity; string lrn; string lrn_grad; string matmul; @@ -767,6 +772,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsIdentity(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsRelu(const Node* orig_node, NodeBuilder* nb); @@ -1275,6 +1281,16 @@ void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node, nb->Attr("data_format", data_format); } +void MklLayoutRewritePass::CopyAttrsIdentity(const Node* orig_node, + NodeBuilder* nb) { + DataType T; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + // Add attributes to new node. + nb->Attr("T", T); +} + void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb) { DataType T; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 0847d1279b8..abce506aba2 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4918,6 +4918,14 @@ tf_mkl_kernel_library( ], ) +tf_mkl_kernel_library( + name = "mkl_identity_op", + prefix = "mkl_identity_op", + deps = ARRAY_DEPS + [ + "//third_party/mkl:intel_binary_blob", + ], +) + tf_mkl_kernel_library( name = "mkl_lrn_op", prefix = "mkl_lrn_op", diff --git a/tensorflow/core/kernels/mkl_identity_op.cc b/tensorflow/core/kernels/mkl_identity_op.cc new file mode 100644 index 00000000000..e138cc2e959 --- /dev/null +++ b/tensorflow/core/kernels/mkl_identity_op.cc @@ -0,0 +1,63 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/array_ops.cc. +#ifdef INTEL_MKL + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +#include "tensorflow/core/util/mkl_util.h" +#include "third_party/mkl/include/mkl_dnn.h" +#include "third_party/mkl/include/mkl_dnn_types.h" + +namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; +template +class MklIdentityOp : public OpKernel { + public: + explicit MklIdentityOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + MklShape mkl_shape_input; + GetMklShape(context, 0, &mkl_shape_input); + bool input_in_mkl_format = mkl_shape_input.IsMklTensor(); + + if (input_in_mkl_format) { + ForwarMklTensorInToOut(context, 0, 0); + } else { + FowardTfTensorInToOut(context, 0, 0); + } + } + + bool IsExpensive() override { return false; } +}; + +#define REGISTER_MKL_CPU(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklIdentity") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklIdentityOp); \ + +TF_CALL_float(REGISTER_MKL_CPU); +#undef REGISTER_MKL_CPU +} // namespace tensorflow +#endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc index 9d050d430ae..ac432e13ced 100644 --- a/tensorflow/core/kernels/mkl_lrn_op.cc +++ b/tensorflow/core/kernels/mkl_lrn_op.cc @@ -171,9 +171,7 @@ class MklLRNOp : public OpKernel { MklShape input_shape; dnnPrimitive_t lrn_fwd = nullptr; dnnPrimitive_t convert_input = nullptr; - /* dnnPrimitive_t convert_output; */ dnnLayout_t lt_input = nullptr; - /* dnnLayout_t lt_output; */ dnnLayout_t lt_internal_input = nullptr; dnnLayout_t lt_internal_workspace = nullptr; dnnLayout_t lt_internal_output = nullptr; @@ -390,11 +388,6 @@ class MklLRNGradOp : public OpKernel { return; } - if (depth_radius_ != 2) { - mkl_context.MklDefaultToEigen(context); - return; - } - if (ingrad_in_mkl_format || inimage_in_mkl_format) { const MklShape* tmp_mkl_shape = (ingrad_in_mkl_format) ? &mkl_context.ingrad_shape @@ -476,11 +469,11 @@ class MklLRNGradOp : public OpKernel { const_cast(static_cast(output->flat().data())); Tensor mkl_tmp_input_buf_tensor, mkl_tmp_image_buf_tensor, - mkl_tmp_outimage_buf_tensor, mkl_tmp_workspace_buf_tensor; + mkl_tmp_outimage_buf_tensor; // Convert Inputs if needed - mkl_context.MklPrepareLRNGradInput( - context, &mkl_tmp_input_buf_tensor, &mkl_tmp_image_buf_tensor, - &mkl_tmp_outimage_buf_tensor, &mkl_tmp_workspace_buf_tensor); + mkl_context.MklPrepareLRNGradInput(context, &mkl_tmp_input_buf_tensor, + &mkl_tmp_image_buf_tensor, + &mkl_tmp_outimage_buf_tensor); // We do not do any conversion for output. But we simply emit it // in MKL format. @@ -537,11 +530,13 @@ class MklLRNGradOp : public OpKernel { void MklPrepareLRNGradInput(OpKernelContext* context, Tensor* mkl_tmp_input_buf_tensor, Tensor* mkl_tmp_image_buf_tensor, - Tensor* mkl_tmp_outimage_buf_tensor, - Tensor* mkl_tmp_workspace_buf_tensor) { + Tensor* mkl_tmp_outimage_buf_tensor) { const Tensor& in_grads = MklGetInput(context, 0); const Tensor& in_image = MklGetInput(context, 1); const Tensor& out_image = MklGetInput(context, 2); + const Tensor& workspace = MklGetInput( + context, + 3); /*Worskpsace is enabled, get the buffer to the workspace */ void* user_input = const_cast( static_cast(in_grads.flat().data())); @@ -549,6 +544,9 @@ class MklLRNGradOp : public OpKernel { static_cast(in_image.flat().data())); void* user_fwd_output = const_cast( static_cast(out_image.flat().data())); + void* workspace_buffer = const_cast( + static_cast(workspace.flat().data())); + CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_workspace, lrn_bwd, dnnResourceWorkspace), E_SUCCESS); @@ -623,9 +621,7 @@ class MklLRNGradOp : public OpKernel { res_lrn_bwd[dnnResourceDst] = user_fwd_output; } - // Allocate buffer for workspace. - AllocTmpBuffer(context, mkl_tmp_workspace_buf_tensor, lt_workspace, - &res_lrn_bwd[dnnResourceWorkspace]); + res_lrn_bwd[dnnResourceWorkspace] = workspace_buffer; } // Fallback implementation - Taken from lrn_op.cc @@ -713,7 +709,7 @@ class MklLRNGradOp : public OpKernel { Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch, depth * depth, shard); } - + // release mkl resources void Mklcleanup() { bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor(); diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc index 753a8b52b42..593aa3a2fd6 100644 --- a/tensorflow/core/kernels/mkl_reshape_op.cc +++ b/tensorflow/core/kernels/mkl_reshape_op.cc @@ -129,7 +129,7 @@ class MklReshapeOp : public OpKernel { return; } } else { - CopyTFTensorInToOut(context, 0, 0, shape); + CopyTfTensorInToOutWithShape(context, 0, 0, shape); } } }; diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index e528ae47aa7..7b44ff1918d 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1506,6 +1506,23 @@ REGISTER_OP("Identity") Return a tensor with the same shape and contents as the input tensor or value. )Doc"); +#ifdef INTEL_MKL +REGISTER_OP("_MklIdentity") + .Input("input: T") + .Input("mkl_input: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("T: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + c->set_output_handle_dtype(0, c->input_handle_dtype(0)); + c->set_output_handle_shape(0, c->input_handle_shape(0)); + return Status::OK(); + }) + .Doc(R"Doc( Mkl implementation of IdentityOp +)Doc"); +#endif + // -------------------------------------------------------------------------- REGISTER_OP("RefIdentity") .Input("input: Ref(T)") diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 897b174eff2..6a37256ea9f 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -542,8 +542,8 @@ inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) { return mkl_shape.dim_size(index); } -inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in, - int idx_out) { +inline void CopyMklTensorInToOut(OpKernelContext* context, + int idx_in, int idx_out) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); @@ -563,8 +563,9 @@ inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in, context->set_output(idx_meta_out, meta_output); } -inline void CopyTFTensorInToOut(OpKernelContext* context, int idx_in, - int idx_out, const TensorShape& shape) { +inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, + int idx_in, int idx_out, + const TensorShape& shape) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); @@ -580,6 +581,41 @@ inline void CopyTFTensorInToOut(OpKernelContext* context, int idx_in, context->set_output(idx_data_out, output); } +inline void FowardTfTensorInToOut(OpKernelContext* context, + int idx_in, int idx_out) { + int num_inputs = context->num_inputs(); + int num_outputs = context->num_outputs(); + int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); + int idx_data_out = GetTensorDataIndex(idx_out, num_outputs); + + MklShape mkl_shape_output; + mkl_shape_output.SetMklTensor(false); + AllocateOutputSetMklShape(context, idx_out, mkl_shape_output); + if (IsRefType(context->input_dtype(idx_data_in))) { + context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out); + } else { + context->set_output(idx_data_out, context->input(idx_data_in)); + } +} + +inline void ForwarMklTensorInToOut(OpKernelContext* context, + int idx_in, int idx_out) { + int num_inputs = context->num_inputs(); + int num_outputs = context->num_outputs(); + int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); + int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs); + int idx_data_out = GetTensorDataIndex(idx_out, num_outputs); + int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs); + + if (IsRefType(context->input_dtype(idx_data_in))) { + context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out); + context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out); + } else { + context->set_output(idx_data_out, context->input(idx_data_in)); + context->set_output(idx_meta_out, context->input(idx_meta_in)); + } +} + namespace mkl_op_registry { static const char* kMklOpLabel = "MklOp"; static const char* kMklOpLabelPattern = "label='MklOp'";