Enable MKL in configure and various bug fixes (#9580)

* relu grad and maxpooling grad fixes for perf

* Graph layout pass and conversion pass changes

This commit makes following changes:

- Enables support for ReluGrad and BiasAddGrad
- Adds support for detecting depthwise/batchwise pooling
- Adds more unit tests for Graph rewrite pass
- Improvements to handling control-flow edges
- Bug fixes

* Defaulting to Eigen when LRN depth_radius!=2

* Fixed mkl_conv_grad_filter.cc for conv_ops_tests.py

* Style fix to mkl_matmul and remove unnecessary 'MKL' label on matmul kernel

* Style fixes based on clang-format to mkl_conv_* and mkl_matmul

* Bug fixes

* Adding OP_REQUIRES_OK check in Concat

* Making some style changes

* Enabled the configuration of MKL settings

* relu grad and maxpooling grad fixes for perf

* Graph layout pass and conversion pass changes

This commit makes following changes:

- Enables support for ReluGrad and BiasAddGrad
- Adds support for detecting depthwise/batchwise pooling
- Adds more unit tests for Graph rewrite pass
- Improvements to handling control-flow edges
- Bug fixes

* Defaulting to Eigen when LRN depth_radius!=2

* Fixed mkl_conv_grad_filter.cc for conv_ops_tests.py

* Style fix to mkl_matmul and remove unnecessary 'MKL' label on matmul kernel

* Style fixes based on clang-format to mkl_conv_* and mkl_matmul

* Bug fixes

* Adding OP_REQUIRES_OK check in Concat

* Making some style changes

* Enabled the configuration of MKL settings

* Fixing graph unit tests with Mkl op name change to _Mkl; Fixed missing _ in MklToTf op

* Fixed missing libdl.so.2 in BUILD file

* Fixes for unit test build failures.

* Changes in mkl_conv_grad_filter_ops.cc for Google code style

* Fixes to remove dead code

* removed the dead code and added a TODO for mkl implementation to handle this case in the future

* Fixed buildifier sanity check error

* Adding support for google's CI automation

* Updated link to new MKL version

* Fix for missing locate command in CI

* Adding updatedb to populate the database after installing mlocate

* Fixed buildifier issue

* setting tf_need_mkl=0 in libtf files

* Added third_party/mkl/* to .gitignore

* Added third_party/eigen3/mkl_include to .gitignore

* In configured, set MKL-enabling options only for Linux.
This commit is contained in:
Vivek Rane 2017-05-04 14:02:25 -07:00 committed by Vijay Vasudevan
parent 3273cf4f4d
commit 27dd167c5f
27 changed files with 1244 additions and 407 deletions

2
.gitignore vendored
View File

@ -4,6 +4,8 @@ node_modules
/.bazelrc
/.tf_configure.bazelrc
/bazel-*
/third_party/eigen3/mkl_include
/third_party/mkl/*
/third_party/py/numpy/numpy_include
/tools/python_bin_path.sh
/tools/git/gen

87
configure vendored
View File

@ -180,25 +180,35 @@ fi
setup_python
## Set up MKL related environment settings
if false; then # Disable building with MKL for now
while [ "$TF_NEED_MKL" == "" ]; do
fromuser=""
read -p "Do you wish to build TensorFlow with MKL support (experimental)? [y/N] " INPUT
fromuser="1"
case $INPUT in
[Yy]* ) echo "MKL support (experimental) (will be enabled for TensorFlow"; TF_NEED_MKL=1;;
[Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
"" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
while [ "$TF_NEED_MKL" == "" ]; do
fromuser=""
read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT
fromuser="1"
case $INPUT in
[Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;;
[Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
"" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
OSNAME=`uname -s`
OSNAME=`uname -s`
if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL
if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL
fromuser=""
read -p "Do you wish to download MKL LIB from the web? [Y/n] " INPUT
fromuser="1"
case $INPUT in
[Yy]* ) TF_DOWNLOAD_MKL=1;;
[Nn]* ) TF_DOWNLOAD_MKL=0;;
"" ) TF_DOWNLOAD_MKL=1;;
* ) echo "Invalid selection: " $INPUT; exit 1;;
esac
if [[ "$TF_DOWNLOAD_MKL" == "1" ]]; then
DST=`dirname $0`
ARCHIVE_BASENAME=mklml_lnx_2017.0.2.20170209.tgz
GITHUB_RELEASE_TAG=v0.5
ARCHIVE_BASENAME=mklml_lnx_2018.0.20170425.tgz
GITHUB_RELEASE_TAG=v0.7
MKLURL="https://github.com/01org/mkl-dnn/releases/download/$GITHUB_RELEASE_TAG/$ARCHIVE_BASENAME"
if ! [ -e "$DST/third_party/mkl/$ARCHIVE_BASENAME" ]; then
wget --no-check-certificate -P $DST/third_party/mkl/ $MKLURL
@ -208,7 +218,20 @@ if false; then # Disable building with MKL for now
MKL_INSTALL_PATH=$DST/third_party/mkl/$extracted_dir_name
MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"`
if [ "$OSNAME" == "Linux" ]; then
else
default_mkl_path=/opt/intel/mklml
fromuser=""
read -p "Please specify the location where MKL is installed. [Default is $default_mkl_path]: " MKL_INSTALL_PATH
fromuser="1"
if [ -z "$MKL_INSTALL_PATH" ]; then
MKL_INSTALL_PATH=$default_mkl_path
fi
# Result returned from "read" will be used unexpanded. That make "~" unuseable.
# Going through one more level of expansion to handle that.
MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"`
fi
if [ "$OSNAME" == "Linux" ]; then
# Full MKL configuration
MKL_RT_LIB_PATH="lib/intel64/libmkl_rt.so" #${TF_MKL_EXT}#TODO version?
MKL_RT_OMP_LIB_PATH="../compiler/lib/intel64/libiomp5.so" #TODO VERSION?
@ -216,24 +239,33 @@ if false; then # Disable building with MKL for now
# MKL-ML configuration
MKL_ML_LIB_PATH="lib/libmklml_intel.so" #${TF_MKL_EXT}#TODO version?
MKL_ML_OMP_LIB_PATH="lib/libiomp5.so" #TODO VERSION?
elif [ "$OSNAME" == "Darwin" ]; then
elif [ "$OSNAME" == "Darwin" ]; then
echo "Darwin is unsupported yet";
exit 1
fi
fi
if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then
if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then
ln -sf $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} third_party/mkl/
ln -sf $MKL_INSTALL_PATH/${MKL_ML_OMP_LIB_PATH} third_party/mkl/
ln -sf $MKL_INSTALL_PATH/include third_party/mkl/
ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include
else
echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} does not exist";
loc=$(locate -e libdl.so.2 | sed -n 1p)
ln -sf $loc third_party/mkl/libdl.so.2
elif [ -e "$MKL_INSTALL_PATH/${MKL_RT_LIB_PATH}" ]; then
ln -sf $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} third_party/mkl/
ln -sf $MKL_INSTALL_PATH/${MKL_RT_OMP_LIB_PATH} third_party/mkl/
ln -sf $MKL_INSTALL_PATH/include third_party/mkl/
ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include
loc=$(locate -e libdl.so.2 | sed -n 1p)
ln -sf $loc third_party/mkl/libdl.so.2
else
echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} nor $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} exists";
exit 1
fi
fi
if [ -z "$fromuser" ]; then
if [ -z "$fromuser" ]; then
exit 1
fi
fi
cat > third_party/mkl/mkl.config <<EOF
# MKL_INSTALL_PATH refers to the location of MKL root folder. The MKL header and library
@ -241,9 +273,8 @@ cat > third_party/mkl/mkl.config <<EOF
MKL_INSTALL_PATH=$MKL_INSTALL_PATH
EOF
fi # TF_NEED_MKL
################## MKL
fi # Disable building with MKL for now
fi # TF_NEED_MKL
## End MKL setup
## Set up architecture-dependent optimization flags.
if [ -z "$CC_OPT_FLAGS" ]; then

View File

@ -98,6 +98,7 @@ cc_library(
name = "simple_orc_jit",
srcs = ["simple_orc_jit.cc"],
hdrs = ["simple_orc_jit.h"],
linkopts = ["-ldl"],
deps = [
":compiler_functor",
":cpu_runtime",

View File

@ -17,6 +17,7 @@
#include <memory>
#include <vector>
#include <cmath>
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h"
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h"

View File

@ -1282,7 +1282,10 @@ cc_library(
] + tf_additional_verbs_lib_defines(),
linkopts = select({
"//tensorflow:freebsd": [],
"//conditions:default": ["-ldl"],
"//conditions:default": [
"-ldl",
"-lpthread",
],
}),
deps = tf_additional_lib_deps() + [
":lib_hash_crc32c_accelerate_internal",

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/graph/mkl_layout_pass.h"
#include "tensorflow/core/util/mkl_util.h"
@ -280,51 +281,72 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.mkl_conv2d = "_MklConv2D";
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
csinfo_.mkl_conv2d_with_bias_backprop_bias =
"_MklConv2DWithBiasBackpropBias";
csinfo_.relu = "Relu";
csinfo_.reshape = "Reshape";
csinfo_.relu_grad = "ReluGrad";
csinfo_.split = "Split";
"_MklConv2DWithBiasBackpropBias";
csinfo_.relu = "Relu";
csinfo_.relu_grad = "ReluGrad";
csinfo_.reshape = "Reshape";
csinfo_.split = "Split";
// NOTE: names are alphabetically sorted.
rinfo_.push_back({csinfo_.avg_pool, GetMklOpName(csinfo_.avg_pool), 1,
CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.avg_pool,
GetMklOpName(csinfo_.avg_pool),
CopyAttrsPooling, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.avg_pool_grad,
GetMklOpName(csinfo_.avg_pool_grad), 2, CopyAttrsPooling,
AlwaysRewrite});
rinfo_.push_back({csinfo_.concat, GetMklOpName(csinfo_.concat), 0,
CopyAttrsConcat, AlwaysRewrite});
rinfo_.push_back({csinfo_.concatv2, GetMklOpName(csinfo_.concatv2), 0,
CopyAttrsConcatV2, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d, GetMklOpName(csinfo_.conv2d), 2,
CopyAttrsConv2D, AlwaysRewrite});
GetMklOpName(csinfo_.avg_pool_grad),
CopyAttrsPooling, AlwaysRewrite, nullptr});
// BiasAddGrad gets written into Conv2DWithBiasBackpropBias depending
// on if context contains Conv2D.
rinfo_.push_back({csinfo_.bias_add_grad,
csinfo_.mkl_conv2d_with_bias_backprop_bias,
CopyAttrsBiasAddGrad, ContextMatchRewrite,
&biasaddgrad_conv2dwithbias_context_});
// BiasAddGrad gets written into BiasAddGrad depending on if context
// contains MatMul.
rinfo_.push_back({csinfo_.bias_add_grad, csinfo_.matmul,
CopyAttrsBiasAddGrad, ContextMatchRewrite,
&biasaddgrad_matmul_context_});
rinfo_.push_back({csinfo_.concat,
GetMklOpName(csinfo_.concat),
CopyAttrsConcat, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.concatv2,
GetMklOpName(csinfo_.concatv2),
CopyAttrsConcatV2, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.conv2d,
GetMklOpName(csinfo_.conv2d),
CopyAttrsConv2D, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.conv2d_grad_filter,
GetMklOpName(csinfo_.conv2d_grad_filter), 3,
CopyAttrsConv2D, AlwaysRewrite});
GetMklOpName(csinfo_.conv2d_grad_filter),
CopyAttrsConv2D, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.conv2d_grad_input,
GetMklOpName(csinfo_.conv2d_grad_input), 3,
CopyAttrsConv2D, AlwaysRewrite});
GetMklOpName(csinfo_.conv2d_grad_input),
CopyAttrsConv2D, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.fused_batch_norm,
GetMklOpName(csinfo_.fused_batch_norm), 5,
CopyAttrsFusedBatchNorm, AlwaysRewrite});
GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.fused_batch_norm_grad,
GetMklOpName(csinfo_.fused_batch_norm_grad), 5,
CopyAttrsFusedBatchNorm, AlwaysRewrite});
rinfo_.push_back({csinfo_.lrn, GetMklOpName(csinfo_.lrn), 1, CopyAttrsLRN,
AlwaysRewrite});
rinfo_.push_back({csinfo_.lrn_grad, GetMklOpName(csinfo_.lrn_grad), 3,
CopyAttrsLRN, AlwaysRewrite});
rinfo_.push_back({csinfo_.max_pool, GetMklOpName(csinfo_.max_pool), 1,
CopyAttrsPooling, AlwaysRewrite});
GetMklOpName(csinfo_.fused_batch_norm_grad),
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.lrn,
GetMklOpName(csinfo_.lrn),
CopyAttrsLRN, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.lrn_grad,
GetMklOpName(csinfo_.lrn_grad),
CopyAttrsLRN, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.max_pool,
GetMklOpName(csinfo_.max_pool),
CopyAttrsPooling, NonDepthBatchWisePoolRewrite, nullptr});
rinfo_.push_back({csinfo_.max_pool_grad,
GetMklOpName(csinfo_.max_pool_grad), 3, CopyAttrsPooling,
AlwaysRewrite});
rinfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu), 1,
CopyAttrsRelu, AlwaysRewrite});
rinfo_.push_back({csinfo_.reshape, GetMklOpName(csinfo_.reshape), 2,
CopyAttrsReshape, AlwaysRewrite});
// TODO(inteltf): we do not support ReluGrad and BiasAddGrad yet.
GetMklOpName(csinfo_.max_pool_grad),
CopyAttrsPooling, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.relu,
GetMklOpName(csinfo_.relu),
CopyAttrsRelu, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.relu_grad,
GetMklOpName(csinfo_.relu_grad),
CopyAttrsRelu, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.reshape,
GetMklOpName(csinfo_.reshape),
CopyAttrsReshape, AlwaysRewrite, nullptr});
// Add info about which ops to add workspace edge to and the slots.
wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
@ -338,8 +360,15 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// maxhops in backward data-flow graph. Since input of forward nodes
// (Conv2D) directly goes to backward nodes, we do not expect the
// hop-distance would be more than few nodes.
cinfo_.push_back({csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias,
kNodeMergeContextMaxDepth});
biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul,
kNodeMergeContextMaxDepth};
biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad,
csinfo_.mkl_conv2d_with_bias,
kNodeMergeContextMaxDepth};
cinfo_.push_back(&biasaddgrad_matmul_context_);
cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_);
}
// Standard interface to run pass
@ -354,7 +383,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// @return true, if and only if graph is mutated; false otherwise.
bool RunPass(std::unique_ptr<Graph>* g);
private:
/// Structure to specify the context information used in a node rewrite rule
typedef struct {
string node; // Name of the node to be rewritten
string fwd; // Name of the node in the forward pass that this node
// corresponds to
size_t max_hop; // Maximum number of hops the fwd is located
// from this node. If the fwd is farther than max_hop
// then we do not rewrite the node.
} ContextInfo;
/// Structure to specify the name of an original node, its new name after
/// rewrite, the number of inputs to the original node, the function to
/// be used to copy attributes for the op, and the rule (if any) which
@ -362,11 +400,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
typedef struct {
string name; // Original name of op of the node in the graph
string new_name; // New name of the op of the node in the graph
int num_ins; // The number of inputs to the original op type
// A function handler to copy attributes from an old node to a new node.
std::function<void(const Node*, NodeBuilder*)> copy_attrs;
std::function<bool(const Node*)> rewrite_rule; // A rule under which to
// rewrite this node.
// A rule under which to rewrite this node
std::function<bool(const Node*, const ContextInfo* c)> rewrite_rule;
// ContextInfo, if any, to be used for rewrite
ContextInfo* context;
} RewriteInfo;
/// Structure to specify a forward op, a backward op, and the slot numbers
@ -393,16 +432,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string new_node; // Name of the node after merge
} MergeInfo;
/// Structure to specify the context information used in a node rewrite rule
typedef struct {
string node; // Name of the node to be rewritten
string fwd; // Name of the node in the forward pass that this node
// corresponds to
size_t max_hop; // Maximum number of hops the fwd is located
// from this node. If the fwd is farther than max_hop
// then we do not rewrite the node.
} ContextInfo;
/// Structure to store all constant strings
/// NOTE: names are alphabetically sorted.
struct {
@ -427,10 +456,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string mkl_conv2d_with_bias_backprop_bias;
string relu;
string relu_grad;
string split;
string reshape;
string split;
} csinfo_;
private:
/// Maintain info about nodes to rewrite
std::vector<RewriteInfo> rinfo_;
@ -441,7 +471,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
std::vector<MergeInfo> minfo_;
/// Maintain info about nodes to rewrite
static std::vector<ContextInfo> cinfo_;
static std::vector<ContextInfo*> cinfo_;
/// Context variables used in referencing rules
static ContextInfo biasaddgrad_matmul_context_;
static ContextInfo biasaddgrad_conv2dwithbias_context_;
/// Hash table to maintain nodes visited in the graph.
std::unordered_set<const Node*> visited_nodes_;
@ -464,19 +498,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// Clear all visited nodes
inline void UnMarkRewrittenNodes() { visited_nodes_.clear(); }
// Is this a graph node that can accept variable number of inputs?
// Return true if yes, false otherwise.
//
// Concat, Split are vararg nodes.
inline bool IsVarArgNode(Node* n) {
if (n->type_string() == csinfo_.concat ||
n->type_string() == csinfo_.concatv2 ||
n->type_string() == csinfo_.split) {
return true;
}
return false;
}
// Is OpDef::ArgDef a list type? It could be N * T or list(type).
// Refer to opdef.proto for details of list type.
inline bool ArgIsList(const OpDef::ArgDef& arg) const {
@ -510,6 +531,39 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return string(kMklOpPrefix) + name;
}
// Can op represented by node 'n' run on DEVICE_CPU?
// Op can run on CPU with MKL if the runtime assigned device or the
// user requested device contains device CPU, or both are empty.
bool CanOpRunOnCPUDevice(const Node* n) {
bool result = true;
string reason;
// Substring that should be checked for in device name for CPU device.
const char* const kCPUDeviceSubStr = "cpu";
// If Op has been specifically assigned to a non-CPU device, then No.
if (!n->assigned_device_name().empty() &&
!StringPiece(n->assigned_device_name()).contains(kCPUDeviceSubStr)) {
result = false;
reason = "Op has been assigned a runtime device that is not CPU.";
}
// If user has specifically assigned this op to a non-CPU device, then No.
if (!n->def().device().empty() &&
!StringPiece(n->def().device()).contains(kCPUDeviceSubStr)) {
result = false;
reason = "User has assigned a device that is not CPU.";
}
if (result == false) {
VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
<< n->type_string() << ", reason: " << reason;
}
// Otherwise Yes.
return result;
}
// Return a node that can be merged with input node 'n'
//
// @return pointer to the node if we can find such a
@ -538,13 +592,46 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// Default rewrite rule to be used in scenario 1 for rewrite.
// @return - true (since we want to always rewrite)
static bool AlwaysRewrite(const Node* n) { return true; }
// Rewrite rule that uses context-information for matching
static bool AlwaysRewrite(const Node* n, const ContextInfo* c = nullptr) {
return true;
}
// Check if we are performing pooling on depth or batch. If it is, then we
// do not rewrite MaxPool node to Mkl version.
// @return - true (if it is not a depth/batch wise pooling case);
// false otherwise.
static bool NonDepthBatchWisePoolRewrite(const Node* n,
const ContextInfo* c) {
CHECK_NOTNULL(n);
string data_format_str;
TensorFormat data_format;
std::vector<int32> ksize, strides;
CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(),
true);
CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
// Condition that specifies non-batch-wise and non-depth-wise pooling.
if (GetTensorDim(ksize, data_format, 'N') == 1 &&
GetTensorDim(strides, data_format, 'N') == 1 &&
GetTensorDim(ksize, data_format, 'C') == 1 &&
GetTensorDim(strides, data_format, 'C') == 1) {
return true;
}
return false;
}
// Rewrite rule that uses context-information for matching,
// used in scenario 2.
//
// @input - Node 'n' for which to search for matching context
// @return - true if matching context is found; false otherwise.
static bool ContextMatchRewrite(const Node* n);
// @input - The context 'c' under which to rewrite
// @return - true if we can rewrite node under context 'c';
// false otherwise.
static bool ContextMatchRewrite(const Node* n, const ContextInfo* c);
// Helper function that searches the matching contextinfo for the node.
// Implements depth-first search in the data dependence graph for the
@ -598,6 +685,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// node that we are constructing.
//
// @input g - input graph,
// @input orig_node - Original node that we are rewriting
// @input inputs - inputs to old node that we are using for constructing
// new inputs,
// @input input_idx - the index in the 'inputs' vector pointing to the
@ -608,11 +696,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// @output output_nodes - the list of new nodes creating Mkl tensors
//
// @return None
void GetNodesProducingMklTensorList(
std::unique_ptr<Graph>* g,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
int* input_idx, int list_length,
std::vector<NodeBuilder::NodeOut>* output_nodes);
void GetNodesProducingMklTensorList(std::unique_ptr<Graph>* g,
Node* orig_node, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
int* input_idx, int list_length,
std::vector<NodeBuilder::NodeOut>* output_nodes);
// Get a node that will feed an Mkl tensor to the new
// node that we are constructing. The output node could be (1) 'n'
@ -620,6 +707,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// if 'n' is not an Mkl layer.
//
// @input g - input graph,
// @input orig_node - Original node that we are rewriting,
// @input n - Node based on which we are creating Mkl node,
// @input n_output_slot - the output slot of node 'n'
// which is feeding to the node that we are constructing
@ -627,9 +715,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// @output mkl_node_output_slot - the slot number of mkl_node that
// will feed the tensor
// @return None
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* n,
int n_output_slot, Node** mkl_node,
int* mkl_node_output_slot);
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot);
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
// in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
@ -695,13 +782,18 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
Node* orig_node);
};
std::vector<MklLayoutRewritePass::ContextInfo> MklLayoutRewritePass::cinfo_;
MklLayoutRewritePass::ContextInfo
MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_;
MklLayoutRewritePass::ContextInfo
MklLayoutRewritePass::biasaddgrad_matmul_context_;
std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
// We register Mkl rewrite pass for phase 1 in post rewrite group.
// We register Mkl rewrite pass for phase 1 in post partitioning group.
// We register it here so that we get a complete picture of all users of Mkl
// nodes. Do not change the ordering of the Mkl passes.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 1,
MklLayoutRewritePass);
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
//////////////////////////////////////////////////////////////////////////
// Helper functions for creating new node
@ -737,27 +829,14 @@ void MklLayoutRewritePass::GetNodesProducingTFTensorList(
while (list_length != 0) {
CHECK_GT(list_length, 0);
CHECK_LE(*input_idx, inputs.size());
CHECK_LT(*input_idx, inputs.size());
Node* n = inputs[*input_idx].first;
int slot = inputs[*input_idx].second;
const OpDef::ArgDef& arg = n->op_def().output_arg(slot);
// If input node 'n' is producing a list/array output at output
// slot 'slot' then we need to find out the length of that list/array.
if (ArgIsList(arg)) {
int N = GetTensorListLength(arg, n);
CHECK_LE(N, list_length);
for (int j = 0; j < N; j++) {
output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
}
(*input_idx)++;
list_length -= N;
} else {
// But if input node 'n' is just producing a single tensor at
// output slot 'slot' then we just add that single node.
output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
(*input_idx)++;
list_length--;
}
// If input node 'n' is just producing a single tensor at
// output slot 'slot' then we just add that single node.
output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
(*input_idx)++;
list_length--;
}
}
@ -775,20 +854,39 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
TensorShape dummy_shape({8});
dummy_shape.AsProto(proto.mutable_tensor_shape());
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
.Attr("value", proto)
.Attr("dtype", dt)
.Device(orig_node->def().device()) // We place this node on
// the same device as the
// device of the original
// node.
.Finalize(&**g, out));
.Attr("value", proto)
.Attr("dtype", dt)
.Device(orig_node->def().device()) // We place this node on
// the same device as the
// device of the original
// node.
.Finalize(&**g, out));
// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
// the dummy Mkl node. This is needed because control-flow ops such as Enter,
// Merge, etc, require frame_name of the dummy Mkl node to be same as the
// rewritten node. Adding control edge between 1st input of the original node
// and the dummy Mkl node ensures that the dummy node is in the same frame
// as the original node. Choosing 1st input is not necessary - any input of
// the original node is fine because all the inputs of a node are always in
// the same frame.
if (orig_node->num_inputs() > 0) {
Node* orig_input0 = nullptr;
TF_CHECK_OK(orig_node->input_node(0,
const_cast<const Node**>(&orig_input0)));
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
}
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
}
void MklLayoutRewritePass::GetNodesProducingMklTensorList(
std::unique_ptr<Graph>* g,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
Node* orig_node,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
int* input_idx, int list_length,
std::vector<NodeBuilder::NodeOut>* output_nodes) {
CHECK_LT(*input_idx, inputs.size());
CHECK_GT(list_length, 0);
CHECK_NOTNULL(output_nodes);
@ -796,38 +894,19 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
while (list_length != 0) {
CHECK_GT(list_length, 0);
CHECK_LE(*input_idx, inputs.size());
CHECK_LT(*input_idx, inputs.size());
Node* n = inputs[*input_idx].first;
int slot = inputs[*input_idx].second;
const OpDef::ArgDef& arg = n->op_def().output_arg(slot);
// We need to check first if the input edge is going to carry a
// single tensor or a list of tensors. If it is a list of tensors,
// then we need to create list of Mkl dummy nodes.
if (ArgIsList(arg)) {
// If input node 'n' is producing a list/array output at output
// slot 'slot' then we need to find out the length of that list/array.
int N = GetTensorListLength(arg, n);
CHECK_LE(N, list_length);
Node* mkl_node = nullptr;
int mkl_node_output_slot = 0;
// If it is a list, then create a list of Mkl dummy nodes.
for (int j = 0; j < N; j++) {
GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
output_nodes->push_back(
NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
}
(*input_idx)++;
list_length -= N;
} else {
// If it is not a list, then create a single Mkl tensor node.
Node* mkl_node = nullptr;
int mkl_node_output_slot = 0;
GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
output_nodes->push_back(
NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
(*input_idx)++;
list_length--;
}
// If 'n' is producing a single tensor, then create a single Mkl tensor
// node.
Node* mkl_node = nullptr;
int mkl_node_output_slot = 0;
GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
&mkl_node_output_slot);
output_nodes->push_back(NodeBuilder::NodeOut(mkl_node,
mkl_node_output_slot));
(*input_idx)++;
list_length--;
}
}
@ -835,9 +914,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
// node that we are constructing. An input node could be (1) 'n'
// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
// if 'n' is not an Mkl layer.
void MklLayoutRewritePass::GetNodeProducingMklTensor(
std::unique_ptr<Graph>* g, Node* n, int n_output_slot, Node** mkl_node,
int* mkl_node_output_slot) {
void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
Node* orig_node, Node* n,
int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
CHECK_NOTNULL(n);
CHECK_NOTNULL(mkl_node);
CHECK_NOTNULL(mkl_node_output_slot);
@ -860,7 +939,7 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor(
// to create a dummy node that will feed a dummy Mkl tensor to this node.
// DummyMklTensor node has no input and generates only 1 output
// (dummy Mkl tensor) as output slot number 0.
GetDummyMklTensorNode(g, mkl_node, n);
GetDummyMklTensorNode(g, mkl_node, orig_node);
CHECK_NOTNULL(*mkl_node);
*mkl_node_output_slot = 0;
}
@ -926,16 +1005,16 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
if (ArgIsList(arg)) {
std::vector<NodeBuilder::NodeOut> new_node_inputs;
int N = GetTensorListLength(arg, old_node);
GetNodesProducingMklTensorList(g, old_node_inputs, &iidx, N,
&new_node_inputs);
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
N, &new_node_inputs);
nb->Input(new_node_inputs);
nn_slot_idx++;
} else {
Node* mkl_node = nullptr;
int mkl_node_output_slot = 0;
GetNodeProducingMklTensor(g, old_node_inputs[iidx].first,
old_node_inputs[iidx].second, &mkl_node,
&mkl_node_output_slot);
GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
old_node_inputs[iidx].second,
&mkl_node, &mkl_node_output_slot);
nb->Input(mkl_node, mkl_node_output_slot);
iidx++;
nn_slot_idx++;
@ -1020,13 +1099,30 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
TensorShape dummy_shape({1});
dummy_shape.AsProto(proto.mutable_tensor_shape());
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
.Attr("value", proto)
.Attr("dtype", dt)
.Device(orig_node->def().device()) // We place this node on
// same the device as the
// device of the original
// node.
.Finalize(&**g, out));
.Attr("value", proto)
.Attr("dtype", dt)
.Device(orig_node->def().device()) // We place this node on
// same the device as the
// device of the original
// node.
.Finalize(&**g, out));
// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
// the dummy Mkl node. This is needed because control-flow ops such as Enter,
// Merge, etc, require frame_name of the dummy Mkl node to be same as the
// rewritten node. Adding control edge between 1st input of the original node
// and the dummy Mkl node ensures that the dummy node is in the same frame
// as the original node. Choosing 1st input is not necessary - any input of
// the original node is fine because all the inputs of a node are always in
// the same frame.
if (orig_node->num_inputs() > 0) {
Node* orig_input0 = nullptr;
TF_CHECK_OK(orig_node->input_node(0,
const_cast<const Node**>(&orig_input0)));
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
}
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
}
@ -1235,6 +1331,19 @@ void MklLayoutRewritePass::CopyAttrsRelu(const Node* orig_node,
nb->Attr("T", T);
}
void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
NodeBuilder* nb) {
DataType T;
DataType Tshape;
// Get all attributes from old node.
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
// Add attributes to new node.
nb->Attr("T", T);
nb->Attr("Tshape", Tshape);
}
void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
NodeBuilder* nb) {
DataType T;
@ -1303,20 +1412,6 @@ void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
nb->Attr("is_training", is_training);
}
void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
NodeBuilder* nb) {
DataType T;
DataType Tshape;
// Get all attributes from old node.
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
// Add attributes to new node.
nb->Attr("T", T);
nb->Attr("Tshape", Tshape);
}
//////////////////////////////////////////////////////////////////////////
// Helper functions related to node merge pass
//////////////////////////////////////////////////////////////////////////
@ -1353,8 +1448,9 @@ Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
continue;
}
const int B_in = b->num_inputs();
gtl::InlinedVector<Node*, 4> b_control_edges;
gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in);
gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
FillInputs(b, &b_control_edges, &b_in);
// Shouldn't merge if a and b have different control edges.
@ -1438,7 +1534,7 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
CHECK_EQ(succ->in_edges().size(), 2);
Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3
int oper3_mkl_slot = 0; // For dummy MKL tensor node, output slot is 0.
GetDummyMklTensorNode(g, &oper3_mkl, succ); // Get dummy Mkl tensor node
GetDummyMklTensorNode(g, &oper3_mkl, pred); // Get dummy Mkl tensor node
// as BiasAdd does not have Mkl tensor as input.
CHECK_NOTNULL(oper3_mkl);
@ -1483,9 +1579,38 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
// Set the Mkl layer label for this op.
new_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
// Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
// node are already copied in BuildNode. We handle control edges now.
for (const Edge* e : pred->in_edges()) {
if (e->IsControlEdge()) {
CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
}
}
for (const Edge* e : succ->in_edges()) {
if (e->IsControlEdge()) {
CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
}
}
// Incoming edges are fixed, we will fix the outgoing edges now.
// First, we will fix outgoing control edges from 'pred' node.
// We don't need to handle outgoing data edges from 'pred' node
// because pred has only 1 output going to succ node (we enforced
// this check for merge already).
for (const Edge* e : pred->out_edges()) {
if (e->IsControlEdge()) {
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
}
}
// Second, we will fix outgoing control and data edges from 'succ' node.
for (const Edge* e : succ->out_edges()) {
(*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input());
if (e->IsControlEdge()) {
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
} else {
CHECK_NOTNULL((*g)->AddEdge(new_node, e->src_output(), e->dst(),
e->dst_input()));
}
}
// Copy device assigned to old node to new node.
@ -1550,18 +1675,22 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
"data_format or T attribute or devices of BiasAddGrad and "
"Conv2D do not match. Will skip node rewrite optimization");
}
} else if (orig_node->type_string() == csinfo_.bias_add_grad &&
ri->new_name == csinfo_.matmul) {
// When BiasAddGrad has MatMul in context, we do not do any rewrite
// and leave BiasAddGrad as it is. But we check for this condition
// when we check for node rewrite rule. So we should not even come
// here for MatMul. So we will fail now.
return Status(
error::Code::INVALID_ARGUMENT,
"No rewrite is required for BiasAddGrad for MatMul context.");
}
}
// Get all inputs.
const int num = orig_node->in_edges().size();
// Check the number of inputs against the user-specified value for non-vararg
// nodes.
if (!IsVarArgNode(orig_node)) {
CHECK_EQ(num, ri->num_ins);
}
const int num_inputs = orig_node->in_edges().size();
gtl::InlinedVector<Node*, 4> control_edges;
gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num);
gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs);
FillInputs(orig_node, &control_edges, &inputs);
// Build new node. We use same name as original node, but change the op name.
@ -1596,8 +1725,15 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
TF_CHECK_OK(nb.Finalize(&**g, &new_node));
CHECK_NOTNULL(new_node);
// Incoming edges from 'orig_node' node to new 'new_node' node are already
// copied in BuildNode. Copy outgoing edges from 'orig_node' node to new
// Incoming data edges from 'orig_node' node to new 'new_node' node are
// already copied in BuildNode. We need to handle control edges now.
for (const Edge* e : orig_node->in_edges()) {
if (e->IsControlEdge()) {
CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
}
}
// Copy outgoing edges from 'orig_node' node to new
// 'new_node' node, since the output also follows same ordering among
// Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
// tensors appropriately. Specifically, nth output of the original node
@ -1605,15 +1741,12 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
// of the tensors. For the contiguous ordering of the tensors, it will be n.
// GetTensorDataIndex provides this mapping function.
for (const Edge* e : orig_node->out_edges()) {
// We need to handle control-edges by using their original slot number.
// Generally, -1 is reserved for control slot.
if (e->src_output() < 0) {
(*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input());
if (e->IsControlEdge()) {
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
} else {
(*g)->AddEdge(
new_node,
GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
e->dst(), e->dst_input());
CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(),
e->src()->num_outputs()),
e->dst(), e->dst_input()));
}
}
@ -1640,8 +1773,8 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n,
bool is_matching_cinfo_found = false;
std::vector<const ContextInfo*> mci;
for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) {
if (n->type_string() == ci->node) {
mci.push_back(&*ci);
if (n->type_string() == (*ci)->node) {
mci.push_back(*ci);
is_matching_cinfo_found = true;
}
}
@ -1701,9 +1834,10 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n,
return nullptr;
}
bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n) {
bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n,
const ContextInfo* c) {
const Node* fwd_node = nullptr;
return SearchMatchingContext(n, &fwd_node) != nullptr;
return SearchMatchingContext(n, &fwd_node) == c;
}
const MklLayoutRewritePass::RewriteInfo*
@ -1719,18 +1853,29 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
return nullptr;
}
if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) {
return nullptr;
// BiasAddGrad is not an Mkl layer, so we make an exception for it.
if (n->type_string() != csinfo_.bias_add_grad) {
if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) {
return nullptr;
}
}
// We support 2 types of node rewrites:
// 1. Rewriting BiasAddGrad depending on its context.
// 1. Rewriting BiasAddGrad depending on its MklConv2DWithBias context.
// 2. Rewriting an op to Mkl op always
// We return true if any of these 2 conditions is met.
// Find matching RewriteInfo and then check that rewrite rule applies.
for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
if (n->type_string().compare(ri->name) == 0 &&
ri->rewrite_rule(n, ri->context)) {
// If we are rewriting BiasAddGrad into BiasAddGrad for MatMul context,
// then we just return directly.
if (n->type_string() == csinfo_.bias_add_grad &&
ri->context->fwd == csinfo_.matmul &&
ri->new_name == csinfo_.bias_add_grad) {
return nullptr;
}
return &*ri;
}
}
@ -1753,7 +1898,8 @@ bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
GetReversePostOrder(**g, &order); // This will give us topological sort.
for (Node* n : order) {
if (!n->IsOp()) {
// If node is not an op or it cannot run on CPU device, then skip.
if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
continue;
}
@ -1801,18 +1947,31 @@ bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
return MklLayoutRewritePass().RunPass(g);
}
Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
if (options.graph == nullptr) {
Status MklLayoutRewritePass::Run(
const GraphOptimizationPassOptions& options) {
if (options.graph == nullptr && options.partition_graphs == nullptr) {
return Status::OK();
}
// Get the ownership of graph
std::unique_ptr<Graph>* g = std::move(options.graph);
auto process_graph = [&](std::unique_ptr<Graph>* g) {
// Get the ownership of a graph
std::unique_ptr<Graph>* ng = std::move(g);
RunPass(ng);
// Return the ownership of a graph back
g->reset(ng->release());
};
RunPass(g);
// Return the ownership of graph back
options.graph->reset(g->release());
if (kMklLayoutRewritePassGroup !=
OptimizationPassRegistry::POST_PARTITIONING) {
// For any pre-partitioning phase, a graph is stored in options.graph.
process_graph(options.graph);
} else {
// For post partitioning phase, graphs are stored in
// options.partition_graphs.
for (auto& pg : *options.partition_graphs) {
process_graph(&pg.second);
}
}
return Status::OK();
}

View File

@ -39,7 +39,11 @@ limitations under the License.
namespace tensorflow {
namespace {
static void InitGraph(const string& s, Graph* graph) {
const char kCPUDevice[] = "/job:a/replica:0/task:0/cpu:0";
const char kGPUDevice[] = "/job:a/replica:0/task:0/gpu:0";
static void InitGraph(const string& s, Graph* graph,
const string& device = kCPUDevice) {
GraphDef graph_def;
auto parser = protobuf::TextFormat::Parser();
@ -47,14 +51,18 @@ static void InitGraph(const string& s, Graph* graph) {
CHECK(parser.MergeFromString(s, &graph_def)) << s;
GraphConstructorOptions opts;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
for (Node* node : graph->nodes()) {
node->set_assigned_device_name(device);
}
}
class MklLayoutPassTest : public ::testing::Test {
public:
MklLayoutPassTest() : graph_(OpRegistry::Global()) {}
void InitGraph(const string& s) {
::tensorflow::InitGraph(s, &graph_);
void InitGraph(const string& s, const string& device = kCPUDevice) {
::tensorflow::InitGraph(s, &graph_, device);
original_ = CanonicalGraphString(&graph_);
}
@ -114,7 +122,8 @@ REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
REGISTER_OP("_MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful();
REGISTER_OP("_MklInput2").Output("o: uint8")
.Output("o1: uint8").SetIsStateful();
/////////////////////////////////////////////////////////////////////
// Unit tests related to node merge optiimization
@ -162,8 +171,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
"M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
"DMT/_0->E:5;E->Z;M->E:3;N->E:4;Y->Z:1");
"M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;"
"A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;M->E:3;"
"N->E:4;Y->Z:1");
}
// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved)
@ -194,8 +204,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
"M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
"DMT/_0->E:5;E->Z;M:1->E:3;N:1->E:4;Y->Z:1");
"M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;"
"A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;"
"M:1->E:3;N:1->E:4;Y->Z:1");
}
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
@ -226,8 +237,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Sub)|"
"A->E;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;DMT/_2->E:5;"
"E->Z;Y->Z:1");
"A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;"
"DMT/_2->E:5;E->Z;Y->Z:1");
}
// Graph contains only _MklConv2D, no AddBias.
@ -330,9 +342,6 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
"N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3");
}
// Disabling Conv2DBackpropBias test for now as we have disabled rewrite
// of BiasAddGrad into BackpropBias
#if 0
// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias
// rewrite tests
@ -361,18 +370,17 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
" input: ['E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
"E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);"
"O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;DMT/_0->F:1;E->F;"
"M->D:3;N->D:4;O->D:5");
"E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);"
"N(_MklInput);O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;"
"DMT/_0->F:1;E->F;E:control->DMT/_0:control;M->D:3;N->D:4;"
"O->D:5");
}
#endif
// No _MklConv2D in context, but Conv2D in context.
// Only Conv2D would be rewritten to _MklConv2D, but no rewrite
// for BiasAddGrad should happen.
// No _MklConv2DWithBias in context, but _MklConv2D in context.
// No rewrite for BiasAddGrad should happen.
// C=_MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved)
// C=_MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous)
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_No_MklConv2DWithBias) {
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
@ -507,8 +515,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) {
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|"
"A->C;B->C:1;B->D;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
"A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);"
"DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
"DMT/_1->C:3");
}
// 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will
@ -535,7 +545,9 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;B->C:1;C->D:1;C->E;"
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->C:1;C->D:1;C->E;"
"C:1->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
}
@ -558,6 +570,50 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
"A->C;B->C:1;B->D;C->D:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Conv2DBackpropFilter'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropFilter);"
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Mul)|"
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
"DMT/_1->D:4;DMT/_2->D:5");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradInput_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Conv2DBackpropInput'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['B', 'A', 'C']}"
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropInput);"
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Mul)|"
"A->D:1;A->E;B->D;B:control->DMT/_0:control;"
"B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;"
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// Concat Op test: Concat with no Mkl layer feeding it
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
InitGraph(
@ -572,13 +628,14 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
"node { name: 'D' op: 'Concat'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['A', 'B']}"
" input: ['A', 'B:0', 'B:1']}"
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;B->D:1;B->D:2;C->E;"
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;"
"B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// Concat with 2 Mkl layers feeding it
@ -616,9 +673,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
"F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;"
"F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;"
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
"DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;H->I:1");
"DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;"
"G:control->DMT/_4:control;H->I:1");
}
// Concat with 1 Mkl and 1 non-Mkl layer feeding it
@ -651,12 +711,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
"H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
"H(_MklConcat);I(Mul)|A->E;A->I;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
"DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:1->H:4;F->H:2;"
"G->H;H->I:1");
"G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1");
}
#if 0
// ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
InitGraph(
@ -676,11 +736,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;C->E;"
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
"A(Const);B(InputList);C(Input);D(_MklConcatV2);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;"
"B:control->DMT/_0:control;B:control->DMT/_1:control;"
"B:control->DMT/_2:control;C->E;D->E:1;DMT/_0->D:3;"
"DMT/_1->D:4;DMT/_2->D:5");
}
#endif
// ConcatV2 with 2 Mkl layers feeding it
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
@ -718,9 +779,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
"F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;"
"F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;"
"C:control->DMT/_0:control;C:control->DMT/_1:control;"
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
"DMT/_4->H:5;E->H;E:1->H:3;F->H:1;F:1->H:4;G->H:2;H->I:1");
"DMT/_4->H:5;E->H;E:1->H:3;E:control->DMT/_4:control;F->H:1;"
"F:1->H:4;G->H:2;H->I:1");
}
// ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it
@ -754,11 +818,175 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
"H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
"DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;F->H:1;"
"H(_MklConcatV2);I(Mul)|A->E;A->I;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
"DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;"
"E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;"
"G->H:2;H->I:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Relu_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A'] }"
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklRelu);C(Mul);DMT/_0(Const)|A->B;A->C;"
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'ReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklReluGrad);D(Mul);DMT/_0(Const);"
"DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
}
TEST_F(MklLayoutPassTest, NodeRewrite_ReluReluGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A'] }"
"node { name: 'C' op: 'ReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklRelu);C(_MklReluGrad);D(Mul);DMT/_0(Const);"
"DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;"
"DMT/_1->C:2");
}
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'AvgPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklAvgPool);C(Mul);DMT/_0(Const)|A->B;A->C;"
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Int32Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'AvgPoolGrad' "
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Mul);DMT/_0(Const);"
"DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
"DMT/_1->C:3");
}
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolAvgPoolGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'I' op: 'Int32Input'}"
"node { name: 'B' op: 'AvgPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['A'] }"
"node { name: 'C' op: 'AvgPoolGrad' "
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['I', 'B'] }"
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Mul);DMT/_0(Const);"
"DMT/_1(Const);I(Int32Input)|A->B;A->D;A:control->DMT/_0:control;"
"B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;DMT/_1->C:2;I->C;"
"I:control->DMT/_1:control");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNormGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
"F(_MklFusedBatchNormGrad);G(Mul)|A->F;A->G;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
"E->F:4;F->G:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNorm'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
"F(_MklFusedBatchNorm);G(Mul)|A->F;A->G;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
"E->F:4;F->G:1");
}
/////////////////////////////////////////////////////////////////////
// Unit tests related to rewriting node for workspace edges
/////////////////////////////////////////////////////////////////////
@ -802,13 +1030,13 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
"node { name: 'H' op: 'Input'}"
"node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['H', 'G'] }");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);I(Mul)|"
"A->B;B->C;B->E;B->G:2;B:1->G:3;B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;"
"C->E:1;C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;"
"DMT/_2->G:5;E->G;E:1->G:4;F->G:1;G->I:1;H->I");
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);"
"I(Mul)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
"B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;B:control->DMT/_1:control;C->E:1;"
"C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;DMT/_2->G:5;"
"E->G;E:1->G:4;E:control->DMT/_2:control;F->G:1;G->I:1;H->I");
}
/* Test LRN->LRNGrad replacement by workspace nodes. */
@ -838,8 +1066,9 @@ TEST_F(MklLayoutPassTest, LRN_Positive) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklLRNGrad);F(Mul)|"
"A->B;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;C->E;C->F;D->E:1;"
"DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
"A->B;A:control->DMT/_0:control;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;"
"C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
"D->E:1;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
}
/* Test LRN->LRNGrad replacement when only one of them is present. */
@ -858,7 +1087,7 @@ TEST_F(MklLayoutPassTest, LRN_Negative1) {
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(Mul);DMT/_0(Const)|"
"A->B;A->C;B->C:1;DMT/_0->B:1");
"A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
/* Test LRN->LRNGrad replacement when only one of them is present. */
@ -880,8 +1109,10 @@ TEST_F(MklLayoutPassTest, LRN_Negative2) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
"A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
"DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
"DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
}
/* Test LRN->LRNGrad negative case, where single LRN feeds
@ -920,9 +1151,13 @@ TEST_F(MklLayoutPassTest, LRN_Negative3) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
"DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;B->E:2;"
"B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;D->E:1;"
"D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
"DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;"
"A:control->DMT/_0:control;B->E:2;"
"B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;"
"C:control->DMT/_1:control;C:control->DMT/_2:control;"
"C:control->DMT/_3:control;C:control->DMT/_4:control;"
"C:control->DMT/_5:control;C:control->DMT/_6:control;"
"D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
"DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1");
}
@ -951,8 +1186,9 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Mul)|"
"A->B;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;C->E;C->F;D->E:2;"
"DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
"A->B;A:control->DMT/_0:control;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;"
"C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
"D->E:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
}
// Test MaxPool>MaxPoolGrad replacement when only one of them is present.
@ -972,7 +1208,7 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) {
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklMaxPool);C(Mul);DMT/_0(Const)|"
"A->B;A->C;B->C:1;DMT/_0->B:1");
"A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
// Test MaxPoolGrad replacement when only one of them is present.
@ -995,8 +1231,374 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
"A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
"DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
"DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
}
// Test MaxPool handling for batch-wise pooling (NCHW)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative3) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for batch-wise pooling (NCHW)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative4) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for depth-wise pooling (NHWC)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative5) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:2, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for depth-wise pooling (NCHW)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative6) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:2, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for batch-wise pooling (NHWC)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative7) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for batch-wise pooling (NHWC)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative8) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for depth-wise pooling (NHWC)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative9) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:2} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for depth-wise pooling (NHWC)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative10) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:2} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
}
/////////////////////////////////////////////////////////////////////
// Single Conv2D Op on GPU device
// No rewrite should happen
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Conv2D);D(Mul)|A->C;B->C:1;B->D;C->D:1");
}
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'O' op: '_MklInput'}"
"node { name: 'D' op: '_MklConv2DWithBias'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
"node { name: 'E' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'A']}"
"node { name: 'F' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['E'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
"E(Sub);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
"O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;E->F;"
"M->D:3;N->D:4;O->D:5");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Conv2DBackpropFilter'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Mul)|"
"A->D;A->E;B->D:1;C->D:2;D->E:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A'] }"
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Relu);C(Mul)|A->B;A->C;B->C:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'ReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(ReluGrad);D(Mul)|A->C;A->D;B->C:1;C->D:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'AvgPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(AvgPool);C(Mul)|A->B;A->C;B->C:1");
}
// Concat Op test: Concat with no Mkl layer feeding it
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'B' op: 'InputList'"
" attr { key: 'N' value { i: 2 } }}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Concat'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['A', 'B:0', 'B:1']}"
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(Concat);E(Mul)|A->D;"
"B->D:1;B:1->D:2;C->E;D->E:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'B' op: 'InputList'"
" attr { key: 'N' value { i: 2 } }}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'ConcatV2'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tidx' value { type: DT_INT32 } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['B:0', 'B:1', 'A']}"
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(ConcatV2);E(Mul)|"
"A->D:2;B->D;B:1->D:1;C->E;D->E:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNorm'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);E(Input);"
"F(FusedBatchNorm);G(Mul)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
"E->F:4;F->G:1");
}
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}", kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
"M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->C;"
"B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
}
/////////////////////////////////////////////////////////////////////

View File

@ -98,12 +98,13 @@ class MklToTfConversionPass : public GraphOptimizationPass {
Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*);
};
// We register MklToTf insertion for phase 1 in post-partition grouping.
// We register this pass after partitioning so that we get a complete
// picture of inputs and outputs of the nodes in the graphs.
// We register MklToTf insertion for phase 2 in post-partition grouping
// because we register MklLayoutRewritePass for phase 1 in post-partition
// grouping. We register this pass after partitioning so that we get a
// complete picture of inputs and outputs of the nodes in the graphs.
const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 1, MklToTfConversionPass);
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
Status MklToTfConversionPass::InsertConversionNodeOnEdge(
std::unique_ptr<Graph>* g, Edge* e) {
@ -121,10 +122,12 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
string data_format;
TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype));
TF_CHECK_OK(GetNodeAttr(dst->def(), "T", &dst_datatype));
if (src_datatype != dst_datatype) {
string err_msg = "T attribute of " + src->name() + " and " + dst->name() +
" do not match. Will not insert" +
bool dst_dtype_found = GetNodeAttr(dst->def(), "T", &dst_datatype) ==
Status::OK();
// We compare source and destination datatypes only when both are found.
if (dst_dtype_found && (src_datatype != dst_datatype)) {
string err_msg = "T attribute of " + src->name() + " and " +
dst->name() + " do not match. Will not insert" +
" MklToTf node in such case.";
return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
}
@ -202,18 +205,19 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
<< src->type_string() << " and " << dst->type_string();
// Let's get source and destination data type.
DataType src_datatype = DT_INVALID;
if (GetNodeAttr(src->def(), "T", &src_datatype) != Status::OK()) {
continue;
}
// We cannot check datatype on destination node because destination node
// may not be Mkl node.
DataType dst_datatype = DT_INVALID;
GetNodeAttr(dst->def(), "T", &dst_datatype);
DataType src_datatype;
DataType dst_datatype;
bool src_is_mkl_op = (GetNodeAttr(src->def(), "T", &src_datatype) ==
Status::OK() &&
IsMklSupportedOp(src->type_string(), src_datatype));
bool dst_is_mkl_op = (GetNodeAttr(dst->def(), "T", &dst_datatype) ==
Status::OK() &&
IsMklSupportedOp(dst->type_string(), dst_datatype));
// Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
if (IsMklSupportedOp(src->type_string(), src_datatype) &&
!IsMklSupportedOp(dst->type_string(), dst_datatype)) {
if (src_is_mkl_op && !dst_is_mkl_op) {
VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name()
<< " and " << dst->name() << " for inserting conversion nodes";
candidate_edges.push_back(const_cast<Edge*>(e));

View File

@ -149,7 +149,7 @@ TEST_F(MklToTfConversionPass, Positive) {
" input: ['C', 'D']}");
EXPECT_EQ(DoRunMklToTfConversionPass(),
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
"_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
"Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
} else {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
@ -172,7 +172,7 @@ TEST_F(MklToTfConversionPass, Positive) {
" input: ['C', 'D']}");
EXPECT_EQ(DoRunMklToTfConversionPass(),
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
"_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;"
"Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;"
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
}
}

View File

@ -265,6 +265,7 @@ class MklConcatOp : public OpKernel {
s.GetDimension() > 0 ? s.GetSizes()[concat_dim] : 1;
}
mkl_context.MklCreateInputLayouts(context, input_shapes);
OP_REQUIRES_OK(context, context->status());
CHECK_EQ(dnnConcatCreate_F32(&mkl_context.prim_concat, NULL, N,
&mkl_context.lt_inputs[0]),
@ -316,12 +317,14 @@ class MklConcatOp : public OpKernel {
mkl_context.mkl_tmp_tensors.resize(N);
mkl_context.MklPrepareConcatInputs(context, input_tensors);
OP_REQUIRES_OK(context, context->status());
// Execute primitive.
CHECK_EQ(dnnExecute_F32(mkl_context.prim_concat, mkl_context.concat_res),
E_SUCCESS);
mkl_context.MklCleanup();
OP_REQUIRES_OK(context, context->status());
}
private:

View File

@ -38,9 +38,9 @@ limitations under the License.
#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/mkl/include/mkl_dnn.h"
#include "third_party/mkl/include/mkl_dnn_types.h"
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
@ -252,7 +252,7 @@ class MklConv2DCustomBackpropBiasOp : public OpKernel {
};
#define REGISTER_CPU_KERNELS(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBiasBackpropBias") \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBiasBackpropBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \

View File

@ -37,9 +37,9 @@ limitations under the License.
#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/mkl/include/mkl_dnn.h"
#include "third_party/mkl/include/mkl_dnn_types.h"
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
@ -266,8 +266,11 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
int input_offsets[2];
size_t conv_strides[2];
MklShape input_shape, grad_filter_shape, out_backprop_shape;
dnnPrimitive_t prim_conv_bwdfilter, convert_bwdfilter;
dnnLayout_t lt_input, lt_grad_filter, lt_out_backprop;
dnnPrimitive_t prim_conv_bwdfilter = nullptr;
dnnPrimitive_t convert_bwdfilter = nullptr;
dnnLayout_t lt_input = nullptr;
dnnLayout_t lt_grad_filter = nullptr;
dnnLayout_t lt_out_backprop = nullptr;
void* conv_res[dnnResourceNumber];
void MklCleanup() {
@ -409,7 +412,7 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
};
#define REGISTER_MKL_FILTER_KERNELS(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \

View File

@ -23,8 +23,6 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <algorithm>
#include <vector>
#include "third_party/mkl/include/mkl_dnn.h"
#include "third_party/mkl/include/mkl_dnn_types.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@ -42,6 +40,8 @@ limitations under the License.
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
#include "third_party/mkl/include/mkl_dnn.h"
#include "third_party/mkl/include/mkl_dnn_types.h"
namespace tensorflow {
@ -342,7 +342,7 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
};
#define REGISTER_MKL_CPU_KERNELS(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \

View File

@ -36,9 +36,9 @@ limitations under the License.
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/mkl/include/mkl_dnn.h"
#include "third_party/mkl/include/mkl_dnn_types.h"
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
@ -98,19 +98,18 @@ class MklConv2DOp : public OpKernel {
filter.shape().DebugString()));
for (int i = 0; i < 3; i++) {
OP_REQUIRES(
context,
FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
errors::InvalidArgument("filter too large"));
OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i),
std::numeric_limits<int>::max()),
errors::InvalidArgument("filter too large"));
}
const int64 input_depth =
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'C')
: GetTensorDim(input, data_format_, 'C');
OP_REQUIRES(context, input_depth == filter.dim_size(2),
errors::InvalidArgument(
"input and filter must have the same depth: ", input_depth,
" vs ", filter.dim_size(2)));
OP_REQUIRES(
context, input_depth == filter.dim_size(2),
errors::InvalidArgument("input and filter must have the same depth: ",
input_depth, " vs ", filter.dim_size(2)));
// The last dimension for filter is out_depth.
const int out_depth = static_cast<int>(filter.dim_size(3));
@ -119,10 +118,9 @@ class MklConv2DOp : public OpKernel {
const int64 input_rows_raw =
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'H')
: GetTensorDim(input, data_format_, 'H');
OP_REQUIRES(
context,
FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
errors::InvalidArgument("Input rows too large"));
OP_REQUIRES(context, FastBoundsCheck(input_rows_raw,
std::numeric_limits<int>::max()),
errors::InvalidArgument("Input rows too large"));
const int input_rows = static_cast<int>(input_rows_raw);
const int filter_rows = static_cast<int>(filter.dim_size(0));
@ -131,10 +129,9 @@ class MklConv2DOp : public OpKernel {
const int64 input_cols_raw =
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'W')
: GetTensorDim(input, data_format_, 'W');
OP_REQUIRES(
context,
FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
errors::InvalidArgument("Input cols too large"));
OP_REQUIRES(context, FastBoundsCheck(input_cols_raw,
std::numeric_limits<int>::max()),
errors::InvalidArgument("Input cols too large"));
const int input_cols = static_cast<int>(input_cols_raw);
const int filter_cols = static_cast<int>(filter.dim_size(1));
@ -142,10 +139,9 @@ class MklConv2DOp : public OpKernel {
const int64 input_batch_raw =
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'N')
: GetTensorDim(input, data_format_, 'N');
OP_REQUIRES(
context,
FastBoundsCheck(input_batch_raw, std::numeric_limits<int>::max()),
errors::InvalidArgument("batch is too large"));
OP_REQUIRES(context, FastBoundsCheck(input_batch_raw,
std::numeric_limits<int>::max()),
errors::InvalidArgument("batch is too large"));
const int batch = static_cast<int>(input_batch_raw);
// For now we take the stride from the second and third dimensions only (we
@ -438,12 +434,12 @@ class MklConv2DOp : public OpKernel {
};
#define REGISTER_MKL_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
MklConv2DOp<CPUDevice, T, false>); \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \

View File

@ -104,6 +104,15 @@ class MklLRNOp : public OpKernel {
return;
}
// TODO(inteltf) MKL will support depth radius not equal to 2 in the future
if (depth_radius_ != 2) {
Tensor converted_tensor =
ConvertMklToTF<T>(context, input, mkl_context.input_shape);
mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
beta_, converted_tensor);
return;
}
if (input_in_mkl_format) {
// MKL supports normalization over channel dimension only
if (mkl_context.input_shape.tf_dim_idx(mkl_context.in_dims - 1) ==
@ -112,8 +121,10 @@ class MklLRNOp : public OpKernel {
static_cast<dnnLayout_t>(mkl_context.input_shape.GetCurLayout());
workspace_enabled_ = true;
} else {
Tensor converted_tensor =
ConvertMklToTF<T>(context, input, mkl_context.input_shape);
mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
beta_, input);
beta_, converted_tensor);
return;
}
}
@ -267,7 +278,7 @@ class MklLRNOp : public OpKernel {
}
// Fallback implementation - Taken from lrn_op.cc
// TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
// TODO(inteltf) Check if we can use EigenLRNOp directly instead of making a
// copy.
void MklDefaultToEigen(OpKernelContext* context, int depth_radius_,
float bias_, float alpha_, float beta_,
@ -378,6 +389,12 @@ class MklLRNGradOp : public OpKernel {
mkl_context.MklDefaultToEigen(context);
return;
}
if (depth_radius_ != 2) {
mkl_context.MklDefaultToEigen(context);
return;
}
if (ingrad_in_mkl_format || inimage_in_mkl_format) {
const MklShape* tmp_mkl_shape = (ingrad_in_mkl_format)
? &mkl_context.ingrad_shape
@ -489,14 +506,11 @@ class MklLRNGradOp : public OpKernel {
MklShape ingrad_shape, inimage_shape, outimage_shape;
dnnPrimitive_t lrn_bwd = nullptr;
dnnPrimitive_t convert_input = nullptr;
/* dnnPrimitive_t convert_output; */
dnnLayout_t lt_input = nullptr;
dnnLayout_t lt_output = nullptr;
dnnLayout_t lt_bdw_input = nullptr;
dnnLayout_t lt_workspace = nullptr;
dnnLayout_t lt_internal_input = nullptr;
/* dnnLayout_t lt_internal_workspace;
dnnLayout_t lt_internal_output; */
void* res_lrn_bwd[dnnResourceNumber];
// prepare mkl input
@ -619,14 +633,36 @@ class MklLRNGradOp : public OpKernel {
// copy.
void MklDefaultToEigen(OpKernelContext* context) {
// CHECK(false);
Tensor in_grads = MklGetInput(context, 0);
Tensor in_image = MklGetInput(context, 1);
Tensor out_image = MklGetInput(context, 2);
Tensor in_grads;
Tensor in_image;
Tensor out_image;
GetMklShape(context, 0, &ingrad_shape);
GetMklShape(context, 1, &inimage_shape);
GetMklShape(context, 2, &outimage_shape);
if (ingrad_shape.IsMklTensor()) {
in_grads =
ConvertMklToTF<T>(context, MklGetInput(context, 0), ingrad_shape);
} else {
in_grads = MklGetInput(context, 0);
}
if (inimage_shape.IsMklTensor()) {
in_image =
ConvertMklToTF<T>(context, MklGetInput(context, 1), inimage_shape);
} else {
in_image = MklGetInput(context, 1);
}
if (outimage_shape.IsMklTensor()) {
out_image =
ConvertMklToTF<T>(context, MklGetInput(context, 2), outimage_shape);
} else {
out_image = MklGetInput(context, 2);
}
const int64 batch = static_cast<int64>(in_grads.dim_size(0));
const int64 rows = static_cast<int64>(in_grads.dim_size(1));
const int64 cols = static_cast<int64>(in_grads.dim_size(2));

View File

@ -199,15 +199,13 @@ class MklMatMulOp : public OpKernel {
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \
REGISTER_KERNEL_BUILDER( \
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("MKL"), \
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>)
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
// TODO:Consider template specialization when adding/removing additional types
// TODO(inteltf) Consider template specialization when adding/removing
// additional types
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);

View File

@ -276,11 +276,6 @@ class MklMaxPoolingGradOp : public OpKernel {
mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>(
static_cast<const void*>(output_tensor->flat<T>().data()));
int64 output_size = output_tensor->NumElements();
for (int64 i = 0; i < output_size; ++i) {
(static_cast<float*>(mkl_context.pooling_res[dnnResourceDiffSrc]))[i] = 0;
}
CHECK_EQ(
dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res),
E_SUCCESS);

View File

@ -16,17 +16,17 @@ limitations under the License.
// See docs in ../ops/nn_ops.cc.
#ifdef INTEL_MKL
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "third_party/mkl/include/mkl_dnn.h"
#include "third_party/mkl/include/mkl_dnn_types.h"
#include "tensorflow/core/platform/default/logging.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/mkl/include/mkl_dnn.h"
#include "third_party/mkl/include/mkl_dnn_types.h"
namespace tensorflow {
@ -194,45 +194,29 @@ class MklReluGradOp : public OpKernel {
void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
dnnPrimitive_t cv_input_to_grad = NULL;
Tensor mkl_tmp_buf_tensor;
void* mkl_buffer_convert = nullptr;
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
&mkl_lt_internal_grad, prim_relu_bwd, dnnResourceDiffDst),
E_SUCCESS);
// if input and grad are not in the same layout, do a conversion between
// them.
if (!dnnLayoutCompare_F32(lt_input, lt_grad)) {
AllocTmpBuffer(context, &mkl_tmp_buf_tensor, lt_grad,
&mkl_buffer_convert);
CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad),
E_SUCCESS);
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input,
prim_relu_bwd, dnnResourceSrc),
E_SUCCESS);
if (!dnnLayoutCompare_F32(mkl_lt_internal_grad, lt_grad)) {
AllocTmpBuffer(context, mkl_tmp_grad_buf_tensor, mkl_lt_internal_grad,
&relu_res[dnnResourceDiffDst]);
CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_grad, lt_grad,
mkl_lt_internal_grad),
CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, user_i,
mkl_buffer_convert),
E_SUCCESS);
CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_grad, user_g,
relu_res[dnnResourceDiffDst]),
E_SUCCESS);
dnnDelete_F32(cv_user_to_reluB_grad);
} else {
relu_res[dnnResourceDiffDst] = user_g;
}
if (!dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input)) {
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
&relu_res[dnnResourceSrc]);
CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_input, lt_input,
mkl_lt_internal_input),
E_SUCCESS);
CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_input, user_i,
relu_res[dnnResourceSrc]),
E_SUCCESS);
dnnDelete_F32(cv_user_to_reluB_input);
relu_res[dnnResourceSrc] = mkl_buffer_convert;
dnnDelete_F32(cv_input_to_grad);
} else {
relu_res[dnnResourceSrc] = user_i;
}
dnnLayoutDelete_F32(mkl_lt_internal_input);
dnnLayoutDelete_F32(mkl_lt_internal_grad);
relu_res[dnnResourceDiffDst] = user_g;
}
void MklCreateInputLayouts(OpKernelContext* context) {
@ -331,7 +315,7 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
mkl_context.MklCreateInputLayouts(context);
float negative_slope = 0.0;
CHECK_EQ(dnnReLUCreateBackward_F32(&mkl_context.prim_relu_bwd, NULL,
mkl_context.lt_grad, mkl_context.lt_input,
mkl_context.lt_grad, mkl_context.lt_grad,
negative_slope),
E_SUCCESS);
Tensor mkl_tmp_grad_buf_tensor, mkl_tmp_input_buf_tensor;
@ -380,12 +364,12 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
/* Register DNN kernels for supported operations and supported types - right now
* it is only Relu and f32*/
#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \
REGISTER_KERNEL_BUILDER(Name("_MklRelu") \
REGISTER_KERNEL_BUILDER(Name("_MklRelu") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
MklReluOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER(Name("_MklReluGrad") \
REGISTER_KERNEL_BUILDER(Name("_MklReluGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.Label(mkl_op_registry::kMklOpLabel), \

View File

@ -106,7 +106,7 @@ class MklToTfOp : public OpKernel {
///////////////////////////////////////////////////////////
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("MklToTf") \
REGISTER_KERNEL_BUILDER(Name("_MklToTf") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \

View File

@ -28,6 +28,12 @@ set -e
CONTAINER_TYPE=$( echo "$1" | tr '[:upper:]' '[:lower:]' )
shift 1
# Enable support for MKL, for Linux only.
if [[ $(uname) == "Linux" ]]; then
export TF_NEED_MKL=1
export TF_DOWNLOAD_MKL=1
fi
# Enable support for Google Cloud Platform (GCP)
export TF_NEED_GCP=1
# Enable support for HDFS

View File

@ -46,6 +46,7 @@ apt-get install -y --no-install-recommends \
git \
libcurl4-openssl-dev \
libtool \
mlocate \
openjdk-8-jdk \
openjdk-8-jre-headless \
pkg-config \
@ -63,6 +64,9 @@ apt-get install -y --no-install-recommends \
zip \
zlib1g-dev
# populate the database
updatedb
if [[ "$1" != "--without_cmake" ]]; then
apt-get install -y --no-install-recommends \
cmake

View File

@ -28,6 +28,7 @@ export TF_NEED_GCP=0
export TF_NEED_HDFS=0
export TF_NEED_CUDA=0
export TF_NEED_OPENCL=0
export TF_NEED_MKL=0
export COMPUTECPP_PATH="/usr/local"
export PATH="/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"

View File

@ -29,6 +29,7 @@ export PYTHON_BIN_PATH="/usr/bin/python"
export TF_NEED_GCP=0
export TF_NEED_HDFS=0
export TF_NEED_OPENCL=0
export TF_NEED_MKL=0
export COMPUTECPP_PATH="/usr/local"
export PATH="/usr/local/cuda/bin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"

View File

@ -178,6 +178,7 @@ cc_library(
],
deps = [
],
linkopts = ["-lpthread"],
)
cc_library(
@ -1787,6 +1788,7 @@ cc_library(
":grpc_unsecure",
"//external:protobuf_clib",
],
linkopts = ["-lpthread"],
)
cc_library(

View File

@ -94,6 +94,9 @@ cc_library(
"@%ws%//tensorflow:linux_ppc64le": [
"-lpthread",
],
"@%ws%//tensorflow:linux_x86_64": [
"-lpthread",
],
"//conditions:default": [
],
}),

View File

@ -1695,6 +1695,7 @@ cc_library(
":demangle",
"@zlib_archive//:zlib",
],
linkopts = ["-lpthread", "-ldl"],
)
cc_library(

View File

@ -16,6 +16,7 @@ load(
cc_library(
name = "intel_binary_blob",
srcs = if_mkl([
"libdl.so.2",
"libmklml_intel.so",
"libiomp5.so",
]),