diff --git a/.gitignore b/.gitignore index d8ecef1e1e7..be0a287a8e3 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ node_modules /.bazelrc /.tf_configure.bazelrc /bazel-* +/third_party/eigen3/mkl_include +/third_party/mkl/* /third_party/py/numpy/numpy_include /tools/python_bin_path.sh /tools/git/gen diff --git a/configure b/configure index 15715e5ef21..4104651cbbb 100755 --- a/configure +++ b/configure @@ -180,25 +180,35 @@ fi setup_python ## Set up MKL related environment settings -if false; then # Disable building with MKL for now - while [ "$TF_NEED_MKL" == "" ]; do - fromuser="" - read -p "Do you wish to build TensorFlow with MKL support (experimental)? [y/N] " INPUT - fromuser="1" - case $INPUT in - [Yy]* ) echo "MKL support (experimental) (will be enabled for TensorFlow"; TF_NEED_MKL=1;; - [Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; - "" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; - * ) echo "Invalid selection: " $INPUT;; - esac - done +while [ "$TF_NEED_MKL" == "" ]; do + fromuser="" + read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT + fromuser="1" + case $INPUT in + [Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;; + [Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; + "" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; + * ) echo "Invalid selection: " $INPUT;; + esac +done - OSNAME=`uname -s` +OSNAME=`uname -s` - if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL +if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL + fromuser="" + read -p "Do you wish to download MKL LIB from the web? [Y/n] " INPUT + fromuser="1" + case $INPUT in + [Yy]* ) TF_DOWNLOAD_MKL=1;; + [Nn]* ) TF_DOWNLOAD_MKL=0;; + "" ) TF_DOWNLOAD_MKL=1;; + * ) echo "Invalid selection: " $INPUT; exit 1;; + esac + + if [[ "$TF_DOWNLOAD_MKL" == "1" ]]; then DST=`dirname $0` - ARCHIVE_BASENAME=mklml_lnx_2017.0.2.20170209.tgz - GITHUB_RELEASE_TAG=v0.5 + ARCHIVE_BASENAME=mklml_lnx_2018.0.20170425.tgz + GITHUB_RELEASE_TAG=v0.7 MKLURL="https://github.com/01org/mkl-dnn/releases/download/$GITHUB_RELEASE_TAG/$ARCHIVE_BASENAME" if ! [ -e "$DST/third_party/mkl/$ARCHIVE_BASENAME" ]; then wget --no-check-certificate -P $DST/third_party/mkl/ $MKLURL @@ -208,7 +218,20 @@ if false; then # Disable building with MKL for now MKL_INSTALL_PATH=$DST/third_party/mkl/$extracted_dir_name MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"` - if [ "$OSNAME" == "Linux" ]; then + else + default_mkl_path=/opt/intel/mklml + fromuser="" + read -p "Please specify the location where MKL is installed. [Default is $default_mkl_path]: " MKL_INSTALL_PATH + fromuser="1" + if [ -z "$MKL_INSTALL_PATH" ]; then + MKL_INSTALL_PATH=$default_mkl_path + fi + # Result returned from "read" will be used unexpanded. That make "~" unuseable. + # Going through one more level of expansion to handle that. + MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"` + fi + + if [ "$OSNAME" == "Linux" ]; then # Full MKL configuration MKL_RT_LIB_PATH="lib/intel64/libmkl_rt.so" #${TF_MKL_EXT}#TODO version? MKL_RT_OMP_LIB_PATH="../compiler/lib/intel64/libiomp5.so" #TODO VERSION? @@ -216,24 +239,33 @@ if false; then # Disable building with MKL for now # MKL-ML configuration MKL_ML_LIB_PATH="lib/libmklml_intel.so" #${TF_MKL_EXT}#TODO version? MKL_ML_OMP_LIB_PATH="lib/libiomp5.so" #TODO VERSION? - elif [ "$OSNAME" == "Darwin" ]; then + elif [ "$OSNAME" == "Darwin" ]; then echo "Darwin is unsupported yet"; exit 1 - fi + fi - if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then + if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then ln -sf $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} third_party/mkl/ ln -sf $MKL_INSTALL_PATH/${MKL_ML_OMP_LIB_PATH} third_party/mkl/ ln -sf $MKL_INSTALL_PATH/include third_party/mkl/ ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include - else - echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} does not exist"; + loc=$(locate -e libdl.so.2 | sed -n 1p) + ln -sf $loc third_party/mkl/libdl.so.2 + elif [ -e "$MKL_INSTALL_PATH/${MKL_RT_LIB_PATH}" ]; then + ln -sf $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} third_party/mkl/ + ln -sf $MKL_INSTALL_PATH/${MKL_RT_OMP_LIB_PATH} third_party/mkl/ + ln -sf $MKL_INSTALL_PATH/include third_party/mkl/ + ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include + loc=$(locate -e libdl.so.2 | sed -n 1p) + ln -sf $loc third_party/mkl/libdl.so.2 + else + echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} nor $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} exists"; exit 1 - fi + fi - if [ -z "$fromuser" ]; then + if [ -z "$fromuser" ]; then exit 1 - fi + fi cat > third_party/mkl/mkl.config < third_party/mkl/mkl.config < #include +#include #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h" #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h" diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 119bc0f8997..2d27b5d9d42 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1282,7 +1282,10 @@ cc_library( ] + tf_additional_verbs_lib_defines(), linkopts = select({ "//tensorflow:freebsd": [], - "//conditions:default": ["-ldl"], + "//conditions:default": [ + "-ldl", + "-lpthread", + ], }), deps = tf_additional_lib_deps() + [ ":lib_hash_crc32c_accelerate_internal", diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 09b632a1650..87388ebe7b6 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/graph/mkl_layout_pass.h" #include "tensorflow/core/util/mkl_util.h" @@ -280,51 +281,72 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.mkl_conv2d = "_MklConv2D"; csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; csinfo_.mkl_conv2d_with_bias_backprop_bias = - "_MklConv2DWithBiasBackpropBias"; - csinfo_.relu = "Relu"; - csinfo_.reshape = "Reshape"; - csinfo_.relu_grad = "ReluGrad"; - csinfo_.split = "Split"; + "_MklConv2DWithBiasBackpropBias"; + csinfo_.relu = "Relu"; + csinfo_.relu_grad = "ReluGrad"; + csinfo_.reshape = "Reshape"; + csinfo_.split = "Split"; // NOTE: names are alphabetically sorted. - rinfo_.push_back({csinfo_.avg_pool, GetMklOpName(csinfo_.avg_pool), 1, - CopyAttrsPooling, AlwaysRewrite}); + rinfo_.push_back({csinfo_.avg_pool, + GetMklOpName(csinfo_.avg_pool), + CopyAttrsPooling, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.avg_pool_grad, - GetMklOpName(csinfo_.avg_pool_grad), 2, CopyAttrsPooling, - AlwaysRewrite}); - rinfo_.push_back({csinfo_.concat, GetMklOpName(csinfo_.concat), 0, - CopyAttrsConcat, AlwaysRewrite}); - rinfo_.push_back({csinfo_.concatv2, GetMklOpName(csinfo_.concatv2), 0, - CopyAttrsConcatV2, AlwaysRewrite}); - rinfo_.push_back({csinfo_.conv2d, GetMklOpName(csinfo_.conv2d), 2, - CopyAttrsConv2D, AlwaysRewrite}); + GetMklOpName(csinfo_.avg_pool_grad), + CopyAttrsPooling, AlwaysRewrite, nullptr}); + // BiasAddGrad gets written into Conv2DWithBiasBackpropBias depending + // on if context contains Conv2D. + rinfo_.push_back({csinfo_.bias_add_grad, + csinfo_.mkl_conv2d_with_bias_backprop_bias, + CopyAttrsBiasAddGrad, ContextMatchRewrite, + &biasaddgrad_conv2dwithbias_context_}); + // BiasAddGrad gets written into BiasAddGrad depending on if context + // contains MatMul. + rinfo_.push_back({csinfo_.bias_add_grad, csinfo_.matmul, + CopyAttrsBiasAddGrad, ContextMatchRewrite, + &biasaddgrad_matmul_context_}); + rinfo_.push_back({csinfo_.concat, + GetMklOpName(csinfo_.concat), + CopyAttrsConcat, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.concatv2, + GetMklOpName(csinfo_.concatv2), + CopyAttrsConcatV2, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.conv2d, + GetMklOpName(csinfo_.conv2d), + CopyAttrsConv2D, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.conv2d_grad_filter, - GetMklOpName(csinfo_.conv2d_grad_filter), 3, - CopyAttrsConv2D, AlwaysRewrite}); + GetMklOpName(csinfo_.conv2d_grad_filter), + CopyAttrsConv2D, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.conv2d_grad_input, - GetMklOpName(csinfo_.conv2d_grad_input), 3, - CopyAttrsConv2D, AlwaysRewrite}); + GetMklOpName(csinfo_.conv2d_grad_input), + CopyAttrsConv2D, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.fused_batch_norm, - GetMklOpName(csinfo_.fused_batch_norm), 5, - CopyAttrsFusedBatchNorm, AlwaysRewrite}); + GetMklOpName(csinfo_.fused_batch_norm), + CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.fused_batch_norm_grad, - GetMklOpName(csinfo_.fused_batch_norm_grad), 5, - CopyAttrsFusedBatchNorm, AlwaysRewrite}); - rinfo_.push_back({csinfo_.lrn, GetMklOpName(csinfo_.lrn), 1, CopyAttrsLRN, - AlwaysRewrite}); - rinfo_.push_back({csinfo_.lrn_grad, GetMklOpName(csinfo_.lrn_grad), 3, - CopyAttrsLRN, AlwaysRewrite}); - rinfo_.push_back({csinfo_.max_pool, GetMklOpName(csinfo_.max_pool), 1, - CopyAttrsPooling, AlwaysRewrite}); + GetMklOpName(csinfo_.fused_batch_norm_grad), + CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.lrn, + GetMklOpName(csinfo_.lrn), + CopyAttrsLRN, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.lrn_grad, + GetMklOpName(csinfo_.lrn_grad), + CopyAttrsLRN, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.max_pool, + GetMklOpName(csinfo_.max_pool), + CopyAttrsPooling, NonDepthBatchWisePoolRewrite, nullptr}); rinfo_.push_back({csinfo_.max_pool_grad, - GetMklOpName(csinfo_.max_pool_grad), 3, CopyAttrsPooling, - AlwaysRewrite}); - rinfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu), 1, - CopyAttrsRelu, AlwaysRewrite}); - rinfo_.push_back({csinfo_.reshape, GetMklOpName(csinfo_.reshape), 2, - CopyAttrsReshape, AlwaysRewrite}); - - // TODO(inteltf): we do not support ReluGrad and BiasAddGrad yet. + GetMklOpName(csinfo_.max_pool_grad), + CopyAttrsPooling, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.relu, + GetMklOpName(csinfo_.relu), + CopyAttrsRelu, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.relu_grad, + GetMklOpName(csinfo_.relu_grad), + CopyAttrsRelu, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.reshape, + GetMklOpName(csinfo_.reshape), + CopyAttrsReshape, AlwaysRewrite, nullptr}); // Add info about which ops to add workspace edge to and the slots. wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); @@ -338,8 +360,15 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // maxhops in backward data-flow graph. Since input of forward nodes // (Conv2D) directly goes to backward nodes, we do not expect the // hop-distance would be more than few nodes. - cinfo_.push_back({csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias, - kNodeMergeContextMaxDepth}); + biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul, + kNodeMergeContextMaxDepth}; + + biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad, + csinfo_.mkl_conv2d_with_bias, + kNodeMergeContextMaxDepth}; + + cinfo_.push_back(&biasaddgrad_matmul_context_); + cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_); } // Standard interface to run pass @@ -354,7 +383,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // @return true, if and only if graph is mutated; false otherwise. bool RunPass(std::unique_ptr* g); - private: + /// Structure to specify the context information used in a node rewrite rule + typedef struct { + string node; // Name of the node to be rewritten + string fwd; // Name of the node in the forward pass that this node + // corresponds to + size_t max_hop; // Maximum number of hops the fwd is located + // from this node. If the fwd is farther than max_hop + // then we do not rewrite the node. + } ContextInfo; + /// Structure to specify the name of an original node, its new name after /// rewrite, the number of inputs to the original node, the function to /// be used to copy attributes for the op, and the rule (if any) which @@ -362,11 +400,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass { typedef struct { string name; // Original name of op of the node in the graph string new_name; // New name of the op of the node in the graph - int num_ins; // The number of inputs to the original op type // A function handler to copy attributes from an old node to a new node. std::function copy_attrs; - std::function rewrite_rule; // A rule under which to - // rewrite this node. + // A rule under which to rewrite this node + std::function rewrite_rule; + // ContextInfo, if any, to be used for rewrite + ContextInfo* context; } RewriteInfo; /// Structure to specify a forward op, a backward op, and the slot numbers @@ -393,16 +432,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string new_node; // Name of the node after merge } MergeInfo; - /// Structure to specify the context information used in a node rewrite rule - typedef struct { - string node; // Name of the node to be rewritten - string fwd; // Name of the node in the forward pass that this node - // corresponds to - size_t max_hop; // Maximum number of hops the fwd is located - // from this node. If the fwd is farther than max_hop - // then we do not rewrite the node. - } ContextInfo; - /// Structure to store all constant strings /// NOTE: names are alphabetically sorted. struct { @@ -427,10 +456,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string mkl_conv2d_with_bias_backprop_bias; string relu; string relu_grad; - string split; string reshape; + string split; } csinfo_; + private: /// Maintain info about nodes to rewrite std::vector rinfo_; @@ -441,7 +471,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { std::vector minfo_; /// Maintain info about nodes to rewrite - static std::vector cinfo_; + static std::vector cinfo_; + + /// Context variables used in referencing rules + static ContextInfo biasaddgrad_matmul_context_; + static ContextInfo biasaddgrad_conv2dwithbias_context_; /// Hash table to maintain nodes visited in the graph. std::unordered_set visited_nodes_; @@ -464,19 +498,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // Clear all visited nodes inline void UnMarkRewrittenNodes() { visited_nodes_.clear(); } - // Is this a graph node that can accept variable number of inputs? - // Return true if yes, false otherwise. - // - // Concat, Split are vararg nodes. - inline bool IsVarArgNode(Node* n) { - if (n->type_string() == csinfo_.concat || - n->type_string() == csinfo_.concatv2 || - n->type_string() == csinfo_.split) { - return true; - } - return false; - } - // Is OpDef::ArgDef a list type? It could be N * T or list(type). // Refer to opdef.proto for details of list type. inline bool ArgIsList(const OpDef::ArgDef& arg) const { @@ -510,6 +531,39 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return string(kMklOpPrefix) + name; } + // Can op represented by node 'n' run on DEVICE_CPU? + // Op can run on CPU with MKL if the runtime assigned device or the + // user requested device contains device CPU, or both are empty. + bool CanOpRunOnCPUDevice(const Node* n) { + bool result = true; + string reason; + + // Substring that should be checked for in device name for CPU device. + const char* const kCPUDeviceSubStr = "cpu"; + + // If Op has been specifically assigned to a non-CPU device, then No. + if (!n->assigned_device_name().empty() && + !StringPiece(n->assigned_device_name()).contains(kCPUDeviceSubStr)) { + result = false; + reason = "Op has been assigned a runtime device that is not CPU."; + } + + // If user has specifically assigned this op to a non-CPU device, then No. + if (!n->def().device().empty() && + !StringPiece(n->def().device()).contains(kCPUDeviceSubStr)) { + result = false; + reason = "User has assigned a device that is not CPU."; + } + + if (result == false) { + VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node " + << n->type_string() << ", reason: " << reason; + } + + // Otherwise Yes. + return result; + } + // Return a node that can be merged with input node 'n' // // @return pointer to the node if we can find such a @@ -538,13 +592,46 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // Default rewrite rule to be used in scenario 1 for rewrite. // @return - true (since we want to always rewrite) - static bool AlwaysRewrite(const Node* n) { return true; } - // Rewrite rule that uses context-information for matching + static bool AlwaysRewrite(const Node* n, const ContextInfo* c = nullptr) { + return true; + } + + // Check if we are performing pooling on depth or batch. If it is, then we + // do not rewrite MaxPool node to Mkl version. + // @return - true (if it is not a depth/batch wise pooling case); + // false otherwise. + static bool NonDepthBatchWisePoolRewrite(const Node* n, + const ContextInfo* c) { + CHECK_NOTNULL(n); + + string data_format_str; + TensorFormat data_format; + std::vector ksize, strides; + CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); + CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); + CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), + true); + CHECK_EQ(FormatFromString(data_format_str, &data_format), true); + + // Condition that specifies non-batch-wise and non-depth-wise pooling. + if (GetTensorDim(ksize, data_format, 'N') == 1 && + GetTensorDim(strides, data_format, 'N') == 1 && + GetTensorDim(ksize, data_format, 'C') == 1 && + GetTensorDim(strides, data_format, 'C') == 1) { + return true; + } + + return false; + } + + // Rewrite rule that uses context-information for matching, // used in scenario 2. // // @input - Node 'n' for which to search for matching context - // @return - true if matching context is found; false otherwise. - static bool ContextMatchRewrite(const Node* n); + // @input - The context 'c' under which to rewrite + // @return - true if we can rewrite node under context 'c'; + // false otherwise. + static bool ContextMatchRewrite(const Node* n, const ContextInfo* c); // Helper function that searches the matching contextinfo for the node. // Implements depth-first search in the data dependence graph for the @@ -598,6 +685,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // node that we are constructing. // // @input g - input graph, + // @input orig_node - Original node that we are rewriting // @input inputs - inputs to old node that we are using for constructing // new inputs, // @input input_idx - the index in the 'inputs' vector pointing to the @@ -608,11 +696,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // @output output_nodes - the list of new nodes creating Mkl tensors // // @return None - void GetNodesProducingMklTensorList( - std::unique_ptr* g, - const gtl::InlinedVector, 4>& inputs, - int* input_idx, int list_length, - std::vector* output_nodes); + void GetNodesProducingMklTensorList(std::unique_ptr* g, + Node* orig_node, const gtl::InlinedVector, 4>& inputs, + int* input_idx, int list_length, + std::vector* output_nodes); // Get a node that will feed an Mkl tensor to the new // node that we are constructing. The output node could be (1) 'n' @@ -620,6 +707,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // if 'n' is not an Mkl layer. // // @input g - input graph, + // @input orig_node - Original node that we are rewriting, // @input n - Node based on which we are creating Mkl node, // @input n_output_slot - the output slot of node 'n' // which is feeding to the node that we are constructing @@ -627,9 +715,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // @output mkl_node_output_slot - the slot number of mkl_node that // will feed the tensor // @return None - void GetNodeProducingMklTensor(std::unique_ptr* g, Node* n, - int n_output_slot, Node** mkl_node, - int* mkl_node_output_slot); + void GetNodeProducingMklTensor(std::unique_ptr* g, Node* orig_node, + Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot); // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are @@ -695,13 +782,18 @@ class MklLayoutRewritePass : public GraphOptimizationPass { Node* orig_node); }; -std::vector MklLayoutRewritePass::cinfo_; +MklLayoutRewritePass::ContextInfo + MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_; +MklLayoutRewritePass::ContextInfo + MklLayoutRewritePass::biasaddgrad_matmul_context_; +std::vector MklLayoutRewritePass::cinfo_; -// We register Mkl rewrite pass for phase 1 in post rewrite group. +// We register Mkl rewrite pass for phase 1 in post partitioning group. // We register it here so that we get a complete picture of all users of Mkl // nodes. Do not change the ordering of the Mkl passes. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 1, - MklLayoutRewritePass); +const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = + OptimizationPassRegistry::POST_PARTITIONING; +REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass); ////////////////////////////////////////////////////////////////////////// // Helper functions for creating new node @@ -737,27 +829,14 @@ void MklLayoutRewritePass::GetNodesProducingTFTensorList( while (list_length != 0) { CHECK_GT(list_length, 0); - CHECK_LE(*input_idx, inputs.size()); + CHECK_LT(*input_idx, inputs.size()); Node* n = inputs[*input_idx].first; int slot = inputs[*input_idx].second; - const OpDef::ArgDef& arg = n->op_def().output_arg(slot); - // If input node 'n' is producing a list/array output at output - // slot 'slot' then we need to find out the length of that list/array. - if (ArgIsList(arg)) { - int N = GetTensorListLength(arg, n); - CHECK_LE(N, list_length); - for (int j = 0; j < N; j++) { - output_nodes->push_back(NodeBuilder::NodeOut(n, slot)); - } - (*input_idx)++; - list_length -= N; - } else { - // But if input node 'n' is just producing a single tensor at - // output slot 'slot' then we just add that single node. - output_nodes->push_back(NodeBuilder::NodeOut(n, slot)); - (*input_idx)++; - list_length--; - } + // If input node 'n' is just producing a single tensor at + // output slot 'slot' then we just add that single node. + output_nodes->push_back(NodeBuilder::NodeOut(n, slot)); + (*input_idx)++; + list_length--; } } @@ -775,20 +854,39 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr* g, TensorShape dummy_shape({8}); dummy_shape.AsProto(proto.mutable_tensor_shape()); TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // the same device as the - // device of the original - // node. - .Finalize(&**g, out)); + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // the same device as the + // device of the original + // node. + .Finalize(&**g, out)); + + // If number of inputs to the original node is > 0, then we add + // control dependency between 1st input (index 0) of the original node and + // the dummy Mkl node. This is needed because control-flow ops such as Enter, + // Merge, etc, require frame_name of the dummy Mkl node to be same as the + // rewritten node. Adding control edge between 1st input of the original node + // and the dummy Mkl node ensures that the dummy node is in the same frame + // as the original node. Choosing 1st input is not necessary - any input of + // the original node is fine because all the inputs of a node are always in + // the same frame. + if (orig_node->num_inputs() > 0) { + Node* orig_input0 = nullptr; + TF_CHECK_OK(orig_node->input_node(0, + const_cast(&orig_input0))); + CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); + } + (*out)->set_assigned_device_name(orig_node->assigned_device_name()); } void MklLayoutRewritePass::GetNodesProducingMklTensorList( std::unique_ptr* g, - const gtl::InlinedVector, 4>& inputs, int* input_idx, - int list_length, std::vector* output_nodes) { + Node* orig_node, + const gtl::InlinedVector, 4>& inputs, + int* input_idx, int list_length, + std::vector* output_nodes) { CHECK_LT(*input_idx, inputs.size()); CHECK_GT(list_length, 0); CHECK_NOTNULL(output_nodes); @@ -796,38 +894,19 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( while (list_length != 0) { CHECK_GT(list_length, 0); - CHECK_LE(*input_idx, inputs.size()); + CHECK_LT(*input_idx, inputs.size()); Node* n = inputs[*input_idx].first; int slot = inputs[*input_idx].second; - const OpDef::ArgDef& arg = n->op_def().output_arg(slot); - // We need to check first if the input edge is going to carry a - // single tensor or a list of tensors. If it is a list of tensors, - // then we need to create list of Mkl dummy nodes. - if (ArgIsList(arg)) { - // If input node 'n' is producing a list/array output at output - // slot 'slot' then we need to find out the length of that list/array. - int N = GetTensorListLength(arg, n); - CHECK_LE(N, list_length); - Node* mkl_node = nullptr; - int mkl_node_output_slot = 0; - // If it is a list, then create a list of Mkl dummy nodes. - for (int j = 0; j < N; j++) { - GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot); - output_nodes->push_back( - NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); - } - (*input_idx)++; - list_length -= N; - } else { - // If it is not a list, then create a single Mkl tensor node. - Node* mkl_node = nullptr; - int mkl_node_output_slot = 0; - GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot); - output_nodes->push_back( - NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); - (*input_idx)++; - list_length--; - } + // If 'n' is producing a single tensor, then create a single Mkl tensor + // node. + Node* mkl_node = nullptr; + int mkl_node_output_slot = 0; + GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, + &mkl_node_output_slot); + output_nodes->push_back(NodeBuilder::NodeOut(mkl_node, + mkl_node_output_slot)); + (*input_idx)++; + list_length--; } } @@ -835,9 +914,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( // node that we are constructing. An input node could be (1) 'n' // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor // if 'n' is not an Mkl layer. -void MklLayoutRewritePass::GetNodeProducingMklTensor( - std::unique_ptr* g, Node* n, int n_output_slot, Node** mkl_node, - int* mkl_node_output_slot) { +void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr* g, + Node* orig_node, Node* n, + int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) { CHECK_NOTNULL(n); CHECK_NOTNULL(mkl_node); CHECK_NOTNULL(mkl_node_output_slot); @@ -860,7 +939,7 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor( // to create a dummy node that will feed a dummy Mkl tensor to this node. // DummyMklTensor node has no input and generates only 1 output // (dummy Mkl tensor) as output slot number 0. - GetDummyMklTensorNode(g, mkl_node, n); + GetDummyMklTensorNode(g, mkl_node, orig_node); CHECK_NOTNULL(*mkl_node); *mkl_node_output_slot = 0; } @@ -926,16 +1005,16 @@ int MklLayoutRewritePass::SetUpContiguousInputs( if (ArgIsList(arg)) { std::vector new_node_inputs; int N = GetTensorListLength(arg, old_node); - GetNodesProducingMklTensorList(g, old_node_inputs, &iidx, N, - &new_node_inputs); + GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, + N, &new_node_inputs); nb->Input(new_node_inputs); nn_slot_idx++; } else { Node* mkl_node = nullptr; int mkl_node_output_slot = 0; - GetNodeProducingMklTensor(g, old_node_inputs[iidx].first, - old_node_inputs[iidx].second, &mkl_node, - &mkl_node_output_slot); + GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first, + old_node_inputs[iidx].second, + &mkl_node, &mkl_node_output_slot); nb->Input(mkl_node, mkl_node_output_slot); iidx++; nn_slot_idx++; @@ -1020,13 +1099,30 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( TensorShape dummy_shape({1}); dummy_shape.AsProto(proto.mutable_tensor_shape()); TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // same the device as the - // device of the original - // node. - .Finalize(&**g, out)); + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // same the device as the + // device of the original + // node. + .Finalize(&**g, out)); + + // If number of inputs to the original node is > 0, then we add + // control dependency between 1st input (index 0) of the original node and + // the dummy Mkl node. This is needed because control-flow ops such as Enter, + // Merge, etc, require frame_name of the dummy Mkl node to be same as the + // rewritten node. Adding control edge between 1st input of the original node + // and the dummy Mkl node ensures that the dummy node is in the same frame + // as the original node. Choosing 1st input is not necessary - any input of + // the original node is fine because all the inputs of a node are always in + // the same frame. + if (orig_node->num_inputs() > 0) { + Node* orig_input0 = nullptr; + TF_CHECK_OK(orig_node->input_node(0, + const_cast(&orig_input0))); + CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); + } + (*out)->set_assigned_device_name(orig_node->assigned_device_name()); } @@ -1235,6 +1331,19 @@ void MklLayoutRewritePass::CopyAttrsRelu(const Node* orig_node, nb->Attr("T", T); } +void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, + NodeBuilder* nb) { + DataType T; + DataType Tshape; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape)); + // Add attributes to new node. + nb->Attr("T", T); + nb->Attr("Tshape", Tshape); +} + void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb) { DataType T; @@ -1303,20 +1412,6 @@ void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node, nb->Attr("is_training", is_training); } -void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, - NodeBuilder* nb) { - DataType T; - DataType Tshape; - - // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape)); - - // Add attributes to new node. - nb->Attr("T", T); - nb->Attr("Tshape", Tshape); -} - ////////////////////////////////////////////////////////////////////////// // Helper functions related to node merge pass ////////////////////////////////////////////////////////////////////////// @@ -1353,8 +1448,9 @@ Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const { continue; } + const int B_in = b->num_inputs(); gtl::InlinedVector b_control_edges; - gtl::InlinedVector, 4> b_in(N_in); + gtl::InlinedVector, 4> b_in(B_in); FillInputs(b, &b_control_edges, &b_in); // Shouldn't merge if a and b have different control edges. @@ -1438,7 +1534,7 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr* g, Node* succ, CHECK_EQ(succ->in_edges().size(), 2); Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3 int oper3_mkl_slot = 0; // For dummy MKL tensor node, output slot is 0. - GetDummyMklTensorNode(g, &oper3_mkl, succ); // Get dummy Mkl tensor node + GetDummyMklTensorNode(g, &oper3_mkl, pred); // Get dummy Mkl tensor node // as BiasAdd does not have Mkl tensor as input. CHECK_NOTNULL(oper3_mkl); @@ -1483,9 +1579,38 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr* g, Node* succ, // Set the Mkl layer label for this op. new_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel); + // Incoming data edges from 'pred' node and 'succ' node to new 'new_node' + // node are already copied in BuildNode. We handle control edges now. + for (const Edge* e : pred->in_edges()) { + if (e->IsControlEdge()) { + CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); + } + } + for (const Edge* e : succ->in_edges()) { + if (e->IsControlEdge()) { + CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); + } + } + // Incoming edges are fixed, we will fix the outgoing edges now. + // First, we will fix outgoing control edges from 'pred' node. + // We don't need to handle outgoing data edges from 'pred' node + // because pred has only 1 output going to succ node (we enforced + // this check for merge already). + for (const Edge* e : pred->out_edges()) { + if (e->IsControlEdge()) { + CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); + } + } + + // Second, we will fix outgoing control and data edges from 'succ' node. for (const Edge* e : succ->out_edges()) { - (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input()); + if (e->IsControlEdge()) { + CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); + } else { + CHECK_NOTNULL((*g)->AddEdge(new_node, e->src_output(), e->dst(), + e->dst_input())); + } } // Copy device assigned to old node to new node. @@ -1550,18 +1675,22 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr* g, "data_format or T attribute or devices of BiasAddGrad and " "Conv2D do not match. Will skip node rewrite optimization"); } + } else if (orig_node->type_string() == csinfo_.bias_add_grad && + ri->new_name == csinfo_.matmul) { + // When BiasAddGrad has MatMul in context, we do not do any rewrite + // and leave BiasAddGrad as it is. But we check for this condition + // when we check for node rewrite rule. So we should not even come + // here for MatMul. So we will fail now. + return Status( + error::Code::INVALID_ARGUMENT, + "No rewrite is required for BiasAddGrad for MatMul context."); } } // Get all inputs. - const int num = orig_node->in_edges().size(); - // Check the number of inputs against the user-specified value for non-vararg - // nodes. - if (!IsVarArgNode(orig_node)) { - CHECK_EQ(num, ri->num_ins); - } + const int num_inputs = orig_node->in_edges().size(); gtl::InlinedVector control_edges; - gtl::InlinedVector, 4> inputs(num); + gtl::InlinedVector, 4> inputs(num_inputs); FillInputs(orig_node, &control_edges, &inputs); // Build new node. We use same name as original node, but change the op name. @@ -1596,8 +1725,15 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr* g, TF_CHECK_OK(nb.Finalize(&**g, &new_node)); CHECK_NOTNULL(new_node); - // Incoming edges from 'orig_node' node to new 'new_node' node are already - // copied in BuildNode. Copy outgoing edges from 'orig_node' node to new + // Incoming data edges from 'orig_node' node to new 'new_node' node are + // already copied in BuildNode. We need to handle control edges now. + for (const Edge* e : orig_node->in_edges()) { + if (e->IsControlEdge()) { + CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); + } + } + + // Copy outgoing edges from 'orig_node' node to new // 'new_node' node, since the output also follows same ordering among // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow // tensors appropriately. Specifically, nth output of the original node @@ -1605,15 +1741,12 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr* g, // of the tensors. For the contiguous ordering of the tensors, it will be n. // GetTensorDataIndex provides this mapping function. for (const Edge* e : orig_node->out_edges()) { - // We need to handle control-edges by using their original slot number. - // Generally, -1 is reserved for control slot. - if (e->src_output() < 0) { - (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input()); + if (e->IsControlEdge()) { + CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); } else { - (*g)->AddEdge( - new_node, - GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), - e->dst(), e->dst_input()); + CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(), + e->src()->num_outputs()), + e->dst(), e->dst_input())); } } @@ -1640,8 +1773,8 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n, bool is_matching_cinfo_found = false; std::vector mci; for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) { - if (n->type_string() == ci->node) { - mci.push_back(&*ci); + if (n->type_string() == (*ci)->node) { + mci.push_back(*ci); is_matching_cinfo_found = true; } } @@ -1701,9 +1834,10 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n, return nullptr; } -bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n) { +bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n, + const ContextInfo* c) { const Node* fwd_node = nullptr; - return SearchMatchingContext(n, &fwd_node) != nullptr; + return SearchMatchingContext(n, &fwd_node) == c; } const MklLayoutRewritePass::RewriteInfo* @@ -1719,18 +1853,29 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { return nullptr; } - if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) { - return nullptr; + // BiasAddGrad is not an Mkl layer, so we make an exception for it. + if (n->type_string() != csinfo_.bias_add_grad) { + if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) { + return nullptr; + } } // We support 2 types of node rewrites: - // 1. Rewriting BiasAddGrad depending on its context. + // 1. Rewriting BiasAddGrad depending on its MklConv2DWithBias context. // 2. Rewriting an op to Mkl op always // We return true if any of these 2 conditions is met. // Find matching RewriteInfo and then check that rewrite rule applies. for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { - if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) { + if (n->type_string().compare(ri->name) == 0 && + ri->rewrite_rule(n, ri->context)) { + // If we are rewriting BiasAddGrad into BiasAddGrad for MatMul context, + // then we just return directly. + if (n->type_string() == csinfo_.bias_add_grad && + ri->context->fwd == csinfo_.matmul && + ri->new_name == csinfo_.bias_add_grad) { + return nullptr; + } return &*ri; } } @@ -1753,7 +1898,8 @@ bool MklLayoutRewritePass::RunPass(std::unique_ptr* g) { GetReversePostOrder(**g, &order); // This will give us topological sort. for (Node* n : order) { - if (!n->IsOp()) { + // If node is not an op or it cannot run on CPU device, then skip. + if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { continue; } @@ -1801,18 +1947,31 @@ bool RunMklLayoutRewritePass(std::unique_ptr* g) { return MklLayoutRewritePass().RunPass(g); } -Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { - if (options.graph == nullptr) { +Status MklLayoutRewritePass::Run( + const GraphOptimizationPassOptions& options) { + if (options.graph == nullptr && options.partition_graphs == nullptr) { return Status::OK(); } - // Get the ownership of graph - std::unique_ptr* g = std::move(options.graph); + auto process_graph = [&](std::unique_ptr* g) { + // Get the ownership of a graph + std::unique_ptr* ng = std::move(g); + RunPass(ng); + // Return the ownership of a graph back + g->reset(ng->release()); + }; - RunPass(g); - - // Return the ownership of graph back - options.graph->reset(g->release()); + if (kMklLayoutRewritePassGroup != + OptimizationPassRegistry::POST_PARTITIONING) { + // For any pre-partitioning phase, a graph is stored in options.graph. + process_graph(options.graph); + } else { + // For post partitioning phase, graphs are stored in + // options.partition_graphs. + for (auto& pg : *options.partition_graphs) { + process_graph(&pg.second); + } + } return Status::OK(); } diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index 6e72baf84e2..3c4a5263afd 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -39,7 +39,11 @@ limitations under the License. namespace tensorflow { namespace { -static void InitGraph(const string& s, Graph* graph) { +const char kCPUDevice[] = "/job:a/replica:0/task:0/cpu:0"; +const char kGPUDevice[] = "/job:a/replica:0/task:0/gpu:0"; + +static void InitGraph(const string& s, Graph* graph, + const string& device = kCPUDevice) { GraphDef graph_def; auto parser = protobuf::TextFormat::Parser(); @@ -47,14 +51,18 @@ static void InitGraph(const string& s, Graph* graph) { CHECK(parser.MergeFromString(s, &graph_def)) << s; GraphConstructorOptions opts; TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph)); + + for (Node* node : graph->nodes()) { + node->set_assigned_device_name(device); + } } class MklLayoutPassTest : public ::testing::Test { public: MklLayoutPassTest() : graph_(OpRegistry::Global()) {} - void InitGraph(const string& s) { - ::tensorflow::InitGraph(s, &graph_); + void InitGraph(const string& s, const string& device = kCPUDevice) { + ::tensorflow::InitGraph(s, &graph_, device); original_ = CanonicalGraphString(&graph_); } @@ -114,7 +122,8 @@ REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful(); REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful(); REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); -REGISTER_OP("_MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful(); +REGISTER_OP("_MklInput2").Output("o: uint8") + .Output("o1: uint8").SetIsStateful(); ///////////////////////////////////////////////////////////////////// // Unit tests related to node merge optiimization @@ -162,8 +171,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) { " input: ['E', 'Y']}"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);" - "M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;" - "DMT/_0->E:5;E->Z;M->E:3;N->E:4;Y->Z:1"); + "M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;" + "A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;M->E:3;" + "N->E:4;Y->Z:1"); } // C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved) @@ -194,8 +204,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) { " input: ['E', 'Y']}"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);" - "M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;" - "DMT/_0->E:5;E->Z;M:1->E:3;N:1->E:4;Y->Z:1"); + "M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;" + "A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;" + "M:1->E:3;N:1->E:4;Y->Z:1"); } // C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y); @@ -226,8 +237,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) { EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Sub)|" - "A->E;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;DMT/_2->E:5;" - "E->Z;Y->Z:1"); + "A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" + "A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;" + "DMT/_2->E:5;E->Z;Y->Z:1"); } // Graph contains only _MklConv2D, no AddBias. @@ -330,9 +342,6 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) { "N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3"); } -// Disabling Conv2DBackpropBias test for now as we have disabled rewrite -// of BiasAddGrad into BackpropBias -#if 0 // Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias // rewrite tests @@ -361,18 +370,17 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) { " input: ['E'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);" - "E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);" - "O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;DMT/_0->F:1;E->F;" - "M->D:3;N->D:4;O->D:5"); + "E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);" + "N(_MklInput);O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;" + "DMT/_0->F:1;E->F;E:control->DMT/_0:control;M->D:3;N->D:4;" + "O->D:5"); } -#endif -// No _MklConv2D in context, but Conv2D in context. -// Only Conv2D would be rewritten to _MklConv2D, but no rewrite -// for BiasAddGrad should happen. +// No _MklConv2DWithBias in context, but _MklConv2D in context. +// No rewrite for BiasAddGrad should happen. // C=_MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved) // C=_MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous) -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_No_MklConv2DWithBias) { +TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) { InitGraph( "node { name: 'A' op: 'Input'}" "node { name: 'B' op: 'Input'}" @@ -507,8 +515,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) { "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['B', 'C'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|" - "A->C;B->C:1;B->D;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); + "A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);" + "DMT/_1(Const)|A->C;A:control->DMT/_0:control;" + "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;" + "DMT/_1->C:3"); } // 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will @@ -535,7 +545,9 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) { " input: ['C', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;B->C:1;C->D:1;C->E;" + "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;" + "A:control->DMT/_0:control;A:control->DMT/_1:control;" + "A:control->DMT/_2:control;B->C:1;C->D:1;C->E;" "C:1->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2"); } @@ -558,6 +570,50 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) { "A->C;B->C:1;B->D;C->D:1"); } +TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Int32Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: 'Conv2DBackpropFilter'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'B', 'C']}" + "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'D'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropFilter);" + "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Mul)|" + "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" + "A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" + "DMT/_1->D:4;DMT/_2->D:5"); +} + +TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradInput_Positive) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Int32Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: 'Conv2DBackpropInput'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['B', 'A', 'C']}" + "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'D'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropInput);" + "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Mul)|" + "A->D:1;A->E;B->D;B:control->DMT/_0:control;" + "B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;" + "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); +} + // Concat Op test: Concat with no Mkl layer feeding it TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) { InitGraph( @@ -572,13 +628,14 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) { "node { name: 'D' op: 'Concat'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'N' value { i: 2 } }" - " input: ['A', 'B']}" + " input: ['A', 'B:0', 'B:1']}" "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;B->D:1;B->D:2;C->E;" - "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); + "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;A:control->DMT/_0:control;" + "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" + "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); } // Concat with 2 Mkl layers feeding it @@ -616,9 +673,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) { EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" - "F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;" + "F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;" + "A:control->DMT/_2:control;A:control->DMT/_3:control;" + "B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;" "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;" - "DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;H->I:1"); + "DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;" + "G:control->DMT/_4:control;H->I:1"); } // Concat with 1 Mkl and 1 non-Mkl layer feeding it @@ -651,12 +711,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) { EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);" - "H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;" + "H(_MklConcat);I(Mul)|A->E;A->I;A:control->DMT/_0:control;" + "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;" "DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:1->H:4;F->H:2;" - "G->H;H->I:1"); + "G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1"); } -#if 0 // ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) { InitGraph( @@ -676,11 +736,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) { "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;C->E;" - "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); + "A(Const);B(InputList);C(Input);D(_MklConcatV2);DMT/_0(Const);" + "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;" + "B:control->DMT/_0:control;B:control->DMT/_1:control;" + "B:control->DMT/_2:control;C->E;D->E:1;DMT/_0->D:3;" + "DMT/_1->D:4;DMT/_2->D:5"); } -#endif // ConcatV2 with 2 Mkl layers feeding it TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) { @@ -718,9 +779,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) { EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" - "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;" + "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;" + "A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;" + "C:control->DMT/_0:control;C:control->DMT/_1:control;" "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;" - "DMT/_4->H:5;E->H;E:1->H:3;F->H:1;F:1->H:4;G->H:2;H->I:1"); + "DMT/_4->H:5;E->H;E:1->H:3;E:control->DMT/_4:control;F->H:1;" + "F:1->H:4;G->H:2;H->I:1"); } // ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it @@ -754,11 +818,175 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) { EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);" - "H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;" - "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;F->H:1;" + "H(_MklConcatV2);I(Mul)|A->E;A->I;A:control->DMT/_0:control;" + "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;" + "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;" + "E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;" "G->H:2;H->I:1"); } +TEST_F(MklLayoutPassTest, NodeRewrite_Relu_Positive) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Relu'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A'] }" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(_MklRelu);C(Mul);DMT/_0(Const)|A->B;A->C;" + "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); +} + +TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_Positive) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'ReluGrad'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'C'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(_MklReluGrad);D(Mul);DMT/_0(Const);" + "DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;" + "A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); +} + +TEST_F(MklLayoutPassTest, NodeRewrite_ReluReluGrad_Positive) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Relu'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A'] }" + "node { name: 'C' op: 'ReluGrad'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'C'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(_MklRelu);C(_MklReluGrad);D(Mul);DMT/_0(Const);" + "DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;" + "A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;" + "DMT/_1->C:2"); +} + +TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'AvgPool'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" + " attr { key: 'padding' value { s: 'VALID' } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" + " input: ['A'] }" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(_MklAvgPool);C(Mul);DMT/_0(Const)|A->B;A->C;" + "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); +} + +TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolGrad_Positive) { + InitGraph( + "node { name: 'A' op: 'Int32Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'AvgPoolGrad' " + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" + " attr { key: 'padding' value { s: 'VALID' } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['B', 'C'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Mul);DMT/_0(Const);" + "DMT/_1(Const)|A->C;A:control->DMT/_0:control;" + "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;" + "DMT/_1->C:3"); +} + +TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolAvgPoolGrad_Positive) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'I' op: 'Int32Input'}" + "node { name: 'B' op: 'AvgPool'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" + " attr { key: 'padding' value { s: 'VALID' } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" + " input: ['A'] }" + "node { name: 'C' op: 'AvgPoolGrad' " + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" + " attr { key: 'padding' value { s: 'VALID' } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" + " input: ['I', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'C'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Mul);DMT/_0(Const);" + "DMT/_1(Const);I(Int32Input)|A->B;A->D;A:control->DMT/_0:control;" + "B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;DMT/_1->C:2;I->C;" + "I:control->DMT/_1:control"); +} + +TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGrad_Positive) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: 'Input'}" + "node { name: 'E' op: 'Input'}" + "node { name: 'F' op: 'FusedBatchNormGrad'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'epsilon' value { f: 0.0001 } }" + " attr { key: 'is_training' value { b: true } }" + " input: ['A', 'B', 'C', 'D', 'E'] }" + "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'F'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" + "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);" + "F(_MklFusedBatchNormGrad);G(Mul)|A->F;A->G;" + "A:control->DMT/_0:control;A:control->DMT/_1:control;" + "A:control->DMT/_2:control;A:control->DMT/_3:control;" + "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" + "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" + "E->F:4;F->G:1"); +} + +TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_Positive) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: 'Input'}" + "node { name: 'E' op: 'Input'}" + "node { name: 'F' op: 'FusedBatchNorm'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'epsilon' value { f: 0.0001 } }" + " attr { key: 'is_training' value { b: true } }" + " input: ['A', 'B', 'C', 'D', 'E'] }" + "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'F'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" + "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);" + "F(_MklFusedBatchNorm);G(Mul)|A->F;A->G;" + "A:control->DMT/_0:control;A:control->DMT/_1:control;" + "A:control->DMT/_2:control;A:control->DMT/_3:control;" + "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" + "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" + "E->F:4;F->G:1"); +} + ///////////////////////////////////////////////////////////////////// // Unit tests related to rewriting node for workspace edges ///////////////////////////////////////////////////////////////////// @@ -802,13 +1030,13 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) { "node { name: 'H' op: 'Input'}" "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['H', 'G'] }"); - EXPECT_EQ( - DoMklLayoutOptimizationPass(), + EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);I(Mul)|" - "A->B;B->C;B->E;B->G:2;B:1->G:3;B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;" - "C->E:1;C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;" - "DMT/_2->G:5;E->G;E:1->G:4;F->G:1;G->I:1;H->I"); + "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);" + "I(Mul)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" + "B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;B:control->DMT/_1:control;C->E:1;" + "C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;DMT/_2->G:5;" + "E->G;E:1->G:4;E:control->DMT/_2:control;F->G:1;G->I:1;H->I"); } /* Test LRN->LRNGrad replacement by workspace nodes. */ @@ -838,8 +1066,9 @@ TEST_F(MklLayoutPassTest, LRN_Positive) { EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);E(_MklLRNGrad);F(Mul)|" - "A->B;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;C->E;C->F;D->E:1;" - "DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1"); + "A->B;A:control->DMT/_0:control;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;" + "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;" + "D->E:1;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1"); } /* Test LRN->LRNGrad replacement when only one of them is present. */ @@ -858,7 +1087,7 @@ TEST_F(MklLayoutPassTest, LRN_Negative1) { " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(_MklLRN);C(Mul);DMT/_0(Const)|" - "A->B;A->C;B->C:1;DMT/_0->B:1"); + "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); } /* Test LRN->LRNGrad replacement when only one of them is present. */ @@ -880,8 +1109,10 @@ TEST_F(MklLayoutPassTest, LRN_Negative2) { EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);" "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|" - "A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;" - "DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6"); + "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" + "A:control->DMT/_2:control;A:control->DMT/_3:control;" + "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" + "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6"); } /* Test LRN->LRNGrad negative case, where single LRN feeds @@ -920,9 +1151,13 @@ TEST_F(MklLayoutPassTest, LRN_Negative3) { EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);" - "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;B->E:2;" - "B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;D->E:1;" - "D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;" + "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;" + "A:control->DMT/_0:control;B->E:2;" + "B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;" + "C:control->DMT/_1:control;C:control->DMT/_2:control;" + "C:control->DMT/_3:control;C:control->DMT/_4:control;" + "C:control->DMT/_5:control;C:control->DMT/_6:control;" + "D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;" "DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1"); } @@ -951,8 +1186,9 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) { EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);" "DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Mul)|" - "A->B;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;C->E;C->F;D->E:2;" - "DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1"); + "A->B;A:control->DMT/_0:control;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;" + "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;" + "D->E:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1"); } // Test MaxPool>MaxPoolGrad replacement when only one of them is present. @@ -972,7 +1208,7 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) { " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(_MklMaxPool);C(Mul);DMT/_0(Const)|" - "A->B;A->C;B->C:1;DMT/_0->B:1"); + "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); } // Test MaxPoolGrad replacement when only one of them is present. @@ -995,8 +1231,374 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) { EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);" "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|" - "A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;" - "DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6"); + "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" + "A:control->DMT/_2:control;A:control->DMT/_3:control;" + "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" + "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6"); +} + +// Test MaxPool handling for batch-wise pooling (NCHW) +// No rewrite should take place in such case +TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative3) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'MaxPool'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'VALID' } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " input: ['A'] }" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); +} + +// Test MaxPool handling for batch-wise pooling (NCHW) +// No rewrite should take place in such case +TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative4) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'MaxPool'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'VALID' } }" + " attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }" + " input: ['A'] }" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); +} + +// Test MaxPool handling for depth-wise pooling (NHWC) +// No rewrite should take place in such case +TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative5) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'MaxPool'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'ksize' value { list: {i: 1, i:2, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'VALID' } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " input: ['A'] }" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); +} + +// Test MaxPool handling for depth-wise pooling (NCHW) +// No rewrite should take place in such case +TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative6) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'MaxPool'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'VALID' } }" + " attr { key: 'strides' value { list: {i: 1, i:2, i:1, i:1} } }" + " input: ['A'] }" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); +} + +// Test MaxPool handling for batch-wise pooling (NHWC) +// No rewrite should take place in such case +TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative7) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'MaxPool'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NHWC' } }" + " attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'VALID' } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " input: ['A'] }" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); +} + +// Test MaxPool handling for batch-wise pooling (NHWC) +// No rewrite should take place in such case +TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative8) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'MaxPool'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NHWC' } }" + " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'VALID' } }" + " attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }" + " input: ['A'] }" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); +} + +// Test MaxPool handling for depth-wise pooling (NHWC) +// No rewrite should take place in such case +TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative9) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'MaxPool'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NHWC' } }" + " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:2} } }" + " attr { key: 'padding' value { s: 'VALID' } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " input: ['A'] }" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); +} + +// Test MaxPool handling for depth-wise pooling (NHWC) +// No rewrite should take place in such case +TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative10) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'MaxPool'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NHWC' } }" + " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'VALID' } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:2} } }" + " input: ['A'] }" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); +} + +///////////////////////////////////////////////////////////////////// + +// Single Conv2D Op on GPU device +// No rewrite should happen +TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Conv2D'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'B']}" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['B', 'C'] }", kGPUDevice); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(Conv2D);D(Mul)|A->C;B->C:1;B->D;C->D:1"); +} + +TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'O' op: '_MklInput'}" + "node { name: 'D' op: '_MklConv2DWithBias'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'B', 'C', 'M', 'N', 'O']}" + "node { name: 'E' op: 'Sub'" + " attr {key: 'T' value { type: DT_FLOAT } }" + " input: ['D', 'A']}" + "node { name: 'F' op: 'BiasAddGrad'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " input: ['E'] }", kGPUDevice); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" + "E(Sub);F(BiasAddGrad);M(_MklInput);N(_MklInput);" + "O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;E->F;" + "M->D:3;N->D:4;O->D:5"); +} + +TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Int32Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: 'Conv2DBackpropFilter'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'B', 'C']}" + "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'D'] }", kGPUDevice); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Mul)|" + "A->D;A->E;B->D:1;C->D:2;D->E:1"); +} + +TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Relu'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A'] }" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }", kGPUDevice); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Relu);C(Mul)|A->B;A->C;B->C:1"); +} + +TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'ReluGrad'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'C'] }", kGPUDevice); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(ReluGrad);D(Mul)|A->C;A->D;B->C:1;C->D:1"); +} + +TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'MaxPool'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NHWC' } }" + " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'VALID' } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " input: ['A'] }" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }", kGPUDevice); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); +} + +TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'AvgPool'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NHWC' } }" + " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'VALID' } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " input: ['A'] }" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }", kGPUDevice); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(AvgPool);C(Mul)|A->B;A->C;B->C:1"); +} + +// Concat Op test: Concat with no Mkl layer feeding it +TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) { + InitGraph( + "node { name: 'A' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT32 } }" + " attr { key: 'value' value { " + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 0 } } } }" + "node { name: 'B' op: 'InputList'" + " attr { key: 'N' value { i: 2 } }}" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: 'Concat'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'N' value { i: 2 } }" + " input: ['A', 'B:0', 'B:1']}" + "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['C', 'D'] }", kGPUDevice); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Const);B(InputList);C(Input);D(Concat);E(Mul)|A->D;" + "B->D:1;B:1->D:2;C->E;D->E:1"); +} + +TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) { + InitGraph( + "node { name: 'A' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT32 } }" + " attr { key: 'value' value { " + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 0 } } } }" + "node { name: 'B' op: 'InputList'" + " attr { key: 'N' value { i: 2 } }}" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: 'ConcatV2'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'Tidx' value { type: DT_INT32 } }" + " attr { key: 'N' value { i: 2 } }" + " input: ['B:0', 'B:1', 'A']}" + "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['C', 'D'] }", kGPUDevice); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Const);B(InputList);C(Input);D(ConcatV2);E(Mul)|" + "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); +} + +TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: 'Input'}" + "node { name: 'E' op: 'Input'}" + "node { name: 'F' op: 'FusedBatchNorm'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'epsilon' value { f: 0.0001 } }" + " attr { key: 'is_training' value { b: true } }" + " input: ['A', 'B', 'C', 'D', 'E'] }" + "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'F'] }", kGPUDevice); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(Input);D(Input);E(Input);" + "F(FusedBatchNorm);G(Mul)|A->F;A->G;B->F:1;C->F:2;D->F:3;" + "E->F:4;F->G:1"); +} + +TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) { + CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'C' op: '_MklConv2D'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'B', 'M', 'N']}" + "node { name: 'D' op: 'Input'}" + "node { name: 'E' op: 'BiasAdd'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " input: ['C', 'D'] }" + "node { name: 'Y' op: 'Input'}" + "node { name: 'Z' op: 'Sub'" + " attr {key: 'T' value { type: DT_FLOAT } }" + " input: ['E', 'Y']}", kGPUDevice); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);" + "M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->C;" + "B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1"); } ///////////////////////////////////////////////////////////////////// diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc index 55c280719c3..590b3d030fa 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc @@ -98,12 +98,13 @@ class MklToTfConversionPass : public GraphOptimizationPass { Status InsertConversionNodeOnEdge(std::unique_ptr* g, Edge*); }; -// We register MklToTf insertion for phase 1 in post-partition grouping. -// We register this pass after partitioning so that we get a complete -// picture of inputs and outputs of the nodes in the graphs. +// We register MklToTf insertion for phase 2 in post-partition grouping +// because we register MklLayoutRewritePass for phase 1 in post-partition +// grouping. We register this pass after partitioning so that we get a +// complete picture of inputs and outputs of the nodes in the graphs. const OptimizationPassRegistry::Grouping kMklTfConvPassGroup = OptimizationPassRegistry::POST_PARTITIONING; -REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 1, MklToTfConversionPass); +REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass); Status MklToTfConversionPass::InsertConversionNodeOnEdge( std::unique_ptr* g, Edge* e) { @@ -121,10 +122,12 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge( string data_format; TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype)); - TF_CHECK_OK(GetNodeAttr(dst->def(), "T", &dst_datatype)); - if (src_datatype != dst_datatype) { - string err_msg = "T attribute of " + src->name() + " and " + dst->name() + - " do not match. Will not insert" + + bool dst_dtype_found = GetNodeAttr(dst->def(), "T", &dst_datatype) == + Status::OK(); + // We compare source and destination datatypes only when both are found. + if (dst_dtype_found && (src_datatype != dst_datatype)) { + string err_msg = "T attribute of " + src->name() + " and " + + dst->name() + " do not match. Will not insert" + " MklToTf node in such case."; return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str()); } @@ -202,18 +205,19 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr* g) { << src->type_string() << " and " << dst->type_string(); // Let's get source and destination data type. - DataType src_datatype = DT_INVALID; - if (GetNodeAttr(src->def(), "T", &src_datatype) != Status::OK()) { - continue; - } // We cannot check datatype on destination node because destination node // may not be Mkl node. - DataType dst_datatype = DT_INVALID; - GetNodeAttr(dst->def(), "T", &dst_datatype); + DataType src_datatype; + DataType dst_datatype; + bool src_is_mkl_op = (GetNodeAttr(src->def(), "T", &src_datatype) == + Status::OK() && + IsMklSupportedOp(src->type_string(), src_datatype)); + bool dst_is_mkl_op = (GetNodeAttr(dst->def(), "T", &dst_datatype) == + Status::OK() && + IsMklSupportedOp(dst->type_string(), dst_datatype)); // Check if src with is Mkl-compliant, while dst is not Mkl-compliant. - if (IsMklSupportedOp(src->type_string(), src_datatype) && - !IsMklSupportedOp(dst->type_string(), dst_datatype)) { + if (src_is_mkl_op && !dst_is_mkl_op) { VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name() << " and " << dst->name() << " for inserting conversion nodes"; candidate_edges.push_back(const_cast(e)); diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc index bd2cb0989c1..90bef111648 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc @@ -149,7 +149,7 @@ TEST_F(MklToTfConversionPass, Positive) { " input: ['C', 'D']}"); EXPECT_EQ(DoRunMklToTfConversionPass(), "A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);" - "_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;" + "Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;" "C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3"); } else { CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); @@ -172,7 +172,7 @@ TEST_F(MklToTfConversionPass, Positive) { " input: ['C', 'D']}"); EXPECT_EQ(DoRunMklToTfConversionPass(), "A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);" - "_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;" + "Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;" "C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3"); } } diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc index 27930c44a65..094ab1c6c64 100644 --- a/tensorflow/core/kernels/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl_concat_op.cc @@ -265,6 +265,7 @@ class MklConcatOp : public OpKernel { s.GetDimension() > 0 ? s.GetSizes()[concat_dim] : 1; } mkl_context.MklCreateInputLayouts(context, input_shapes); + OP_REQUIRES_OK(context, context->status()); CHECK_EQ(dnnConcatCreate_F32(&mkl_context.prim_concat, NULL, N, &mkl_context.lt_inputs[0]), @@ -316,12 +317,14 @@ class MklConcatOp : public OpKernel { mkl_context.mkl_tmp_tensors.resize(N); mkl_context.MklPrepareConcatInputs(context, input_tensors); + OP_REQUIRES_OK(context, context->status()); // Execute primitive. CHECK_EQ(dnnExecute_F32(mkl_context.prim_concat, mkl_context.concat_res), E_SUCCESS); mkl_context.MklCleanup(); + OP_REQUIRES_OK(context, context->status()); } private: diff --git a/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc index 8a1006a8e95..d4364d31e41 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc @@ -38,9 +38,9 @@ limitations under the License. #include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" +#include "tensorflow/core/util/mkl_util.h" #include "third_party/mkl/include/mkl_dnn.h" #include "third_party/mkl/include/mkl_dnn_types.h" -#include "tensorflow/core/util/mkl_util.h" namespace tensorflow { @@ -252,7 +252,7 @@ class MklConv2DCustomBackpropBiasOp : public OpKernel { }; #define REGISTER_CPU_KERNELS(T) \ - REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBiasBackpropBias") \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBiasBackpropBias") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index 6381b527a1b..dc6b88e953a 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -37,9 +37,9 @@ limitations under the License. #include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" +#include "tensorflow/core/util/mkl_util.h" #include "third_party/mkl/include/mkl_dnn.h" #include "third_party/mkl/include/mkl_dnn_types.h" -#include "tensorflow/core/util/mkl_util.h" namespace tensorflow { @@ -266,8 +266,11 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { int input_offsets[2]; size_t conv_strides[2]; MklShape input_shape, grad_filter_shape, out_backprop_shape; - dnnPrimitive_t prim_conv_bwdfilter, convert_bwdfilter; - dnnLayout_t lt_input, lt_grad_filter, lt_out_backprop; + dnnPrimitive_t prim_conv_bwdfilter = nullptr; + dnnPrimitive_t convert_bwdfilter = nullptr; + dnnLayout_t lt_input = nullptr; + dnnLayout_t lt_grad_filter = nullptr; + dnnLayout_t lt_out_backprop = nullptr; void* conv_res[dnnResourceNumber]; void MklCleanup() { @@ -409,7 +412,7 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { }; #define REGISTER_MKL_FILTER_KERNELS(T) \ - REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index 638ce4c0243..c97f1dd7b73 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -23,8 +23,6 @@ limitations under the License. #define EIGEN_USE_THREADS #include #include -#include "third_party/mkl/include/mkl_dnn.h" -#include "third_party/mkl/include/mkl_dnn_types.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -42,6 +40,8 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" +#include "third_party/mkl/include/mkl_dnn.h" +#include "third_party/mkl/include/mkl_dnn_types.h" namespace tensorflow { @@ -342,7 +342,7 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { }; #define REGISTER_MKL_CPU_KERNELS(T) \ - REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index b818819b020..76b9f1798dd 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -36,9 +36,9 @@ limitations under the License. #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/util/mkl_util.h" #include "third_party/mkl/include/mkl_dnn.h" #include "third_party/mkl/include/mkl_dnn_types.h" -#include "tensorflow/core/util/mkl_util.h" namespace tensorflow { @@ -98,19 +98,18 @@ class MklConv2DOp : public OpKernel { filter.shape().DebugString())); for (int i = 0; i < 3; i++) { - OP_REQUIRES( - context, - FastBoundsCheck(filter.dim_size(i), std::numeric_limits::max()), - errors::InvalidArgument("filter too large")); + OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i), + std::numeric_limits::max()), + errors::InvalidArgument("filter too large")); } const int64 input_depth = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'C') : GetTensorDim(input, data_format_, 'C'); - OP_REQUIRES(context, input_depth == filter.dim_size(2), - errors::InvalidArgument( - "input and filter must have the same depth: ", input_depth, - " vs ", filter.dim_size(2))); + OP_REQUIRES( + context, input_depth == filter.dim_size(2), + errors::InvalidArgument("input and filter must have the same depth: ", + input_depth, " vs ", filter.dim_size(2))); // The last dimension for filter is out_depth. const int out_depth = static_cast(filter.dim_size(3)); @@ -119,10 +118,9 @@ class MklConv2DOp : public OpKernel { const int64 input_rows_raw = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'H') : GetTensorDim(input, data_format_, 'H'); - OP_REQUIRES( - context, - FastBoundsCheck(input_rows_raw, std::numeric_limits::max()), - errors::InvalidArgument("Input rows too large")); + OP_REQUIRES(context, FastBoundsCheck(input_rows_raw, + std::numeric_limits::max()), + errors::InvalidArgument("Input rows too large")); const int input_rows = static_cast(input_rows_raw); const int filter_rows = static_cast(filter.dim_size(0)); @@ -131,10 +129,9 @@ class MklConv2DOp : public OpKernel { const int64 input_cols_raw = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'W') : GetTensorDim(input, data_format_, 'W'); - OP_REQUIRES( - context, - FastBoundsCheck(input_cols_raw, std::numeric_limits::max()), - errors::InvalidArgument("Input cols too large")); + OP_REQUIRES(context, FastBoundsCheck(input_cols_raw, + std::numeric_limits::max()), + errors::InvalidArgument("Input cols too large")); const int input_cols = static_cast(input_cols_raw); const int filter_cols = static_cast(filter.dim_size(1)); @@ -142,10 +139,9 @@ class MklConv2DOp : public OpKernel { const int64 input_batch_raw = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'N') : GetTensorDim(input, data_format_, 'N'); - OP_REQUIRES( - context, - FastBoundsCheck(input_batch_raw, std::numeric_limits::max()), - errors::InvalidArgument("batch is too large")); + OP_REQUIRES(context, FastBoundsCheck(input_batch_raw, + std::numeric_limits::max()), + errors::InvalidArgument("batch is too large")); const int batch = static_cast(input_batch_raw); // For now we take the stride from the second and third dimensions only (we @@ -438,12 +434,12 @@ class MklConv2DOp : public OpKernel { }; #define REGISTER_MKL_CPU(T) \ - REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ MklConv2DOp); \ - REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc index edca8e2553d..9d050d430ae 100644 --- a/tensorflow/core/kernels/mkl_lrn_op.cc +++ b/tensorflow/core/kernels/mkl_lrn_op.cc @@ -104,6 +104,15 @@ class MklLRNOp : public OpKernel { return; } + // TODO(inteltf) MKL will support depth radius not equal to 2 in the future + if (depth_radius_ != 2) { + Tensor converted_tensor = + ConvertMklToTF(context, input, mkl_context.input_shape); + mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_, + beta_, converted_tensor); + return; + } + if (input_in_mkl_format) { // MKL supports normalization over channel dimension only if (mkl_context.input_shape.tf_dim_idx(mkl_context.in_dims - 1) == @@ -112,8 +121,10 @@ class MklLRNOp : public OpKernel { static_cast(mkl_context.input_shape.GetCurLayout()); workspace_enabled_ = true; } else { + Tensor converted_tensor = + ConvertMklToTF(context, input, mkl_context.input_shape); mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_, - beta_, input); + beta_, converted_tensor); return; } } @@ -267,7 +278,7 @@ class MklLRNOp : public OpKernel { } // Fallback implementation - Taken from lrn_op.cc - // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a + // TODO(inteltf) Check if we can use EigenLRNOp directly instead of making a // copy. void MklDefaultToEigen(OpKernelContext* context, int depth_radius_, float bias_, float alpha_, float beta_, @@ -378,6 +389,12 @@ class MklLRNGradOp : public OpKernel { mkl_context.MklDefaultToEigen(context); return; } + + if (depth_radius_ != 2) { + mkl_context.MklDefaultToEigen(context); + return; + } + if (ingrad_in_mkl_format || inimage_in_mkl_format) { const MklShape* tmp_mkl_shape = (ingrad_in_mkl_format) ? &mkl_context.ingrad_shape @@ -489,14 +506,11 @@ class MklLRNGradOp : public OpKernel { MklShape ingrad_shape, inimage_shape, outimage_shape; dnnPrimitive_t lrn_bwd = nullptr; dnnPrimitive_t convert_input = nullptr; - /* dnnPrimitive_t convert_output; */ dnnLayout_t lt_input = nullptr; dnnLayout_t lt_output = nullptr; dnnLayout_t lt_bdw_input = nullptr; dnnLayout_t lt_workspace = nullptr; dnnLayout_t lt_internal_input = nullptr; - /* dnnLayout_t lt_internal_workspace; - dnnLayout_t lt_internal_output; */ void* res_lrn_bwd[dnnResourceNumber]; // prepare mkl input @@ -619,14 +633,36 @@ class MklLRNGradOp : public OpKernel { // copy. void MklDefaultToEigen(OpKernelContext* context) { // CHECK(false); - Tensor in_grads = MklGetInput(context, 0); - Tensor in_image = MklGetInput(context, 1); - Tensor out_image = MklGetInput(context, 2); + + Tensor in_grads; + Tensor in_image; + Tensor out_image; GetMklShape(context, 0, &ingrad_shape); GetMklShape(context, 1, &inimage_shape); GetMklShape(context, 2, &outimage_shape); + if (ingrad_shape.IsMklTensor()) { + in_grads = + ConvertMklToTF(context, MklGetInput(context, 0), ingrad_shape); + } else { + in_grads = MklGetInput(context, 0); + } + + if (inimage_shape.IsMklTensor()) { + in_image = + ConvertMklToTF(context, MklGetInput(context, 1), inimage_shape); + } else { + in_image = MklGetInput(context, 1); + } + + if (outimage_shape.IsMklTensor()) { + out_image = + ConvertMklToTF(context, MklGetInput(context, 2), outimage_shape); + } else { + out_image = MklGetInput(context, 2); + } + const int64 batch = static_cast(in_grads.dim_size(0)); const int64 rows = static_cast(in_grads.dim_size(1)); const int64 cols = static_cast(in_grads.dim_size(2)); diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc index 3ba28c13ed5..e43b75e2504 100644 --- a/tensorflow/core/kernels/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -199,15 +199,13 @@ class MklMatMulOp : public OpKernel { } }; -#define REGISTER_CPU(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("MatMul").Device(DEVICE_CPU).TypeConstraint("T"), \ - MklMatMulOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("MatMul").Device(DEVICE_CPU).TypeConstraint("T").Label("MKL"), \ - MklMatMulOp) +#define REGISTER_CPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("MatMul").Device(DEVICE_CPU).TypeConstraint("T"), \ + MklMatMulOp); -// TODO:Consider template specialization when adding/removing additional types +// TODO(inteltf) Consider template specialization when adding/removing +// additional types TF_CALL_float(REGISTER_CPU); TF_CALL_double(REGISTER_CPU); TF_CALL_complex64(REGISTER_CPU); diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc index ba2d347d941..1e0ee258b09 100644 --- a/tensorflow/core/kernels/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc @@ -276,11 +276,6 @@ class MklMaxPoolingGradOp : public OpKernel { mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast( static_cast(output_tensor->flat().data())); - int64 output_size = output_tensor->NumElements(); - for (int64 i = 0; i < output_size; ++i) { - (static_cast(mkl_context.pooling_res[dnnResourceDiffSrc]))[i] = 0; - } - CHECK_EQ( dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res), E_SUCCESS); diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 25c8359cc53..0c66f731410 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -16,17 +16,17 @@ limitations under the License. // See docs in ../ops/nn_ops.cc. #ifdef INTEL_MKL -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "third_party/mkl/include/mkl_dnn.h" -#include "third_party/mkl/include/mkl_dnn_types.h" #include "tensorflow/core/platform/default/logging.h" #include "tensorflow/core/util/mkl_util.h" +#include "third_party/mkl/include/mkl_dnn.h" +#include "third_party/mkl/include/mkl_dnn_types.h" namespace tensorflow { @@ -194,45 +194,29 @@ class MklReluGradOp : public OpKernel { void* user_i = static_cast(const_cast(a.flat().data())); void* user_g = static_cast(const_cast(g.flat().data())); + dnnPrimitive_t cv_input_to_grad = NULL; + Tensor mkl_tmp_buf_tensor; + void* mkl_buffer_convert = nullptr; - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( - &mkl_lt_internal_grad, prim_relu_bwd, dnnResourceDiffDst), - E_SUCCESS); + // if input and grad are not in the same layout, do a conversion between + // them. + if (!dnnLayoutCompare_F32(lt_input, lt_grad)) { + AllocTmpBuffer(context, &mkl_tmp_buf_tensor, lt_grad, + &mkl_buffer_convert); + CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad), + E_SUCCESS); - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input, - prim_relu_bwd, dnnResourceSrc), - E_SUCCESS); - - if (!dnnLayoutCompare_F32(mkl_lt_internal_grad, lt_grad)) { - AllocTmpBuffer(context, mkl_tmp_grad_buf_tensor, mkl_lt_internal_grad, - &relu_res[dnnResourceDiffDst]); - CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_grad, lt_grad, - mkl_lt_internal_grad), + CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, user_i, + mkl_buffer_convert), E_SUCCESS); - CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_grad, user_g, - relu_res[dnnResourceDiffDst]), - E_SUCCESS); - dnnDelete_F32(cv_user_to_reluB_grad); - } else { - relu_res[dnnResourceDiffDst] = user_g; - } - - if (!dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input)) { - AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input, - &relu_res[dnnResourceSrc]); - CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_input, lt_input, - mkl_lt_internal_input), - E_SUCCESS); - CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_input, user_i, - relu_res[dnnResourceSrc]), - E_SUCCESS); - dnnDelete_F32(cv_user_to_reluB_input); + relu_res[dnnResourceSrc] = mkl_buffer_convert; + dnnDelete_F32(cv_input_to_grad); } else { relu_res[dnnResourceSrc] = user_i; } - dnnLayoutDelete_F32(mkl_lt_internal_input); - dnnLayoutDelete_F32(mkl_lt_internal_grad); + relu_res[dnnResourceDiffDst] = user_g; + } void MklCreateInputLayouts(OpKernelContext* context) { @@ -331,7 +315,7 @@ void MklReluGradOp::Compute(OpKernelContext* context) { mkl_context.MklCreateInputLayouts(context); float negative_slope = 0.0; CHECK_EQ(dnnReLUCreateBackward_F32(&mkl_context.prim_relu_bwd, NULL, - mkl_context.lt_grad, mkl_context.lt_input, + mkl_context.lt_grad, mkl_context.lt_grad, negative_slope), E_SUCCESS); Tensor mkl_tmp_grad_buf_tensor, mkl_tmp_input_buf_tensor; @@ -380,12 +364,12 @@ void MklReluGradOp::Compute(OpKernelContext* context) { /* Register DNN kernels for supported operations and supported types - right now * it is only Relu and f32*/ #define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ - REGISTER_KERNEL_BUILDER(Name("_MklRelu") \ + REGISTER_KERNEL_BUILDER(Name("_MklRelu") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ MklReluOp); \ - REGISTER_KERNEL_BUILDER(Name("_MklReluGrad") \ + REGISTER_KERNEL_BUILDER(Name("_MklReluGrad") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ diff --git a/tensorflow/core/kernels/mkl_tfconv_op.cc b/tensorflow/core/kernels/mkl_tfconv_op.cc index c31ef5c2554..588d6874dd6 100644 --- a/tensorflow/core/kernels/mkl_tfconv_op.cc +++ b/tensorflow/core/kernels/mkl_tfconv_op.cc @@ -106,7 +106,7 @@ class MklToTfOp : public OpKernel { /////////////////////////////////////////////////////////// #define REGISTER_CPU(T) \ - REGISTER_KERNEL_BUILDER(Name("MklToTf") \ + REGISTER_KERNEL_BUILDER(Name("_MklToTf") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ diff --git a/tensorflow/tools/ci_build/builds/configured b/tensorflow/tools/ci_build/builds/configured index 25cb51ea7cc..9fcf9169026 100755 --- a/tensorflow/tools/ci_build/builds/configured +++ b/tensorflow/tools/ci_build/builds/configured @@ -28,6 +28,12 @@ set -e CONTAINER_TYPE=$( echo "$1" | tr '[:upper:]' '[:lower:]' ) shift 1 +# Enable support for MKL, for Linux only. +if [[ $(uname) == "Linux" ]]; then + export TF_NEED_MKL=1 + export TF_DOWNLOAD_MKL=1 +fi + # Enable support for Google Cloud Platform (GCP) export TF_NEED_GCP=1 # Enable support for HDFS diff --git a/tensorflow/tools/ci_build/install/install_deb_packages.sh b/tensorflow/tools/ci_build/install/install_deb_packages.sh index 6b160bbe032..da1f2199d0d 100755 --- a/tensorflow/tools/ci_build/install/install_deb_packages.sh +++ b/tensorflow/tools/ci_build/install/install_deb_packages.sh @@ -46,6 +46,7 @@ apt-get install -y --no-install-recommends \ git \ libcurl4-openssl-dev \ libtool \ + mlocate \ openjdk-8-jdk \ openjdk-8-jre-headless \ pkg-config \ @@ -63,6 +64,9 @@ apt-get install -y --no-install-recommends \ zip \ zlib1g-dev +# populate the database +updatedb + if [[ "$1" != "--without_cmake" ]]; then apt-get install -y --no-install-recommends \ cmake diff --git a/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh index 762c5317258..d90a1b905d9 100755 --- a/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh +++ b/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh @@ -28,6 +28,7 @@ export TF_NEED_GCP=0 export TF_NEED_HDFS=0 export TF_NEED_CUDA=0 export TF_NEED_OPENCL=0 +export TF_NEED_MKL=0 export COMPUTECPP_PATH="/usr/local" export PATH="/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin" diff --git a/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh b/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh index 1da5e8c2bf3..79973647c11 100755 --- a/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh +++ b/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh @@ -29,6 +29,7 @@ export PYTHON_BIN_PATH="/usr/bin/python" export TF_NEED_GCP=0 export TF_NEED_HDFS=0 export TF_NEED_OPENCL=0 +export TF_NEED_MKL=0 export COMPUTECPP_PATH="/usr/local" export PATH="/usr/local/cuda/bin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin" diff --git a/third_party/grpc.BUILD b/third_party/grpc.BUILD index 1d1e2222dea..d776574f79e 100644 --- a/third_party/grpc.BUILD +++ b/third_party/grpc.BUILD @@ -178,6 +178,7 @@ cc_library( ], deps = [ ], + linkopts = ["-lpthread"], ) cc_library( @@ -1787,6 +1788,7 @@ cc_library( ":grpc_unsecure", "//external:protobuf_clib", ], + linkopts = ["-lpthread"], ) cc_library( diff --git a/third_party/jemalloc.BUILD b/third_party/jemalloc.BUILD index 8ed13c51a5d..b1c639a4544 100644 --- a/third_party/jemalloc.BUILD +++ b/third_party/jemalloc.BUILD @@ -94,6 +94,9 @@ cc_library( "@%ws%//tensorflow:linux_ppc64le": [ "-lpthread", ], + "@%ws%//tensorflow:linux_x86_64": [ + "-lpthread", + ], "//conditions:default": [ ], }), diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD index d5ab3262835..306ac517673 100644 --- a/third_party/llvm/llvm.BUILD +++ b/third_party/llvm/llvm.BUILD @@ -1695,6 +1695,7 @@ cc_library( ":demangle", "@zlib_archive//:zlib", ], + linkopts = ["-lpthread", "-ldl"], ) cc_library( diff --git a/third_party/mkl/BUILD b/third_party/mkl/BUILD index 7e95ebd3551..8c86766effa 100644 --- a/third_party/mkl/BUILD +++ b/third_party/mkl/BUILD @@ -16,6 +16,7 @@ load( cc_library( name = "intel_binary_blob", srcs = if_mkl([ + "libdl.so.2", "libmklml_intel.so", "libiomp5.so", ]),