Enable MKL in configure and various bug fixes (#9580)
* relu grad and maxpooling grad fixes for perf * Graph layout pass and conversion pass changes This commit makes following changes: - Enables support for ReluGrad and BiasAddGrad - Adds support for detecting depthwise/batchwise pooling - Adds more unit tests for Graph rewrite pass - Improvements to handling control-flow edges - Bug fixes * Defaulting to Eigen when LRN depth_radius!=2 * Fixed mkl_conv_grad_filter.cc for conv_ops_tests.py * Style fix to mkl_matmul and remove unnecessary 'MKL' label on matmul kernel * Style fixes based on clang-format to mkl_conv_* and mkl_matmul * Bug fixes * Adding OP_REQUIRES_OK check in Concat * Making some style changes * Enabled the configuration of MKL settings * relu grad and maxpooling grad fixes for perf * Graph layout pass and conversion pass changes This commit makes following changes: - Enables support for ReluGrad and BiasAddGrad - Adds support for detecting depthwise/batchwise pooling - Adds more unit tests for Graph rewrite pass - Improvements to handling control-flow edges - Bug fixes * Defaulting to Eigen when LRN depth_radius!=2 * Fixed mkl_conv_grad_filter.cc for conv_ops_tests.py * Style fix to mkl_matmul and remove unnecessary 'MKL' label on matmul kernel * Style fixes based on clang-format to mkl_conv_* and mkl_matmul * Bug fixes * Adding OP_REQUIRES_OK check in Concat * Making some style changes * Enabled the configuration of MKL settings * Fixing graph unit tests with Mkl op name change to _Mkl; Fixed missing _ in MklToTf op * Fixed missing libdl.so.2 in BUILD file * Fixes for unit test build failures. * Changes in mkl_conv_grad_filter_ops.cc for Google code style * Fixes to remove dead code * removed the dead code and added a TODO for mkl implementation to handle this case in the future * Fixed buildifier sanity check error * Adding support for google's CI automation * Updated link to new MKL version * Fix for missing locate command in CI * Adding updatedb to populate the database after installing mlocate * Fixed buildifier issue * setting tf_need_mkl=0 in libtf files * Added third_party/mkl/* to .gitignore * Added third_party/eigen3/mkl_include to .gitignore * In configured, set MKL-enabling options only for Linux.
This commit is contained in:
parent
3273cf4f4d
commit
27dd167c5f
2
.gitignore
vendored
2
.gitignore
vendored
@ -4,6 +4,8 @@ node_modules
|
|||||||
/.bazelrc
|
/.bazelrc
|
||||||
/.tf_configure.bazelrc
|
/.tf_configure.bazelrc
|
||||||
/bazel-*
|
/bazel-*
|
||||||
|
/third_party/eigen3/mkl_include
|
||||||
|
/third_party/mkl/*
|
||||||
/third_party/py/numpy/numpy_include
|
/third_party/py/numpy/numpy_include
|
||||||
/tools/python_bin_path.sh
|
/tools/python_bin_path.sh
|
||||||
/tools/git/gen
|
/tools/git/gen
|
||||||
|
87
configure
vendored
87
configure
vendored
@ -180,25 +180,35 @@ fi
|
|||||||
setup_python
|
setup_python
|
||||||
|
|
||||||
## Set up MKL related environment settings
|
## Set up MKL related environment settings
|
||||||
if false; then # Disable building with MKL for now
|
while [ "$TF_NEED_MKL" == "" ]; do
|
||||||
while [ "$TF_NEED_MKL" == "" ]; do
|
fromuser=""
|
||||||
fromuser=""
|
read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT
|
||||||
read -p "Do you wish to build TensorFlow with MKL support (experimental)? [y/N] " INPUT
|
fromuser="1"
|
||||||
fromuser="1"
|
case $INPUT in
|
||||||
case $INPUT in
|
[Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;;
|
||||||
[Yy]* ) echo "MKL support (experimental) (will be enabled for TensorFlow"; TF_NEED_MKL=1;;
|
[Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
|
||||||
[Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
|
"" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
|
||||||
"" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
|
* ) echo "Invalid selection: " $INPUT;;
|
||||||
* ) echo "Invalid selection: " $INPUT;;
|
esac
|
||||||
esac
|
done
|
||||||
done
|
|
||||||
|
|
||||||
OSNAME=`uname -s`
|
OSNAME=`uname -s`
|
||||||
|
|
||||||
if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL
|
if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL
|
||||||
|
fromuser=""
|
||||||
|
read -p "Do you wish to download MKL LIB from the web? [Y/n] " INPUT
|
||||||
|
fromuser="1"
|
||||||
|
case $INPUT in
|
||||||
|
[Yy]* ) TF_DOWNLOAD_MKL=1;;
|
||||||
|
[Nn]* ) TF_DOWNLOAD_MKL=0;;
|
||||||
|
"" ) TF_DOWNLOAD_MKL=1;;
|
||||||
|
* ) echo "Invalid selection: " $INPUT; exit 1;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
if [[ "$TF_DOWNLOAD_MKL" == "1" ]]; then
|
||||||
DST=`dirname $0`
|
DST=`dirname $0`
|
||||||
ARCHIVE_BASENAME=mklml_lnx_2017.0.2.20170209.tgz
|
ARCHIVE_BASENAME=mklml_lnx_2018.0.20170425.tgz
|
||||||
GITHUB_RELEASE_TAG=v0.5
|
GITHUB_RELEASE_TAG=v0.7
|
||||||
MKLURL="https://github.com/01org/mkl-dnn/releases/download/$GITHUB_RELEASE_TAG/$ARCHIVE_BASENAME"
|
MKLURL="https://github.com/01org/mkl-dnn/releases/download/$GITHUB_RELEASE_TAG/$ARCHIVE_BASENAME"
|
||||||
if ! [ -e "$DST/third_party/mkl/$ARCHIVE_BASENAME" ]; then
|
if ! [ -e "$DST/third_party/mkl/$ARCHIVE_BASENAME" ]; then
|
||||||
wget --no-check-certificate -P $DST/third_party/mkl/ $MKLURL
|
wget --no-check-certificate -P $DST/third_party/mkl/ $MKLURL
|
||||||
@ -208,7 +218,20 @@ if false; then # Disable building with MKL for now
|
|||||||
MKL_INSTALL_PATH=$DST/third_party/mkl/$extracted_dir_name
|
MKL_INSTALL_PATH=$DST/third_party/mkl/$extracted_dir_name
|
||||||
MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"`
|
MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"`
|
||||||
|
|
||||||
if [ "$OSNAME" == "Linux" ]; then
|
else
|
||||||
|
default_mkl_path=/opt/intel/mklml
|
||||||
|
fromuser=""
|
||||||
|
read -p "Please specify the location where MKL is installed. [Default is $default_mkl_path]: " MKL_INSTALL_PATH
|
||||||
|
fromuser="1"
|
||||||
|
if [ -z "$MKL_INSTALL_PATH" ]; then
|
||||||
|
MKL_INSTALL_PATH=$default_mkl_path
|
||||||
|
fi
|
||||||
|
# Result returned from "read" will be used unexpanded. That make "~" unuseable.
|
||||||
|
# Going through one more level of expansion to handle that.
|
||||||
|
MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"`
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$OSNAME" == "Linux" ]; then
|
||||||
# Full MKL configuration
|
# Full MKL configuration
|
||||||
MKL_RT_LIB_PATH="lib/intel64/libmkl_rt.so" #${TF_MKL_EXT}#TODO version?
|
MKL_RT_LIB_PATH="lib/intel64/libmkl_rt.so" #${TF_MKL_EXT}#TODO version?
|
||||||
MKL_RT_OMP_LIB_PATH="../compiler/lib/intel64/libiomp5.so" #TODO VERSION?
|
MKL_RT_OMP_LIB_PATH="../compiler/lib/intel64/libiomp5.so" #TODO VERSION?
|
||||||
@ -216,24 +239,33 @@ if false; then # Disable building with MKL for now
|
|||||||
# MKL-ML configuration
|
# MKL-ML configuration
|
||||||
MKL_ML_LIB_PATH="lib/libmklml_intel.so" #${TF_MKL_EXT}#TODO version?
|
MKL_ML_LIB_PATH="lib/libmklml_intel.so" #${TF_MKL_EXT}#TODO version?
|
||||||
MKL_ML_OMP_LIB_PATH="lib/libiomp5.so" #TODO VERSION?
|
MKL_ML_OMP_LIB_PATH="lib/libiomp5.so" #TODO VERSION?
|
||||||
elif [ "$OSNAME" == "Darwin" ]; then
|
elif [ "$OSNAME" == "Darwin" ]; then
|
||||||
echo "Darwin is unsupported yet";
|
echo "Darwin is unsupported yet";
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then
|
if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then
|
||||||
ln -sf $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} third_party/mkl/
|
ln -sf $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} third_party/mkl/
|
||||||
ln -sf $MKL_INSTALL_PATH/${MKL_ML_OMP_LIB_PATH} third_party/mkl/
|
ln -sf $MKL_INSTALL_PATH/${MKL_ML_OMP_LIB_PATH} third_party/mkl/
|
||||||
ln -sf $MKL_INSTALL_PATH/include third_party/mkl/
|
ln -sf $MKL_INSTALL_PATH/include third_party/mkl/
|
||||||
ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include
|
ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include
|
||||||
else
|
loc=$(locate -e libdl.so.2 | sed -n 1p)
|
||||||
echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} does not exist";
|
ln -sf $loc third_party/mkl/libdl.so.2
|
||||||
|
elif [ -e "$MKL_INSTALL_PATH/${MKL_RT_LIB_PATH}" ]; then
|
||||||
|
ln -sf $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} third_party/mkl/
|
||||||
|
ln -sf $MKL_INSTALL_PATH/${MKL_RT_OMP_LIB_PATH} third_party/mkl/
|
||||||
|
ln -sf $MKL_INSTALL_PATH/include third_party/mkl/
|
||||||
|
ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include
|
||||||
|
loc=$(locate -e libdl.so.2 | sed -n 1p)
|
||||||
|
ln -sf $loc third_party/mkl/libdl.so.2
|
||||||
|
else
|
||||||
|
echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} nor $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} exists";
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -z "$fromuser" ]; then
|
if [ -z "$fromuser" ]; then
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
cat > third_party/mkl/mkl.config <<EOF
|
cat > third_party/mkl/mkl.config <<EOF
|
||||||
# MKL_INSTALL_PATH refers to the location of MKL root folder. The MKL header and library
|
# MKL_INSTALL_PATH refers to the location of MKL root folder. The MKL header and library
|
||||||
@ -241,9 +273,8 @@ cat > third_party/mkl/mkl.config <<EOF
|
|||||||
MKL_INSTALL_PATH=$MKL_INSTALL_PATH
|
MKL_INSTALL_PATH=$MKL_INSTALL_PATH
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
fi # TF_NEED_MKL
|
fi # TF_NEED_MKL
|
||||||
################## MKL
|
## End MKL setup
|
||||||
fi # Disable building with MKL for now
|
|
||||||
|
|
||||||
## Set up architecture-dependent optimization flags.
|
## Set up architecture-dependent optimization flags.
|
||||||
if [ -z "$CC_OPT_FLAGS" ]; then
|
if [ -z "$CC_OPT_FLAGS" ]; then
|
||||||
|
@ -98,6 +98,7 @@ cc_library(
|
|||||||
name = "simple_orc_jit",
|
name = "simple_orc_jit",
|
||||||
srcs = ["simple_orc_jit.cc"],
|
srcs = ["simple_orc_jit.cc"],
|
||||||
hdrs = ["simple_orc_jit.h"],
|
hdrs = ["simple_orc_jit.h"],
|
||||||
|
linkopts = ["-ldl"],
|
||||||
deps = [
|
deps = [
|
||||||
":compiler_functor",
|
":compiler_functor",
|
||||||
":cpu_runtime",
|
":cpu_runtime",
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h"
|
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h"
|
||||||
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h"
|
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h"
|
||||||
|
@ -1282,7 +1282,10 @@ cc_library(
|
|||||||
] + tf_additional_verbs_lib_defines(),
|
] + tf_additional_verbs_lib_defines(),
|
||||||
linkopts = select({
|
linkopts = select({
|
||||||
"//tensorflow:freebsd": [],
|
"//tensorflow:freebsd": [],
|
||||||
"//conditions:default": ["-ldl"],
|
"//conditions:default": [
|
||||||
|
"-ldl",
|
||||||
|
"-lpthread",
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
deps = tf_additional_lib_deps() + [
|
deps = tf_additional_lib_deps() + [
|
||||||
":lib_hash_crc32c_accelerate_internal",
|
":lib_hash_crc32c_accelerate_internal",
|
||||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
|
||||||
#include "tensorflow/core/graph/mkl_layout_pass.h"
|
#include "tensorflow/core/graph/mkl_layout_pass.h"
|
||||||
#include "tensorflow/core/util/mkl_util.h"
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
@ -280,51 +281,72 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
csinfo_.mkl_conv2d = "_MklConv2D";
|
csinfo_.mkl_conv2d = "_MklConv2D";
|
||||||
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
|
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
|
||||||
csinfo_.mkl_conv2d_with_bias_backprop_bias =
|
csinfo_.mkl_conv2d_with_bias_backprop_bias =
|
||||||
"_MklConv2DWithBiasBackpropBias";
|
"_MklConv2DWithBiasBackpropBias";
|
||||||
csinfo_.relu = "Relu";
|
csinfo_.relu = "Relu";
|
||||||
csinfo_.reshape = "Reshape";
|
csinfo_.relu_grad = "ReluGrad";
|
||||||
csinfo_.relu_grad = "ReluGrad";
|
csinfo_.reshape = "Reshape";
|
||||||
csinfo_.split = "Split";
|
csinfo_.split = "Split";
|
||||||
|
|
||||||
// NOTE: names are alphabetically sorted.
|
// NOTE: names are alphabetically sorted.
|
||||||
rinfo_.push_back({csinfo_.avg_pool, GetMklOpName(csinfo_.avg_pool), 1,
|
rinfo_.push_back({csinfo_.avg_pool,
|
||||||
CopyAttrsPooling, AlwaysRewrite});
|
GetMklOpName(csinfo_.avg_pool),
|
||||||
|
CopyAttrsPooling, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.avg_pool_grad,
|
rinfo_.push_back({csinfo_.avg_pool_grad,
|
||||||
GetMklOpName(csinfo_.avg_pool_grad), 2, CopyAttrsPooling,
|
GetMklOpName(csinfo_.avg_pool_grad),
|
||||||
AlwaysRewrite});
|
CopyAttrsPooling, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.concat, GetMklOpName(csinfo_.concat), 0,
|
// BiasAddGrad gets written into Conv2DWithBiasBackpropBias depending
|
||||||
CopyAttrsConcat, AlwaysRewrite});
|
// on if context contains Conv2D.
|
||||||
rinfo_.push_back({csinfo_.concatv2, GetMklOpName(csinfo_.concatv2), 0,
|
rinfo_.push_back({csinfo_.bias_add_grad,
|
||||||
CopyAttrsConcatV2, AlwaysRewrite});
|
csinfo_.mkl_conv2d_with_bias_backprop_bias,
|
||||||
rinfo_.push_back({csinfo_.conv2d, GetMklOpName(csinfo_.conv2d), 2,
|
CopyAttrsBiasAddGrad, ContextMatchRewrite,
|
||||||
CopyAttrsConv2D, AlwaysRewrite});
|
&biasaddgrad_conv2dwithbias_context_});
|
||||||
|
// BiasAddGrad gets written into BiasAddGrad depending on if context
|
||||||
|
// contains MatMul.
|
||||||
|
rinfo_.push_back({csinfo_.bias_add_grad, csinfo_.matmul,
|
||||||
|
CopyAttrsBiasAddGrad, ContextMatchRewrite,
|
||||||
|
&biasaddgrad_matmul_context_});
|
||||||
|
rinfo_.push_back({csinfo_.concat,
|
||||||
|
GetMklOpName(csinfo_.concat),
|
||||||
|
CopyAttrsConcat, AlwaysRewrite, nullptr});
|
||||||
|
rinfo_.push_back({csinfo_.concatv2,
|
||||||
|
GetMklOpName(csinfo_.concatv2),
|
||||||
|
CopyAttrsConcatV2, AlwaysRewrite, nullptr});
|
||||||
|
rinfo_.push_back({csinfo_.conv2d,
|
||||||
|
GetMklOpName(csinfo_.conv2d),
|
||||||
|
CopyAttrsConv2D, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.conv2d_grad_filter,
|
rinfo_.push_back({csinfo_.conv2d_grad_filter,
|
||||||
GetMklOpName(csinfo_.conv2d_grad_filter), 3,
|
GetMklOpName(csinfo_.conv2d_grad_filter),
|
||||||
CopyAttrsConv2D, AlwaysRewrite});
|
CopyAttrsConv2D, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.conv2d_grad_input,
|
rinfo_.push_back({csinfo_.conv2d_grad_input,
|
||||||
GetMklOpName(csinfo_.conv2d_grad_input), 3,
|
GetMklOpName(csinfo_.conv2d_grad_input),
|
||||||
CopyAttrsConv2D, AlwaysRewrite});
|
CopyAttrsConv2D, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.fused_batch_norm,
|
rinfo_.push_back({csinfo_.fused_batch_norm,
|
||||||
GetMklOpName(csinfo_.fused_batch_norm), 5,
|
GetMklOpName(csinfo_.fused_batch_norm),
|
||||||
CopyAttrsFusedBatchNorm, AlwaysRewrite});
|
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.fused_batch_norm_grad,
|
rinfo_.push_back({csinfo_.fused_batch_norm_grad,
|
||||||
GetMklOpName(csinfo_.fused_batch_norm_grad), 5,
|
GetMklOpName(csinfo_.fused_batch_norm_grad),
|
||||||
CopyAttrsFusedBatchNorm, AlwaysRewrite});
|
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.lrn, GetMklOpName(csinfo_.lrn), 1, CopyAttrsLRN,
|
rinfo_.push_back({csinfo_.lrn,
|
||||||
AlwaysRewrite});
|
GetMklOpName(csinfo_.lrn),
|
||||||
rinfo_.push_back({csinfo_.lrn_grad, GetMklOpName(csinfo_.lrn_grad), 3,
|
CopyAttrsLRN, AlwaysRewrite, nullptr});
|
||||||
CopyAttrsLRN, AlwaysRewrite});
|
rinfo_.push_back({csinfo_.lrn_grad,
|
||||||
rinfo_.push_back({csinfo_.max_pool, GetMklOpName(csinfo_.max_pool), 1,
|
GetMklOpName(csinfo_.lrn_grad),
|
||||||
CopyAttrsPooling, AlwaysRewrite});
|
CopyAttrsLRN, AlwaysRewrite, nullptr});
|
||||||
|
rinfo_.push_back({csinfo_.max_pool,
|
||||||
|
GetMklOpName(csinfo_.max_pool),
|
||||||
|
CopyAttrsPooling, NonDepthBatchWisePoolRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.max_pool_grad,
|
rinfo_.push_back({csinfo_.max_pool_grad,
|
||||||
GetMklOpName(csinfo_.max_pool_grad), 3, CopyAttrsPooling,
|
GetMklOpName(csinfo_.max_pool_grad),
|
||||||
AlwaysRewrite});
|
CopyAttrsPooling, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu), 1,
|
rinfo_.push_back({csinfo_.relu,
|
||||||
CopyAttrsRelu, AlwaysRewrite});
|
GetMklOpName(csinfo_.relu),
|
||||||
rinfo_.push_back({csinfo_.reshape, GetMklOpName(csinfo_.reshape), 2,
|
CopyAttrsRelu, AlwaysRewrite, nullptr});
|
||||||
CopyAttrsReshape, AlwaysRewrite});
|
rinfo_.push_back({csinfo_.relu_grad,
|
||||||
|
GetMklOpName(csinfo_.relu_grad),
|
||||||
// TODO(inteltf): we do not support ReluGrad and BiasAddGrad yet.
|
CopyAttrsRelu, AlwaysRewrite, nullptr});
|
||||||
|
rinfo_.push_back({csinfo_.reshape,
|
||||||
|
GetMklOpName(csinfo_.reshape),
|
||||||
|
CopyAttrsReshape, AlwaysRewrite, nullptr});
|
||||||
|
|
||||||
// Add info about which ops to add workspace edge to and the slots.
|
// Add info about which ops to add workspace edge to and the slots.
|
||||||
wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
|
wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
|
||||||
@ -338,8 +360,15 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
// maxhops in backward data-flow graph. Since input of forward nodes
|
// maxhops in backward data-flow graph. Since input of forward nodes
|
||||||
// (Conv2D) directly goes to backward nodes, we do not expect the
|
// (Conv2D) directly goes to backward nodes, we do not expect the
|
||||||
// hop-distance would be more than few nodes.
|
// hop-distance would be more than few nodes.
|
||||||
cinfo_.push_back({csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias,
|
biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul,
|
||||||
kNodeMergeContextMaxDepth});
|
kNodeMergeContextMaxDepth};
|
||||||
|
|
||||||
|
biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad,
|
||||||
|
csinfo_.mkl_conv2d_with_bias,
|
||||||
|
kNodeMergeContextMaxDepth};
|
||||||
|
|
||||||
|
cinfo_.push_back(&biasaddgrad_matmul_context_);
|
||||||
|
cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Standard interface to run pass
|
// Standard interface to run pass
|
||||||
@ -354,7 +383,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
// @return true, if and only if graph is mutated; false otherwise.
|
// @return true, if and only if graph is mutated; false otherwise.
|
||||||
bool RunPass(std::unique_ptr<Graph>* g);
|
bool RunPass(std::unique_ptr<Graph>* g);
|
||||||
|
|
||||||
private:
|
/// Structure to specify the context information used in a node rewrite rule
|
||||||
|
typedef struct {
|
||||||
|
string node; // Name of the node to be rewritten
|
||||||
|
string fwd; // Name of the node in the forward pass that this node
|
||||||
|
// corresponds to
|
||||||
|
size_t max_hop; // Maximum number of hops the fwd is located
|
||||||
|
// from this node. If the fwd is farther than max_hop
|
||||||
|
// then we do not rewrite the node.
|
||||||
|
} ContextInfo;
|
||||||
|
|
||||||
/// Structure to specify the name of an original node, its new name after
|
/// Structure to specify the name of an original node, its new name after
|
||||||
/// rewrite, the number of inputs to the original node, the function to
|
/// rewrite, the number of inputs to the original node, the function to
|
||||||
/// be used to copy attributes for the op, and the rule (if any) which
|
/// be used to copy attributes for the op, and the rule (if any) which
|
||||||
@ -362,11 +400,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
string name; // Original name of op of the node in the graph
|
string name; // Original name of op of the node in the graph
|
||||||
string new_name; // New name of the op of the node in the graph
|
string new_name; // New name of the op of the node in the graph
|
||||||
int num_ins; // The number of inputs to the original op type
|
|
||||||
// A function handler to copy attributes from an old node to a new node.
|
// A function handler to copy attributes from an old node to a new node.
|
||||||
std::function<void(const Node*, NodeBuilder*)> copy_attrs;
|
std::function<void(const Node*, NodeBuilder*)> copy_attrs;
|
||||||
std::function<bool(const Node*)> rewrite_rule; // A rule under which to
|
// A rule under which to rewrite this node
|
||||||
// rewrite this node.
|
std::function<bool(const Node*, const ContextInfo* c)> rewrite_rule;
|
||||||
|
// ContextInfo, if any, to be used for rewrite
|
||||||
|
ContextInfo* context;
|
||||||
} RewriteInfo;
|
} RewriteInfo;
|
||||||
|
|
||||||
/// Structure to specify a forward op, a backward op, and the slot numbers
|
/// Structure to specify a forward op, a backward op, and the slot numbers
|
||||||
@ -393,16 +432,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
string new_node; // Name of the node after merge
|
string new_node; // Name of the node after merge
|
||||||
} MergeInfo;
|
} MergeInfo;
|
||||||
|
|
||||||
/// Structure to specify the context information used in a node rewrite rule
|
|
||||||
typedef struct {
|
|
||||||
string node; // Name of the node to be rewritten
|
|
||||||
string fwd; // Name of the node in the forward pass that this node
|
|
||||||
// corresponds to
|
|
||||||
size_t max_hop; // Maximum number of hops the fwd is located
|
|
||||||
// from this node. If the fwd is farther than max_hop
|
|
||||||
// then we do not rewrite the node.
|
|
||||||
} ContextInfo;
|
|
||||||
|
|
||||||
/// Structure to store all constant strings
|
/// Structure to store all constant strings
|
||||||
/// NOTE: names are alphabetically sorted.
|
/// NOTE: names are alphabetically sorted.
|
||||||
struct {
|
struct {
|
||||||
@ -427,10 +456,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
string mkl_conv2d_with_bias_backprop_bias;
|
string mkl_conv2d_with_bias_backprop_bias;
|
||||||
string relu;
|
string relu;
|
||||||
string relu_grad;
|
string relu_grad;
|
||||||
string split;
|
|
||||||
string reshape;
|
string reshape;
|
||||||
|
string split;
|
||||||
} csinfo_;
|
} csinfo_;
|
||||||
|
|
||||||
|
private:
|
||||||
/// Maintain info about nodes to rewrite
|
/// Maintain info about nodes to rewrite
|
||||||
std::vector<RewriteInfo> rinfo_;
|
std::vector<RewriteInfo> rinfo_;
|
||||||
|
|
||||||
@ -441,7 +471,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
std::vector<MergeInfo> minfo_;
|
std::vector<MergeInfo> minfo_;
|
||||||
|
|
||||||
/// Maintain info about nodes to rewrite
|
/// Maintain info about nodes to rewrite
|
||||||
static std::vector<ContextInfo> cinfo_;
|
static std::vector<ContextInfo*> cinfo_;
|
||||||
|
|
||||||
|
/// Context variables used in referencing rules
|
||||||
|
static ContextInfo biasaddgrad_matmul_context_;
|
||||||
|
static ContextInfo biasaddgrad_conv2dwithbias_context_;
|
||||||
|
|
||||||
/// Hash table to maintain nodes visited in the graph.
|
/// Hash table to maintain nodes visited in the graph.
|
||||||
std::unordered_set<const Node*> visited_nodes_;
|
std::unordered_set<const Node*> visited_nodes_;
|
||||||
@ -464,19 +498,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
// Clear all visited nodes
|
// Clear all visited nodes
|
||||||
inline void UnMarkRewrittenNodes() { visited_nodes_.clear(); }
|
inline void UnMarkRewrittenNodes() { visited_nodes_.clear(); }
|
||||||
|
|
||||||
// Is this a graph node that can accept variable number of inputs?
|
|
||||||
// Return true if yes, false otherwise.
|
|
||||||
//
|
|
||||||
// Concat, Split are vararg nodes.
|
|
||||||
inline bool IsVarArgNode(Node* n) {
|
|
||||||
if (n->type_string() == csinfo_.concat ||
|
|
||||||
n->type_string() == csinfo_.concatv2 ||
|
|
||||||
n->type_string() == csinfo_.split) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Is OpDef::ArgDef a list type? It could be N * T or list(type).
|
// Is OpDef::ArgDef a list type? It could be N * T or list(type).
|
||||||
// Refer to opdef.proto for details of list type.
|
// Refer to opdef.proto for details of list type.
|
||||||
inline bool ArgIsList(const OpDef::ArgDef& arg) const {
|
inline bool ArgIsList(const OpDef::ArgDef& arg) const {
|
||||||
@ -510,6 +531,39 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
return string(kMklOpPrefix) + name;
|
return string(kMklOpPrefix) + name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Can op represented by node 'n' run on DEVICE_CPU?
|
||||||
|
// Op can run on CPU with MKL if the runtime assigned device or the
|
||||||
|
// user requested device contains device CPU, or both are empty.
|
||||||
|
bool CanOpRunOnCPUDevice(const Node* n) {
|
||||||
|
bool result = true;
|
||||||
|
string reason;
|
||||||
|
|
||||||
|
// Substring that should be checked for in device name for CPU device.
|
||||||
|
const char* const kCPUDeviceSubStr = "cpu";
|
||||||
|
|
||||||
|
// If Op has been specifically assigned to a non-CPU device, then No.
|
||||||
|
if (!n->assigned_device_name().empty() &&
|
||||||
|
!StringPiece(n->assigned_device_name()).contains(kCPUDeviceSubStr)) {
|
||||||
|
result = false;
|
||||||
|
reason = "Op has been assigned a runtime device that is not CPU.";
|
||||||
|
}
|
||||||
|
|
||||||
|
// If user has specifically assigned this op to a non-CPU device, then No.
|
||||||
|
if (!n->def().device().empty() &&
|
||||||
|
!StringPiece(n->def().device()).contains(kCPUDeviceSubStr)) {
|
||||||
|
result = false;
|
||||||
|
reason = "User has assigned a device that is not CPU.";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result == false) {
|
||||||
|
VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
|
||||||
|
<< n->type_string() << ", reason: " << reason;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise Yes.
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// Return a node that can be merged with input node 'n'
|
// Return a node that can be merged with input node 'n'
|
||||||
//
|
//
|
||||||
// @return pointer to the node if we can find such a
|
// @return pointer to the node if we can find such a
|
||||||
@ -538,13 +592,46 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
|
|
||||||
// Default rewrite rule to be used in scenario 1 for rewrite.
|
// Default rewrite rule to be used in scenario 1 for rewrite.
|
||||||
// @return - true (since we want to always rewrite)
|
// @return - true (since we want to always rewrite)
|
||||||
static bool AlwaysRewrite(const Node* n) { return true; }
|
static bool AlwaysRewrite(const Node* n, const ContextInfo* c = nullptr) {
|
||||||
// Rewrite rule that uses context-information for matching
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we are performing pooling on depth or batch. If it is, then we
|
||||||
|
// do not rewrite MaxPool node to Mkl version.
|
||||||
|
// @return - true (if it is not a depth/batch wise pooling case);
|
||||||
|
// false otherwise.
|
||||||
|
static bool NonDepthBatchWisePoolRewrite(const Node* n,
|
||||||
|
const ContextInfo* c) {
|
||||||
|
CHECK_NOTNULL(n);
|
||||||
|
|
||||||
|
string data_format_str;
|
||||||
|
TensorFormat data_format;
|
||||||
|
std::vector<int32> ksize, strides;
|
||||||
|
CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
|
||||||
|
CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
|
||||||
|
CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(),
|
||||||
|
true);
|
||||||
|
CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
|
||||||
|
|
||||||
|
// Condition that specifies non-batch-wise and non-depth-wise pooling.
|
||||||
|
if (GetTensorDim(ksize, data_format, 'N') == 1 &&
|
||||||
|
GetTensorDim(strides, data_format, 'N') == 1 &&
|
||||||
|
GetTensorDim(ksize, data_format, 'C') == 1 &&
|
||||||
|
GetTensorDim(strides, data_format, 'C') == 1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rewrite rule that uses context-information for matching,
|
||||||
// used in scenario 2.
|
// used in scenario 2.
|
||||||
//
|
//
|
||||||
// @input - Node 'n' for which to search for matching context
|
// @input - Node 'n' for which to search for matching context
|
||||||
// @return - true if matching context is found; false otherwise.
|
// @input - The context 'c' under which to rewrite
|
||||||
static bool ContextMatchRewrite(const Node* n);
|
// @return - true if we can rewrite node under context 'c';
|
||||||
|
// false otherwise.
|
||||||
|
static bool ContextMatchRewrite(const Node* n, const ContextInfo* c);
|
||||||
|
|
||||||
// Helper function that searches the matching contextinfo for the node.
|
// Helper function that searches the matching contextinfo for the node.
|
||||||
// Implements depth-first search in the data dependence graph for the
|
// Implements depth-first search in the data dependence graph for the
|
||||||
@ -598,6 +685,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
// node that we are constructing.
|
// node that we are constructing.
|
||||||
//
|
//
|
||||||
// @input g - input graph,
|
// @input g - input graph,
|
||||||
|
// @input orig_node - Original node that we are rewriting
|
||||||
// @input inputs - inputs to old node that we are using for constructing
|
// @input inputs - inputs to old node that we are using for constructing
|
||||||
// new inputs,
|
// new inputs,
|
||||||
// @input input_idx - the index in the 'inputs' vector pointing to the
|
// @input input_idx - the index in the 'inputs' vector pointing to the
|
||||||
@ -608,11 +696,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
// @output output_nodes - the list of new nodes creating Mkl tensors
|
// @output output_nodes - the list of new nodes creating Mkl tensors
|
||||||
//
|
//
|
||||||
// @return None
|
// @return None
|
||||||
void GetNodesProducingMklTensorList(
|
void GetNodesProducingMklTensorList(std::unique_ptr<Graph>* g,
|
||||||
std::unique_ptr<Graph>* g,
|
Node* orig_node, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
||||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
int* input_idx, int list_length,
|
||||||
int* input_idx, int list_length,
|
std::vector<NodeBuilder::NodeOut>* output_nodes);
|
||||||
std::vector<NodeBuilder::NodeOut>* output_nodes);
|
|
||||||
|
|
||||||
// Get a node that will feed an Mkl tensor to the new
|
// Get a node that will feed an Mkl tensor to the new
|
||||||
// node that we are constructing. The output node could be (1) 'n'
|
// node that we are constructing. The output node could be (1) 'n'
|
||||||
@ -620,6 +707,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
// if 'n' is not an Mkl layer.
|
// if 'n' is not an Mkl layer.
|
||||||
//
|
//
|
||||||
// @input g - input graph,
|
// @input g - input graph,
|
||||||
|
// @input orig_node - Original node that we are rewriting,
|
||||||
// @input n - Node based on which we are creating Mkl node,
|
// @input n - Node based on which we are creating Mkl node,
|
||||||
// @input n_output_slot - the output slot of node 'n'
|
// @input n_output_slot - the output slot of node 'n'
|
||||||
// which is feeding to the node that we are constructing
|
// which is feeding to the node that we are constructing
|
||||||
@ -627,9 +715,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
// @output mkl_node_output_slot - the slot number of mkl_node that
|
// @output mkl_node_output_slot - the slot number of mkl_node that
|
||||||
// will feed the tensor
|
// will feed the tensor
|
||||||
// @return None
|
// @return None
|
||||||
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* n,
|
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
|
||||||
int n_output_slot, Node** mkl_node,
|
Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot);
|
||||||
int* mkl_node_output_slot);
|
|
||||||
|
|
||||||
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
|
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
|
||||||
// in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
|
// in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
|
||||||
@ -695,13 +782,18 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
Node* orig_node);
|
Node* orig_node);
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<MklLayoutRewritePass::ContextInfo> MklLayoutRewritePass::cinfo_;
|
MklLayoutRewritePass::ContextInfo
|
||||||
|
MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_;
|
||||||
|
MklLayoutRewritePass::ContextInfo
|
||||||
|
MklLayoutRewritePass::biasaddgrad_matmul_context_;
|
||||||
|
std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
|
||||||
|
|
||||||
// We register Mkl rewrite pass for phase 1 in post rewrite group.
|
// We register Mkl rewrite pass for phase 1 in post partitioning group.
|
||||||
// We register it here so that we get a complete picture of all users of Mkl
|
// We register it here so that we get a complete picture of all users of Mkl
|
||||||
// nodes. Do not change the ordering of the Mkl passes.
|
// nodes. Do not change the ordering of the Mkl passes.
|
||||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 1,
|
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
|
||||||
MklLayoutRewritePass);
|
OptimizationPassRegistry::POST_PARTITIONING;
|
||||||
|
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// Helper functions for creating new node
|
// Helper functions for creating new node
|
||||||
@ -737,27 +829,14 @@ void MklLayoutRewritePass::GetNodesProducingTFTensorList(
|
|||||||
|
|
||||||
while (list_length != 0) {
|
while (list_length != 0) {
|
||||||
CHECK_GT(list_length, 0);
|
CHECK_GT(list_length, 0);
|
||||||
CHECK_LE(*input_idx, inputs.size());
|
CHECK_LT(*input_idx, inputs.size());
|
||||||
Node* n = inputs[*input_idx].first;
|
Node* n = inputs[*input_idx].first;
|
||||||
int slot = inputs[*input_idx].second;
|
int slot = inputs[*input_idx].second;
|
||||||
const OpDef::ArgDef& arg = n->op_def().output_arg(slot);
|
// If input node 'n' is just producing a single tensor at
|
||||||
// If input node 'n' is producing a list/array output at output
|
// output slot 'slot' then we just add that single node.
|
||||||
// slot 'slot' then we need to find out the length of that list/array.
|
output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
|
||||||
if (ArgIsList(arg)) {
|
(*input_idx)++;
|
||||||
int N = GetTensorListLength(arg, n);
|
list_length--;
|
||||||
CHECK_LE(N, list_length);
|
|
||||||
for (int j = 0; j < N; j++) {
|
|
||||||
output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
|
|
||||||
}
|
|
||||||
(*input_idx)++;
|
|
||||||
list_length -= N;
|
|
||||||
} else {
|
|
||||||
// But if input node 'n' is just producing a single tensor at
|
|
||||||
// output slot 'slot' then we just add that single node.
|
|
||||||
output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
|
|
||||||
(*input_idx)++;
|
|
||||||
list_length--;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -775,20 +854,39 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
|
|||||||
TensorShape dummy_shape({8});
|
TensorShape dummy_shape({8});
|
||||||
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
||||||
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
||||||
.Attr("value", proto)
|
.Attr("value", proto)
|
||||||
.Attr("dtype", dt)
|
.Attr("dtype", dt)
|
||||||
.Device(orig_node->def().device()) // We place this node on
|
.Device(orig_node->def().device()) // We place this node on
|
||||||
// the same device as the
|
// the same device as the
|
||||||
// device of the original
|
// device of the original
|
||||||
// node.
|
// node.
|
||||||
.Finalize(&**g, out));
|
.Finalize(&**g, out));
|
||||||
|
|
||||||
|
// If number of inputs to the original node is > 0, then we add
|
||||||
|
// control dependency between 1st input (index 0) of the original node and
|
||||||
|
// the dummy Mkl node. This is needed because control-flow ops such as Enter,
|
||||||
|
// Merge, etc, require frame_name of the dummy Mkl node to be same as the
|
||||||
|
// rewritten node. Adding control edge between 1st input of the original node
|
||||||
|
// and the dummy Mkl node ensures that the dummy node is in the same frame
|
||||||
|
// as the original node. Choosing 1st input is not necessary - any input of
|
||||||
|
// the original node is fine because all the inputs of a node are always in
|
||||||
|
// the same frame.
|
||||||
|
if (orig_node->num_inputs() > 0) {
|
||||||
|
Node* orig_input0 = nullptr;
|
||||||
|
TF_CHECK_OK(orig_node->input_node(0,
|
||||||
|
const_cast<const Node**>(&orig_input0)));
|
||||||
|
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
|
||||||
|
}
|
||||||
|
|
||||||
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
|
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
|
||||||
}
|
}
|
||||||
|
|
||||||
void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
||||||
std::unique_ptr<Graph>* g,
|
std::unique_ptr<Graph>* g,
|
||||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
|
Node* orig_node,
|
||||||
int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
|
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
||||||
|
int* input_idx, int list_length,
|
||||||
|
std::vector<NodeBuilder::NodeOut>* output_nodes) {
|
||||||
CHECK_LT(*input_idx, inputs.size());
|
CHECK_LT(*input_idx, inputs.size());
|
||||||
CHECK_GT(list_length, 0);
|
CHECK_GT(list_length, 0);
|
||||||
CHECK_NOTNULL(output_nodes);
|
CHECK_NOTNULL(output_nodes);
|
||||||
@ -796,38 +894,19 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
|||||||
|
|
||||||
while (list_length != 0) {
|
while (list_length != 0) {
|
||||||
CHECK_GT(list_length, 0);
|
CHECK_GT(list_length, 0);
|
||||||
CHECK_LE(*input_idx, inputs.size());
|
CHECK_LT(*input_idx, inputs.size());
|
||||||
Node* n = inputs[*input_idx].first;
|
Node* n = inputs[*input_idx].first;
|
||||||
int slot = inputs[*input_idx].second;
|
int slot = inputs[*input_idx].second;
|
||||||
const OpDef::ArgDef& arg = n->op_def().output_arg(slot);
|
// If 'n' is producing a single tensor, then create a single Mkl tensor
|
||||||
// We need to check first if the input edge is going to carry a
|
// node.
|
||||||
// single tensor or a list of tensors. If it is a list of tensors,
|
Node* mkl_node = nullptr;
|
||||||
// then we need to create list of Mkl dummy nodes.
|
int mkl_node_output_slot = 0;
|
||||||
if (ArgIsList(arg)) {
|
GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
|
||||||
// If input node 'n' is producing a list/array output at output
|
&mkl_node_output_slot);
|
||||||
// slot 'slot' then we need to find out the length of that list/array.
|
output_nodes->push_back(NodeBuilder::NodeOut(mkl_node,
|
||||||
int N = GetTensorListLength(arg, n);
|
mkl_node_output_slot));
|
||||||
CHECK_LE(N, list_length);
|
(*input_idx)++;
|
||||||
Node* mkl_node = nullptr;
|
list_length--;
|
||||||
int mkl_node_output_slot = 0;
|
|
||||||
// If it is a list, then create a list of Mkl dummy nodes.
|
|
||||||
for (int j = 0; j < N; j++) {
|
|
||||||
GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
|
|
||||||
output_nodes->push_back(
|
|
||||||
NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
|
|
||||||
}
|
|
||||||
(*input_idx)++;
|
|
||||||
list_length -= N;
|
|
||||||
} else {
|
|
||||||
// If it is not a list, then create a single Mkl tensor node.
|
|
||||||
Node* mkl_node = nullptr;
|
|
||||||
int mkl_node_output_slot = 0;
|
|
||||||
GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
|
|
||||||
output_nodes->push_back(
|
|
||||||
NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
|
|
||||||
(*input_idx)++;
|
|
||||||
list_length--;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -835,9 +914,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
|||||||
// node that we are constructing. An input node could be (1) 'n'
|
// node that we are constructing. An input node could be (1) 'n'
|
||||||
// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
|
// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
|
||||||
// if 'n' is not an Mkl layer.
|
// if 'n' is not an Mkl layer.
|
||||||
void MklLayoutRewritePass::GetNodeProducingMklTensor(
|
void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
|
||||||
std::unique_ptr<Graph>* g, Node* n, int n_output_slot, Node** mkl_node,
|
Node* orig_node, Node* n,
|
||||||
int* mkl_node_output_slot) {
|
int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
|
||||||
CHECK_NOTNULL(n);
|
CHECK_NOTNULL(n);
|
||||||
CHECK_NOTNULL(mkl_node);
|
CHECK_NOTNULL(mkl_node);
|
||||||
CHECK_NOTNULL(mkl_node_output_slot);
|
CHECK_NOTNULL(mkl_node_output_slot);
|
||||||
@ -860,7 +939,7 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor(
|
|||||||
// to create a dummy node that will feed a dummy Mkl tensor to this node.
|
// to create a dummy node that will feed a dummy Mkl tensor to this node.
|
||||||
// DummyMklTensor node has no input and generates only 1 output
|
// DummyMklTensor node has no input and generates only 1 output
|
||||||
// (dummy Mkl tensor) as output slot number 0.
|
// (dummy Mkl tensor) as output slot number 0.
|
||||||
GetDummyMklTensorNode(g, mkl_node, n);
|
GetDummyMklTensorNode(g, mkl_node, orig_node);
|
||||||
CHECK_NOTNULL(*mkl_node);
|
CHECK_NOTNULL(*mkl_node);
|
||||||
*mkl_node_output_slot = 0;
|
*mkl_node_output_slot = 0;
|
||||||
}
|
}
|
||||||
@ -926,16 +1005,16 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
|
|||||||
if (ArgIsList(arg)) {
|
if (ArgIsList(arg)) {
|
||||||
std::vector<NodeBuilder::NodeOut> new_node_inputs;
|
std::vector<NodeBuilder::NodeOut> new_node_inputs;
|
||||||
int N = GetTensorListLength(arg, old_node);
|
int N = GetTensorListLength(arg, old_node);
|
||||||
GetNodesProducingMklTensorList(g, old_node_inputs, &iidx, N,
|
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
|
||||||
&new_node_inputs);
|
N, &new_node_inputs);
|
||||||
nb->Input(new_node_inputs);
|
nb->Input(new_node_inputs);
|
||||||
nn_slot_idx++;
|
nn_slot_idx++;
|
||||||
} else {
|
} else {
|
||||||
Node* mkl_node = nullptr;
|
Node* mkl_node = nullptr;
|
||||||
int mkl_node_output_slot = 0;
|
int mkl_node_output_slot = 0;
|
||||||
GetNodeProducingMklTensor(g, old_node_inputs[iidx].first,
|
GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
|
||||||
old_node_inputs[iidx].second, &mkl_node,
|
old_node_inputs[iidx].second,
|
||||||
&mkl_node_output_slot);
|
&mkl_node, &mkl_node_output_slot);
|
||||||
nb->Input(mkl_node, mkl_node_output_slot);
|
nb->Input(mkl_node, mkl_node_output_slot);
|
||||||
iidx++;
|
iidx++;
|
||||||
nn_slot_idx++;
|
nn_slot_idx++;
|
||||||
@ -1020,13 +1099,30 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
|
|||||||
TensorShape dummy_shape({1});
|
TensorShape dummy_shape({1});
|
||||||
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
||||||
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
||||||
.Attr("value", proto)
|
.Attr("value", proto)
|
||||||
.Attr("dtype", dt)
|
.Attr("dtype", dt)
|
||||||
.Device(orig_node->def().device()) // We place this node on
|
.Device(orig_node->def().device()) // We place this node on
|
||||||
// same the device as the
|
// same the device as the
|
||||||
// device of the original
|
// device of the original
|
||||||
// node.
|
// node.
|
||||||
.Finalize(&**g, out));
|
.Finalize(&**g, out));
|
||||||
|
|
||||||
|
// If number of inputs to the original node is > 0, then we add
|
||||||
|
// control dependency between 1st input (index 0) of the original node and
|
||||||
|
// the dummy Mkl node. This is needed because control-flow ops such as Enter,
|
||||||
|
// Merge, etc, require frame_name of the dummy Mkl node to be same as the
|
||||||
|
// rewritten node. Adding control edge between 1st input of the original node
|
||||||
|
// and the dummy Mkl node ensures that the dummy node is in the same frame
|
||||||
|
// as the original node. Choosing 1st input is not necessary - any input of
|
||||||
|
// the original node is fine because all the inputs of a node are always in
|
||||||
|
// the same frame.
|
||||||
|
if (orig_node->num_inputs() > 0) {
|
||||||
|
Node* orig_input0 = nullptr;
|
||||||
|
TF_CHECK_OK(orig_node->input_node(0,
|
||||||
|
const_cast<const Node**>(&orig_input0)));
|
||||||
|
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
|
||||||
|
}
|
||||||
|
|
||||||
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
|
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1235,6 +1331,19 @@ void MklLayoutRewritePass::CopyAttrsRelu(const Node* orig_node,
|
|||||||
nb->Attr("T", T);
|
nb->Attr("T", T);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
|
||||||
|
NodeBuilder* nb) {
|
||||||
|
DataType T;
|
||||||
|
DataType Tshape;
|
||||||
|
|
||||||
|
// Get all attributes from old node.
|
||||||
|
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
|
||||||
|
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
|
||||||
|
// Add attributes to new node.
|
||||||
|
nb->Attr("T", T);
|
||||||
|
nb->Attr("Tshape", Tshape);
|
||||||
|
}
|
||||||
|
|
||||||
void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
|
void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
|
||||||
NodeBuilder* nb) {
|
NodeBuilder* nb) {
|
||||||
DataType T;
|
DataType T;
|
||||||
@ -1303,20 +1412,6 @@ void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
|
|||||||
nb->Attr("is_training", is_training);
|
nb->Attr("is_training", is_training);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
|
|
||||||
NodeBuilder* nb) {
|
|
||||||
DataType T;
|
|
||||||
DataType Tshape;
|
|
||||||
|
|
||||||
// Get all attributes from old node.
|
|
||||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
|
|
||||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
|
|
||||||
|
|
||||||
// Add attributes to new node.
|
|
||||||
nb->Attr("T", T);
|
|
||||||
nb->Attr("Tshape", Tshape);
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// Helper functions related to node merge pass
|
// Helper functions related to node merge pass
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
@ -1353,8 +1448,9 @@ Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int B_in = b->num_inputs();
|
||||||
gtl::InlinedVector<Node*, 4> b_control_edges;
|
gtl::InlinedVector<Node*, 4> b_control_edges;
|
||||||
gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in);
|
gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
|
||||||
FillInputs(b, &b_control_edges, &b_in);
|
FillInputs(b, &b_control_edges, &b_in);
|
||||||
|
|
||||||
// Shouldn't merge if a and b have different control edges.
|
// Shouldn't merge if a and b have different control edges.
|
||||||
@ -1438,7 +1534,7 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
|
|||||||
CHECK_EQ(succ->in_edges().size(), 2);
|
CHECK_EQ(succ->in_edges().size(), 2);
|
||||||
Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3
|
Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3
|
||||||
int oper3_mkl_slot = 0; // For dummy MKL tensor node, output slot is 0.
|
int oper3_mkl_slot = 0; // For dummy MKL tensor node, output slot is 0.
|
||||||
GetDummyMklTensorNode(g, &oper3_mkl, succ); // Get dummy Mkl tensor node
|
GetDummyMklTensorNode(g, &oper3_mkl, pred); // Get dummy Mkl tensor node
|
||||||
// as BiasAdd does not have Mkl tensor as input.
|
// as BiasAdd does not have Mkl tensor as input.
|
||||||
CHECK_NOTNULL(oper3_mkl);
|
CHECK_NOTNULL(oper3_mkl);
|
||||||
|
|
||||||
@ -1483,9 +1579,38 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
|
|||||||
// Set the Mkl layer label for this op.
|
// Set the Mkl layer label for this op.
|
||||||
new_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
|
new_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
|
||||||
|
|
||||||
|
// Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
|
||||||
|
// node are already copied in BuildNode. We handle control edges now.
|
||||||
|
for (const Edge* e : pred->in_edges()) {
|
||||||
|
if (e->IsControlEdge()) {
|
||||||
|
CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const Edge* e : succ->in_edges()) {
|
||||||
|
if (e->IsControlEdge()) {
|
||||||
|
CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Incoming edges are fixed, we will fix the outgoing edges now.
|
// Incoming edges are fixed, we will fix the outgoing edges now.
|
||||||
|
// First, we will fix outgoing control edges from 'pred' node.
|
||||||
|
// We don't need to handle outgoing data edges from 'pred' node
|
||||||
|
// because pred has only 1 output going to succ node (we enforced
|
||||||
|
// this check for merge already).
|
||||||
|
for (const Edge* e : pred->out_edges()) {
|
||||||
|
if (e->IsControlEdge()) {
|
||||||
|
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second, we will fix outgoing control and data edges from 'succ' node.
|
||||||
for (const Edge* e : succ->out_edges()) {
|
for (const Edge* e : succ->out_edges()) {
|
||||||
(*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input());
|
if (e->IsControlEdge()) {
|
||||||
|
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
|
||||||
|
} else {
|
||||||
|
CHECK_NOTNULL((*g)->AddEdge(new_node, e->src_output(), e->dst(),
|
||||||
|
e->dst_input()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy device assigned to old node to new node.
|
// Copy device assigned to old node to new node.
|
||||||
@ -1550,18 +1675,22 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
|
|||||||
"data_format or T attribute or devices of BiasAddGrad and "
|
"data_format or T attribute or devices of BiasAddGrad and "
|
||||||
"Conv2D do not match. Will skip node rewrite optimization");
|
"Conv2D do not match. Will skip node rewrite optimization");
|
||||||
}
|
}
|
||||||
|
} else if (orig_node->type_string() == csinfo_.bias_add_grad &&
|
||||||
|
ri->new_name == csinfo_.matmul) {
|
||||||
|
// When BiasAddGrad has MatMul in context, we do not do any rewrite
|
||||||
|
// and leave BiasAddGrad as it is. But we check for this condition
|
||||||
|
// when we check for node rewrite rule. So we should not even come
|
||||||
|
// here for MatMul. So we will fail now.
|
||||||
|
return Status(
|
||||||
|
error::Code::INVALID_ARGUMENT,
|
||||||
|
"No rewrite is required for BiasAddGrad for MatMul context.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all inputs.
|
// Get all inputs.
|
||||||
const int num = orig_node->in_edges().size();
|
const int num_inputs = orig_node->in_edges().size();
|
||||||
// Check the number of inputs against the user-specified value for non-vararg
|
|
||||||
// nodes.
|
|
||||||
if (!IsVarArgNode(orig_node)) {
|
|
||||||
CHECK_EQ(num, ri->num_ins);
|
|
||||||
}
|
|
||||||
gtl::InlinedVector<Node*, 4> control_edges;
|
gtl::InlinedVector<Node*, 4> control_edges;
|
||||||
gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num);
|
gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs);
|
||||||
FillInputs(orig_node, &control_edges, &inputs);
|
FillInputs(orig_node, &control_edges, &inputs);
|
||||||
|
|
||||||
// Build new node. We use same name as original node, but change the op name.
|
// Build new node. We use same name as original node, but change the op name.
|
||||||
@ -1596,8 +1725,15 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
|
|||||||
TF_CHECK_OK(nb.Finalize(&**g, &new_node));
|
TF_CHECK_OK(nb.Finalize(&**g, &new_node));
|
||||||
CHECK_NOTNULL(new_node);
|
CHECK_NOTNULL(new_node);
|
||||||
|
|
||||||
// Incoming edges from 'orig_node' node to new 'new_node' node are already
|
// Incoming data edges from 'orig_node' node to new 'new_node' node are
|
||||||
// copied in BuildNode. Copy outgoing edges from 'orig_node' node to new
|
// already copied in BuildNode. We need to handle control edges now.
|
||||||
|
for (const Edge* e : orig_node->in_edges()) {
|
||||||
|
if (e->IsControlEdge()) {
|
||||||
|
CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy outgoing edges from 'orig_node' node to new
|
||||||
// 'new_node' node, since the output also follows same ordering among
|
// 'new_node' node, since the output also follows same ordering among
|
||||||
// Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
|
// Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
|
||||||
// tensors appropriately. Specifically, nth output of the original node
|
// tensors appropriately. Specifically, nth output of the original node
|
||||||
@ -1605,15 +1741,12 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
|
|||||||
// of the tensors. For the contiguous ordering of the tensors, it will be n.
|
// of the tensors. For the contiguous ordering of the tensors, it will be n.
|
||||||
// GetTensorDataIndex provides this mapping function.
|
// GetTensorDataIndex provides this mapping function.
|
||||||
for (const Edge* e : orig_node->out_edges()) {
|
for (const Edge* e : orig_node->out_edges()) {
|
||||||
// We need to handle control-edges by using their original slot number.
|
if (e->IsControlEdge()) {
|
||||||
// Generally, -1 is reserved for control slot.
|
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
|
||||||
if (e->src_output() < 0) {
|
|
||||||
(*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input());
|
|
||||||
} else {
|
} else {
|
||||||
(*g)->AddEdge(
|
CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(),
|
||||||
new_node,
|
e->src()->num_outputs()),
|
||||||
GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
|
e->dst(), e->dst_input()));
|
||||||
e->dst(), e->dst_input());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1640,8 +1773,8 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n,
|
|||||||
bool is_matching_cinfo_found = false;
|
bool is_matching_cinfo_found = false;
|
||||||
std::vector<const ContextInfo*> mci;
|
std::vector<const ContextInfo*> mci;
|
||||||
for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) {
|
for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) {
|
||||||
if (n->type_string() == ci->node) {
|
if (n->type_string() == (*ci)->node) {
|
||||||
mci.push_back(&*ci);
|
mci.push_back(*ci);
|
||||||
is_matching_cinfo_found = true;
|
is_matching_cinfo_found = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1701,9 +1834,10 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n) {
|
bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n,
|
||||||
|
const ContextInfo* c) {
|
||||||
const Node* fwd_node = nullptr;
|
const Node* fwd_node = nullptr;
|
||||||
return SearchMatchingContext(n, &fwd_node) != nullptr;
|
return SearchMatchingContext(n, &fwd_node) == c;
|
||||||
}
|
}
|
||||||
|
|
||||||
const MklLayoutRewritePass::RewriteInfo*
|
const MklLayoutRewritePass::RewriteInfo*
|
||||||
@ -1719,18 +1853,29 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) {
|
// BiasAddGrad is not an Mkl layer, so we make an exception for it.
|
||||||
return nullptr;
|
if (n->type_string() != csinfo_.bias_add_grad) {
|
||||||
|
if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We support 2 types of node rewrites:
|
// We support 2 types of node rewrites:
|
||||||
// 1. Rewriting BiasAddGrad depending on its context.
|
// 1. Rewriting BiasAddGrad depending on its MklConv2DWithBias context.
|
||||||
// 2. Rewriting an op to Mkl op always
|
// 2. Rewriting an op to Mkl op always
|
||||||
// We return true if any of these 2 conditions is met.
|
// We return true if any of these 2 conditions is met.
|
||||||
|
|
||||||
// Find matching RewriteInfo and then check that rewrite rule applies.
|
// Find matching RewriteInfo and then check that rewrite rule applies.
|
||||||
for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
|
for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
|
||||||
if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
|
if (n->type_string().compare(ri->name) == 0 &&
|
||||||
|
ri->rewrite_rule(n, ri->context)) {
|
||||||
|
// If we are rewriting BiasAddGrad into BiasAddGrad for MatMul context,
|
||||||
|
// then we just return directly.
|
||||||
|
if (n->type_string() == csinfo_.bias_add_grad &&
|
||||||
|
ri->context->fwd == csinfo_.matmul &&
|
||||||
|
ri->new_name == csinfo_.bias_add_grad) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
return &*ri;
|
return &*ri;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1753,7 +1898,8 @@ bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
|
|||||||
GetReversePostOrder(**g, &order); // This will give us topological sort.
|
GetReversePostOrder(**g, &order); // This will give us topological sort.
|
||||||
|
|
||||||
for (Node* n : order) {
|
for (Node* n : order) {
|
||||||
if (!n->IsOp()) {
|
// If node is not an op or it cannot run on CPU device, then skip.
|
||||||
|
if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1801,18 +1947,31 @@ bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
|
|||||||
return MklLayoutRewritePass().RunPass(g);
|
return MklLayoutRewritePass().RunPass(g);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
|
Status MklLayoutRewritePass::Run(
|
||||||
if (options.graph == nullptr) {
|
const GraphOptimizationPassOptions& options) {
|
||||||
|
if (options.graph == nullptr && options.partition_graphs == nullptr) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the ownership of graph
|
auto process_graph = [&](std::unique_ptr<Graph>* g) {
|
||||||
std::unique_ptr<Graph>* g = std::move(options.graph);
|
// Get the ownership of a graph
|
||||||
|
std::unique_ptr<Graph>* ng = std::move(g);
|
||||||
|
RunPass(ng);
|
||||||
|
// Return the ownership of a graph back
|
||||||
|
g->reset(ng->release());
|
||||||
|
};
|
||||||
|
|
||||||
RunPass(g);
|
if (kMklLayoutRewritePassGroup !=
|
||||||
|
OptimizationPassRegistry::POST_PARTITIONING) {
|
||||||
// Return the ownership of graph back
|
// For any pre-partitioning phase, a graph is stored in options.graph.
|
||||||
options.graph->reset(g->release());
|
process_graph(options.graph);
|
||||||
|
} else {
|
||||||
|
// For post partitioning phase, graphs are stored in
|
||||||
|
// options.partition_graphs.
|
||||||
|
for (auto& pg : *options.partition_graphs) {
|
||||||
|
process_graph(&pg.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,11 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static void InitGraph(const string& s, Graph* graph) {
|
const char kCPUDevice[] = "/job:a/replica:0/task:0/cpu:0";
|
||||||
|
const char kGPUDevice[] = "/job:a/replica:0/task:0/gpu:0";
|
||||||
|
|
||||||
|
static void InitGraph(const string& s, Graph* graph,
|
||||||
|
const string& device = kCPUDevice) {
|
||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
|
|
||||||
auto parser = protobuf::TextFormat::Parser();
|
auto parser = protobuf::TextFormat::Parser();
|
||||||
@ -47,14 +51,18 @@ static void InitGraph(const string& s, Graph* graph) {
|
|||||||
CHECK(parser.MergeFromString(s, &graph_def)) << s;
|
CHECK(parser.MergeFromString(s, &graph_def)) << s;
|
||||||
GraphConstructorOptions opts;
|
GraphConstructorOptions opts;
|
||||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
|
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
|
||||||
|
|
||||||
|
for (Node* node : graph->nodes()) {
|
||||||
|
node->set_assigned_device_name(device);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class MklLayoutPassTest : public ::testing::Test {
|
class MklLayoutPassTest : public ::testing::Test {
|
||||||
public:
|
public:
|
||||||
MklLayoutPassTest() : graph_(OpRegistry::Global()) {}
|
MklLayoutPassTest() : graph_(OpRegistry::Global()) {}
|
||||||
|
|
||||||
void InitGraph(const string& s) {
|
void InitGraph(const string& s, const string& device = kCPUDevice) {
|
||||||
::tensorflow::InitGraph(s, &graph_);
|
::tensorflow::InitGraph(s, &graph_, device);
|
||||||
original_ = CanonicalGraphString(&graph_);
|
original_ = CanonicalGraphString(&graph_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,7 +122,8 @@ REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
|
|||||||
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
|
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
|
||||||
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
|
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
|
||||||
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
|
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
|
||||||
REGISTER_OP("_MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful();
|
REGISTER_OP("_MklInput2").Output("o: uint8")
|
||||||
|
.Output("o1: uint8").SetIsStateful();
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
// Unit tests related to node merge optiimization
|
// Unit tests related to node merge optiimization
|
||||||
@ -162,8 +171,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
|
|||||||
" input: ['E', 'Y']}");
|
" input: ['E', 'Y']}");
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
|
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
|
||||||
"M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
|
"M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;"
|
||||||
"DMT/_0->E:5;E->Z;M->E:3;N->E:4;Y->Z:1");
|
"A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;M->E:3;"
|
||||||
|
"N->E:4;Y->Z:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved)
|
// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved)
|
||||||
@ -194,8 +204,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
|
|||||||
" input: ['E', 'Y']}");
|
" input: ['E', 'Y']}");
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
|
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
|
||||||
"M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
|
"M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;"
|
||||||
"DMT/_0->E:5;E->Z;M:1->E:3;N:1->E:4;Y->Z:1");
|
"A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;"
|
||||||
|
"M:1->E:3;N:1->E:4;Y->Z:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
|
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
|
||||||
@ -226,8 +237,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
|
|||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
"A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||||
"DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Sub)|"
|
"DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Sub)|"
|
||||||
"A->E;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;DMT/_2->E:5;"
|
"A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||||
"E->Z;Y->Z:1");
|
"A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;"
|
||||||
|
"DMT/_2->E:5;E->Z;Y->Z:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Graph contains only _MklConv2D, no AddBias.
|
// Graph contains only _MklConv2D, no AddBias.
|
||||||
@ -330,9 +342,6 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
|
|||||||
"N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3");
|
"N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disabling Conv2DBackpropBias test for now as we have disabled rewrite
|
|
||||||
// of BiasAddGrad into BackpropBias
|
|
||||||
#if 0
|
|
||||||
// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias
|
// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias
|
||||||
// rewrite tests
|
// rewrite tests
|
||||||
|
|
||||||
@ -361,18 +370,17 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
|
|||||||
" input: ['E'] }");
|
" input: ['E'] }");
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
|
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
|
||||||
"E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);"
|
"E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);"
|
||||||
"O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;DMT/_0->F:1;E->F;"
|
"N(_MklInput);O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;"
|
||||||
"M->D:3;N->D:4;O->D:5");
|
"DMT/_0->F:1;E->F;E:control->DMT/_0:control;M->D:3;N->D:4;"
|
||||||
|
"O->D:5");
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
// No _MklConv2D in context, but Conv2D in context.
|
// No _MklConv2DWithBias in context, but _MklConv2D in context.
|
||||||
// Only Conv2D would be rewritten to _MklConv2D, but no rewrite
|
// No rewrite for BiasAddGrad should happen.
|
||||||
// for BiasAddGrad should happen.
|
|
||||||
// C=_MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved)
|
// C=_MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved)
|
||||||
// C=_MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous)
|
// C=_MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous)
|
||||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_No_MklConv2DWithBias) {
|
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
|
||||||
InitGraph(
|
InitGraph(
|
||||||
"node { name: 'A' op: 'Input'}"
|
"node { name: 'A' op: 'Input'}"
|
||||||
"node { name: 'B' op: 'Input'}"
|
"node { name: 'B' op: 'Input'}"
|
||||||
@ -507,8 +515,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) {
|
|||||||
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['B', 'C'] }");
|
" input: ['B', 'C'] }");
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|"
|
"A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);"
|
||||||
"A->C;B->C:1;B->D;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
|
"DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
|
||||||
|
"A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
|
||||||
|
"DMT/_1->C:3");
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will
|
// 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will
|
||||||
@ -535,7 +545,9 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
|
|||||||
" input: ['C', 'D'] }");
|
" input: ['C', 'D'] }");
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);"
|
"A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);"
|
||||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;B->C:1;C->D:1;C->E;"
|
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;"
|
||||||
|
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||||
|
"A:control->DMT/_2:control;B->C:1;C->D:1;C->E;"
|
||||||
"C:1->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
|
"C:1->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -558,6 +570,50 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
|
|||||||
"A->C;B->C:1;B->D;C->D:1");
|
"A->C;B->C:1;B->D;C->D:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Int32Input'}"
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'D' op: 'Conv2DBackpropFilter'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'B', 'C']}"
|
||||||
|
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'D'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropFilter);"
|
||||||
|
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Mul)|"
|
||||||
|
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||||
|
"A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
|
||||||
|
"DMT/_1->D:4;DMT/_2->D:5");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradInput_Positive) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Int32Input'}"
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'D' op: 'Conv2DBackpropInput'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['B', 'A', 'C']}"
|
||||||
|
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'D'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropInput);"
|
||||||
|
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Mul)|"
|
||||||
|
"A->D:1;A->E;B->D;B:control->DMT/_0:control;"
|
||||||
|
"B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;"
|
||||||
|
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
||||||
|
}
|
||||||
|
|
||||||
// Concat Op test: Concat with no Mkl layer feeding it
|
// Concat Op test: Concat with no Mkl layer feeding it
|
||||||
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
|
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
|
||||||
InitGraph(
|
InitGraph(
|
||||||
@ -572,13 +628,14 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
|
|||||||
"node { name: 'D' op: 'Concat'"
|
"node { name: 'D' op: 'Concat'"
|
||||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" attr { key: 'N' value { i: 2 } }"
|
" attr { key: 'N' value { i: 2 } }"
|
||||||
" input: ['A', 'B']}"
|
" input: ['A', 'B:0', 'B:1']}"
|
||||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['C', 'D'] }");
|
" input: ['C', 'D'] }");
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
|
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
|
||||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;B->D:1;B->D:2;C->E;"
|
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;A:control->DMT/_0:control;"
|
||||||
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;"
|
||||||
|
"B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Concat with 2 Mkl layers feeding it
|
// Concat with 2 Mkl layers feeding it
|
||||||
@ -616,9 +673,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
|
|||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
|
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
|
||||||
"F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;"
|
"F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;"
|
||||||
|
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
|
||||||
|
"B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;"
|
||||||
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
|
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
|
||||||
"DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;H->I:1");
|
"DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;"
|
||||||
|
"G:control->DMT/_4:control;H->I:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Concat with 1 Mkl and 1 non-Mkl layer feeding it
|
// Concat with 1 Mkl and 1 non-Mkl layer feeding it
|
||||||
@ -651,12 +711,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
|
|||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||||
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
|
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
|
||||||
"H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
|
"H(_MklConcat);I(Mul)|A->E;A->I;A:control->DMT/_0:control;"
|
||||||
|
"A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
|
||||||
"DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:1->H:4;F->H:2;"
|
"DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:1->H:4;F->H:2;"
|
||||||
"G->H;H->I:1");
|
"G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
#if 0
|
|
||||||
// ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it
|
// ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it
|
||||||
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
|
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
|
||||||
InitGraph(
|
InitGraph(
|
||||||
@ -676,11 +736,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
|
|||||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['C', 'D'] }");
|
" input: ['C', 'D'] }");
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
|
"A(Const);B(InputList);C(Input);D(_MklConcatV2);DMT/_0(Const);"
|
||||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;C->E;"
|
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;"
|
||||||
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
"B:control->DMT/_0:control;B:control->DMT/_1:control;"
|
||||||
|
"B:control->DMT/_2:control;C->E;D->E:1;DMT/_0->D:3;"
|
||||||
|
"DMT/_1->D:4;DMT/_2->D:5");
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
// ConcatV2 with 2 Mkl layers feeding it
|
// ConcatV2 with 2 Mkl layers feeding it
|
||||||
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
|
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
|
||||||
@ -718,9 +779,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
|
|||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
|
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
|
||||||
"F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;"
|
"F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;"
|
||||||
|
"A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;"
|
||||||
|
"C:control->DMT/_0:control;C:control->DMT/_1:control;"
|
||||||
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
|
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
|
||||||
"DMT/_4->H:5;E->H;E:1->H:3;F->H:1;F:1->H:4;G->H:2;H->I:1");
|
"DMT/_4->H:5;E->H;E:1->H:3;E:control->DMT/_4:control;F->H:1;"
|
||||||
|
"F:1->H:4;G->H:2;H->I:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it
|
// ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it
|
||||||
@ -754,11 +818,175 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
|
|||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||||
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
|
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
|
||||||
"H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
|
"H(_MklConcatV2);I(Mul)|A->E;A->I;A:control->DMT/_0:control;"
|
||||||
"DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;F->H:1;"
|
"A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
|
||||||
|
"DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;"
|
||||||
|
"E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;"
|
||||||
"G->H:2;H->I:1");
|
"G->H:2;H->I:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_Relu_Positive) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Relu'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A'] }"
|
||||||
|
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(_MklRelu);C(Mul);DMT/_0(Const)|A->B;A->C;"
|
||||||
|
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_Positive) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'ReluGrad'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }"
|
||||||
|
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'C'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Input);C(_MklReluGrad);D(Mul);DMT/_0(Const);"
|
||||||
|
"DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;"
|
||||||
|
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_ReluReluGrad_Positive) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Relu'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A'] }"
|
||||||
|
"node { name: 'C' op: 'ReluGrad'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }"
|
||||||
|
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'C'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(_MklRelu);C(_MklReluGrad);D(Mul);DMT/_0(Const);"
|
||||||
|
"DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;"
|
||||||
|
"A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;"
|
||||||
|
"DMT/_1->C:2");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'AvgPool'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
|
||||||
|
" input: ['A'] }"
|
||||||
|
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(_MklAvgPool);C(Mul);DMT/_0(Const)|A->B;A->C;"
|
||||||
|
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolGrad_Positive) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Int32Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'AvgPoolGrad' "
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
|
||||||
|
" input: ['A', 'B'] }"
|
||||||
|
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['B', 'C'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Mul);DMT/_0(Const);"
|
||||||
|
"DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
|
||||||
|
"A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
|
||||||
|
"DMT/_1->C:3");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolAvgPoolGrad_Positive) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'I' op: 'Int32Input'}"
|
||||||
|
"node { name: 'B' op: 'AvgPool'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
|
||||||
|
" input: ['A'] }"
|
||||||
|
"node { name: 'C' op: 'AvgPoolGrad' "
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
|
||||||
|
" input: ['I', 'B'] }"
|
||||||
|
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'C'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Mul);DMT/_0(Const);"
|
||||||
|
"DMT/_1(Const);I(Int32Input)|A->B;A->D;A:control->DMT/_0:control;"
|
||||||
|
"B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;DMT/_1->C:2;I->C;"
|
||||||
|
"I:control->DMT/_1:control");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGrad_Positive) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'D' op: 'Input'}"
|
||||||
|
"node { name: 'E' op: 'Input'}"
|
||||||
|
"node { name: 'F' op: 'FusedBatchNormGrad'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'epsilon' value { f: 0.0001 } }"
|
||||||
|
" attr { key: 'is_training' value { b: true } }"
|
||||||
|
" input: ['A', 'B', 'C', 'D', 'E'] }"
|
||||||
|
"node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'F'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||||
|
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
|
||||||
|
"F(_MklFusedBatchNormGrad);G(Mul)|A->F;A->G;"
|
||||||
|
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||||
|
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
|
||||||
|
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
|
||||||
|
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
|
||||||
|
"E->F:4;F->G:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_Positive) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'D' op: 'Input'}"
|
||||||
|
"node { name: 'E' op: 'Input'}"
|
||||||
|
"node { name: 'F' op: 'FusedBatchNorm'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'epsilon' value { f: 0.0001 } }"
|
||||||
|
" attr { key: 'is_training' value { b: true } }"
|
||||||
|
" input: ['A', 'B', 'C', 'D', 'E'] }"
|
||||||
|
"node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'F'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||||
|
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
|
||||||
|
"F(_MklFusedBatchNorm);G(Mul)|A->F;A->G;"
|
||||||
|
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||||
|
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
|
||||||
|
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
|
||||||
|
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
|
||||||
|
"E->F:4;F->G:1");
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
// Unit tests related to rewriting node for workspace edges
|
// Unit tests related to rewriting node for workspace edges
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
@ -802,13 +1030,13 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
|
|||||||
"node { name: 'H' op: 'Input'}"
|
"node { name: 'H' op: 'Input'}"
|
||||||
"node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
"node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['H', 'G'] }");
|
" input: ['H', 'G'] }");
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
DoMklLayoutOptimizationPass(),
|
|
||||||
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||||
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);I(Mul)|"
|
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);"
|
||||||
"A->B;B->C;B->E;B->G:2;B:1->G:3;B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;"
|
"I(Mul)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
|
||||||
"C->E:1;C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;"
|
"B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;B:control->DMT/_1:control;C->E:1;"
|
||||||
"DMT/_2->G:5;E->G;E:1->G:4;F->G:1;G->I:1;H->I");
|
"C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;DMT/_2->G:5;"
|
||||||
|
"E->G;E:1->G:4;E:control->DMT/_2:control;F->G:1;G->I:1;H->I");
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Test LRN->LRNGrad replacement by workspace nodes. */
|
/* Test LRN->LRNGrad replacement by workspace nodes. */
|
||||||
@ -838,8 +1066,9 @@ TEST_F(MklLayoutPassTest, LRN_Positive) {
|
|||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||||
"DMT/_2(Const);E(_MklLRNGrad);F(Mul)|"
|
"DMT/_2(Const);E(_MklLRNGrad);F(Mul)|"
|
||||||
"A->B;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;C->E;C->F;D->E:1;"
|
"A->B;A:control->DMT/_0:control;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;"
|
||||||
"DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
|
"C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
|
||||||
|
"D->E:1;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Test LRN->LRNGrad replacement when only one of them is present. */
|
/* Test LRN->LRNGrad replacement when only one of them is present. */
|
||||||
@ -858,7 +1087,7 @@ TEST_F(MklLayoutPassTest, LRN_Negative1) {
|
|||||||
" input: ['A', 'B'] }");
|
" input: ['A', 'B'] }");
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(_MklLRN);C(Mul);DMT/_0(Const)|"
|
"A(Input);B(_MklLRN);C(Mul);DMT/_0(Const)|"
|
||||||
"A->B;A->C;B->C:1;DMT/_0->B:1");
|
"A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Test LRN->LRNGrad replacement when only one of them is present. */
|
/* Test LRN->LRNGrad replacement when only one of them is present. */
|
||||||
@ -880,8 +1109,10 @@ TEST_F(MklLayoutPassTest, LRN_Negative2) {
|
|||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
|
"A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
|
||||||
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
|
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
|
||||||
"A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
|
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||||
"DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
|
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
|
||||||
|
"A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
|
||||||
|
"DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Test LRN->LRNGrad negative case, where single LRN feeds
|
/* Test LRN->LRNGrad negative case, where single LRN feeds
|
||||||
@ -920,9 +1151,13 @@ TEST_F(MklLayoutPassTest, LRN_Negative3) {
|
|||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
|
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
|
||||||
"DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;B->E:2;"
|
"DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;"
|
||||||
"B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;D->E:1;"
|
"A:control->DMT/_0:control;B->E:2;"
|
||||||
"D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
|
"B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;"
|
||||||
|
"C:control->DMT/_1:control;C:control->DMT/_2:control;"
|
||||||
|
"C:control->DMT/_3:control;C:control->DMT/_4:control;"
|
||||||
|
"C:control->DMT/_5:control;C:control->DMT/_6:control;"
|
||||||
|
"D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
|
||||||
"DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1");
|
"DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -951,8 +1186,9 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
|
|||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
|
"A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
|
||||||
"DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Mul)|"
|
"DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Mul)|"
|
||||||
"A->B;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;C->E;C->F;D->E:2;"
|
"A->B;A:control->DMT/_0:control;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;"
|
||||||
"DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
|
"C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
|
||||||
|
"D->E:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test MaxPool>MaxPoolGrad replacement when only one of them is present.
|
// Test MaxPool>MaxPoolGrad replacement when only one of them is present.
|
||||||
@ -972,7 +1208,7 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) {
|
|||||||
" input: ['A', 'B'] }");
|
" input: ['A', 'B'] }");
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(_MklMaxPool);C(Mul);DMT/_0(Const)|"
|
"A(Input);B(_MklMaxPool);C(Mul);DMT/_0(Const)|"
|
||||||
"A->B;A->C;B->C:1;DMT/_0->B:1");
|
"A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test MaxPoolGrad replacement when only one of them is present.
|
// Test MaxPoolGrad replacement when only one of them is present.
|
||||||
@ -995,8 +1231,374 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
|
|||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
|
"A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
|
||||||
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
|
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
|
||||||
"A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
|
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||||
"DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
|
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
|
||||||
|
"A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
|
||||||
|
"DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MaxPool handling for batch-wise pooling (NCHW)
|
||||||
|
// No rewrite should take place in such case
|
||||||
|
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative3) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'MaxPool'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" input: ['A'] }"
|
||||||
|
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MaxPool handling for batch-wise pooling (NCHW)
|
||||||
|
// No rewrite should take place in such case
|
||||||
|
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative4) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'MaxPool'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }"
|
||||||
|
" input: ['A'] }"
|
||||||
|
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MaxPool handling for depth-wise pooling (NHWC)
|
||||||
|
// No rewrite should take place in such case
|
||||||
|
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative5) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'MaxPool'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'ksize' value { list: {i: 1, i:2, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" input: ['A'] }"
|
||||||
|
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MaxPool handling for depth-wise pooling (NCHW)
|
||||||
|
// No rewrite should take place in such case
|
||||||
|
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative6) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'MaxPool'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:2, i:1, i:1} } }"
|
||||||
|
" input: ['A'] }"
|
||||||
|
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MaxPool handling for batch-wise pooling (NHWC)
|
||||||
|
// No rewrite should take place in such case
|
||||||
|
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative7) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'MaxPool'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NHWC' } }"
|
||||||
|
" attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" input: ['A'] }"
|
||||||
|
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MaxPool handling for batch-wise pooling (NHWC)
|
||||||
|
// No rewrite should take place in such case
|
||||||
|
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative8) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'MaxPool'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NHWC' } }"
|
||||||
|
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }"
|
||||||
|
" input: ['A'] }"
|
||||||
|
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MaxPool handling for depth-wise pooling (NHWC)
|
||||||
|
// No rewrite should take place in such case
|
||||||
|
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative9) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'MaxPool'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NHWC' } }"
|
||||||
|
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:2} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" input: ['A'] }"
|
||||||
|
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MaxPool handling for depth-wise pooling (NHWC)
|
||||||
|
// No rewrite should take place in such case
|
||||||
|
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative10) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'MaxPool'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NHWC' } }"
|
||||||
|
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:2} } }"
|
||||||
|
" input: ['A'] }"
|
||||||
|
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Single Conv2D Op on GPU device
|
||||||
|
// No rewrite should happen
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Conv2D'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'B']}"
|
||||||
|
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['B', 'C'] }", kGPUDevice);
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Input);C(Conv2D);D(Mul)|A->C;B->C:1;B->D;C->D:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'M' op: '_MklInput'}"
|
||||||
|
"node { name: 'N' op: '_MklInput'}"
|
||||||
|
"node { name: 'O' op: '_MklInput'}"
|
||||||
|
"node { name: 'D' op: '_MklConv2DWithBias'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
|
||||||
|
"node { name: 'E' op: 'Sub'"
|
||||||
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['D', 'A']}"
|
||||||
|
"node { name: 'F' op: 'BiasAddGrad'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" input: ['E'] }", kGPUDevice);
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
|
||||||
|
"E(Sub);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
|
||||||
|
"O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;E->F;"
|
||||||
|
"M->D:3;N->D:4;O->D:5");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Int32Input'}"
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'D' op: 'Conv2DBackpropFilter'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'B', 'C']}"
|
||||||
|
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'D'] }", kGPUDevice);
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Mul)|"
|
||||||
|
"A->D;A->E;B->D:1;C->D:2;D->E:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Relu'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A'] }"
|
||||||
|
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }", kGPUDevice);
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Relu);C(Mul)|A->B;A->C;B->C:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'ReluGrad'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }"
|
||||||
|
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'C'] }", kGPUDevice);
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Input);C(ReluGrad);D(Mul)|A->C;A->D;B->C:1;C->D:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'MaxPool'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NHWC' } }"
|
||||||
|
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" input: ['A'] }"
|
||||||
|
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }", kGPUDevice);
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'AvgPool'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NHWC' } }"
|
||||||
|
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" input: ['A'] }"
|
||||||
|
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'B'] }", kGPUDevice);
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(AvgPool);C(Mul)|A->B;A->C;B->C:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concat Op test: Concat with no Mkl layer feeding it
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Const' "
|
||||||
|
" attr { key: 'dtype' value { type: DT_INT32 } }"
|
||||||
|
" attr { key: 'value' value { "
|
||||||
|
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
|
||||||
|
" int_val: 0 } } } }"
|
||||||
|
"node { name: 'B' op: 'InputList'"
|
||||||
|
" attr { key: 'N' value { i: 2 } }}"
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'D' op: 'Concat'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'N' value { i: 2 } }"
|
||||||
|
" input: ['A', 'B:0', 'B:1']}"
|
||||||
|
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['C', 'D'] }", kGPUDevice);
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Const);B(InputList);C(Input);D(Concat);E(Mul)|A->D;"
|
||||||
|
"B->D:1;B:1->D:2;C->E;D->E:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Const' "
|
||||||
|
" attr { key: 'dtype' value { type: DT_INT32 } }"
|
||||||
|
" attr { key: 'value' value { "
|
||||||
|
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
|
||||||
|
" int_val: 0 } } } }"
|
||||||
|
"node { name: 'B' op: 'InputList'"
|
||||||
|
" attr { key: 'N' value { i: 2 } }}"
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'D' op: 'ConcatV2'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'Tidx' value { type: DT_INT32 } }"
|
||||||
|
" attr { key: 'N' value { i: 2 } }"
|
||||||
|
" input: ['B:0', 'B:1', 'A']}"
|
||||||
|
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['C', 'D'] }", kGPUDevice);
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Const);B(InputList);C(Input);D(ConcatV2);E(Mul)|"
|
||||||
|
"A->D:2;B->D;B:1->D:1;C->E;D->E:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'D' op: 'Input'}"
|
||||||
|
"node { name: 'E' op: 'Input'}"
|
||||||
|
"node { name: 'F' op: 'FusedBatchNorm'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'epsilon' value { f: 0.0001 } }"
|
||||||
|
" attr { key: 'is_training' value { b: true } }"
|
||||||
|
" input: ['A', 'B', 'C', 'D', 'E'] }"
|
||||||
|
"node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['A', 'F'] }", kGPUDevice);
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Input);C(Input);D(Input);E(Input);"
|
||||||
|
"F(FusedBatchNorm);G(Mul)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
|
||||||
|
"E->F:4;F->G:1");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
|
||||||
|
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'M' op: '_MklInput'}"
|
||||||
|
"node { name: 'N' op: '_MklInput'}"
|
||||||
|
"node { name: 'C' op: '_MklConv2D'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'B', 'M', 'N']}"
|
||||||
|
"node { name: 'D' op: 'Input'}"
|
||||||
|
"node { name: 'E' op: 'BiasAdd'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" input: ['C', 'D'] }"
|
||||||
|
"node { name: 'Y' op: 'Input'}"
|
||||||
|
"node { name: 'Z' op: 'Sub'"
|
||||||
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['E', 'Y']}", kGPUDevice);
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
|
||||||
|
"M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->C;"
|
||||||
|
"B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
|
@ -98,12 +98,13 @@ class MklToTfConversionPass : public GraphOptimizationPass {
|
|||||||
Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*);
|
Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*);
|
||||||
};
|
};
|
||||||
|
|
||||||
// We register MklToTf insertion for phase 1 in post-partition grouping.
|
// We register MklToTf insertion for phase 2 in post-partition grouping
|
||||||
// We register this pass after partitioning so that we get a complete
|
// because we register MklLayoutRewritePass for phase 1 in post-partition
|
||||||
// picture of inputs and outputs of the nodes in the graphs.
|
// grouping. We register this pass after partitioning so that we get a
|
||||||
|
// complete picture of inputs and outputs of the nodes in the graphs.
|
||||||
const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
|
const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
|
||||||
OptimizationPassRegistry::POST_PARTITIONING;
|
OptimizationPassRegistry::POST_PARTITIONING;
|
||||||
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 1, MklToTfConversionPass);
|
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
|
||||||
|
|
||||||
Status MklToTfConversionPass::InsertConversionNodeOnEdge(
|
Status MklToTfConversionPass::InsertConversionNodeOnEdge(
|
||||||
std::unique_ptr<Graph>* g, Edge* e) {
|
std::unique_ptr<Graph>* g, Edge* e) {
|
||||||
@ -121,10 +122,12 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
|
|||||||
string data_format;
|
string data_format;
|
||||||
|
|
||||||
TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype));
|
TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype));
|
||||||
TF_CHECK_OK(GetNodeAttr(dst->def(), "T", &dst_datatype));
|
bool dst_dtype_found = GetNodeAttr(dst->def(), "T", &dst_datatype) ==
|
||||||
if (src_datatype != dst_datatype) {
|
Status::OK();
|
||||||
string err_msg = "T attribute of " + src->name() + " and " + dst->name() +
|
// We compare source and destination datatypes only when both are found.
|
||||||
" do not match. Will not insert" +
|
if (dst_dtype_found && (src_datatype != dst_datatype)) {
|
||||||
|
string err_msg = "T attribute of " + src->name() + " and " +
|
||||||
|
dst->name() + " do not match. Will not insert" +
|
||||||
" MklToTf node in such case.";
|
" MklToTf node in such case.";
|
||||||
return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
|
return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
|
||||||
}
|
}
|
||||||
@ -202,18 +205,19 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
|
|||||||
<< src->type_string() << " and " << dst->type_string();
|
<< src->type_string() << " and " << dst->type_string();
|
||||||
|
|
||||||
// Let's get source and destination data type.
|
// Let's get source and destination data type.
|
||||||
DataType src_datatype = DT_INVALID;
|
|
||||||
if (GetNodeAttr(src->def(), "T", &src_datatype) != Status::OK()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// We cannot check datatype on destination node because destination node
|
// We cannot check datatype on destination node because destination node
|
||||||
// may not be Mkl node.
|
// may not be Mkl node.
|
||||||
DataType dst_datatype = DT_INVALID;
|
DataType src_datatype;
|
||||||
GetNodeAttr(dst->def(), "T", &dst_datatype);
|
DataType dst_datatype;
|
||||||
|
bool src_is_mkl_op = (GetNodeAttr(src->def(), "T", &src_datatype) ==
|
||||||
|
Status::OK() &&
|
||||||
|
IsMklSupportedOp(src->type_string(), src_datatype));
|
||||||
|
bool dst_is_mkl_op = (GetNodeAttr(dst->def(), "T", &dst_datatype) ==
|
||||||
|
Status::OK() &&
|
||||||
|
IsMklSupportedOp(dst->type_string(), dst_datatype));
|
||||||
|
|
||||||
// Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
|
// Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
|
||||||
if (IsMklSupportedOp(src->type_string(), src_datatype) &&
|
if (src_is_mkl_op && !dst_is_mkl_op) {
|
||||||
!IsMklSupportedOp(dst->type_string(), dst_datatype)) {
|
|
||||||
VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name()
|
VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name()
|
||||||
<< " and " << dst->name() << " for inserting conversion nodes";
|
<< " and " << dst->name() << " for inserting conversion nodes";
|
||||||
candidate_edges.push_back(const_cast<Edge*>(e));
|
candidate_edges.push_back(const_cast<Edge*>(e));
|
||||||
|
@ -149,7 +149,7 @@ TEST_F(MklToTfConversionPass, Positive) {
|
|||||||
" input: ['C', 'D']}");
|
" input: ['C', 'D']}");
|
||||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
|
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
|
||||||
"_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
|
"Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
|
||||||
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
|
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
|
||||||
} else {
|
} else {
|
||||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||||
@ -172,7 +172,7 @@ TEST_F(MklToTfConversionPass, Positive) {
|
|||||||
" input: ['C', 'D']}");
|
" input: ['C', 'D']}");
|
||||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
|
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
|
||||||
"_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;"
|
"Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;"
|
||||||
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
|
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -265,6 +265,7 @@ class MklConcatOp : public OpKernel {
|
|||||||
s.GetDimension() > 0 ? s.GetSizes()[concat_dim] : 1;
|
s.GetDimension() > 0 ? s.GetSizes()[concat_dim] : 1;
|
||||||
}
|
}
|
||||||
mkl_context.MklCreateInputLayouts(context, input_shapes);
|
mkl_context.MklCreateInputLayouts(context, input_shapes);
|
||||||
|
OP_REQUIRES_OK(context, context->status());
|
||||||
|
|
||||||
CHECK_EQ(dnnConcatCreate_F32(&mkl_context.prim_concat, NULL, N,
|
CHECK_EQ(dnnConcatCreate_F32(&mkl_context.prim_concat, NULL, N,
|
||||||
&mkl_context.lt_inputs[0]),
|
&mkl_context.lt_inputs[0]),
|
||||||
@ -316,12 +317,14 @@ class MklConcatOp : public OpKernel {
|
|||||||
|
|
||||||
mkl_context.mkl_tmp_tensors.resize(N);
|
mkl_context.mkl_tmp_tensors.resize(N);
|
||||||
mkl_context.MklPrepareConcatInputs(context, input_tensors);
|
mkl_context.MklPrepareConcatInputs(context, input_tensors);
|
||||||
|
OP_REQUIRES_OK(context, context->status());
|
||||||
|
|
||||||
// Execute primitive.
|
// Execute primitive.
|
||||||
CHECK_EQ(dnnExecute_F32(mkl_context.prim_concat, mkl_context.concat_res),
|
CHECK_EQ(dnnExecute_F32(mkl_context.prim_concat, mkl_context.concat_res),
|
||||||
E_SUCCESS);
|
E_SUCCESS);
|
||||||
|
|
||||||
mkl_context.MklCleanup();
|
mkl_context.MklCleanup();
|
||||||
|
OP_REQUIRES_OK(context, context->status());
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -38,9 +38,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/use_cudnn.h"
|
#include "tensorflow/core/util/use_cudnn.h"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
#include "third_party/mkl/include/mkl_dnn.h"
|
#include "third_party/mkl/include/mkl_dnn.h"
|
||||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||||
#include "tensorflow/core/util/mkl_util.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -252,7 +252,7 @@ class MklConv2DCustomBackpropBiasOp : public OpKernel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_CPU_KERNELS(T) \
|
#define REGISTER_CPU_KERNELS(T) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBiasBackpropBias") \
|
REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBiasBackpropBias") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.Label(mkl_op_registry::kMklOpLabel), \
|
.Label(mkl_op_registry::kMklOpLabel), \
|
||||||
|
@ -37,9 +37,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/use_cudnn.h"
|
#include "tensorflow/core/util/use_cudnn.h"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
#include "third_party/mkl/include/mkl_dnn.h"
|
#include "third_party/mkl/include/mkl_dnn.h"
|
||||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||||
#include "tensorflow/core/util/mkl_util.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -266,8 +266,11 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
|
|||||||
int input_offsets[2];
|
int input_offsets[2];
|
||||||
size_t conv_strides[2];
|
size_t conv_strides[2];
|
||||||
MklShape input_shape, grad_filter_shape, out_backprop_shape;
|
MklShape input_shape, grad_filter_shape, out_backprop_shape;
|
||||||
dnnPrimitive_t prim_conv_bwdfilter, convert_bwdfilter;
|
dnnPrimitive_t prim_conv_bwdfilter = nullptr;
|
||||||
dnnLayout_t lt_input, lt_grad_filter, lt_out_backprop;
|
dnnPrimitive_t convert_bwdfilter = nullptr;
|
||||||
|
dnnLayout_t lt_input = nullptr;
|
||||||
|
dnnLayout_t lt_grad_filter = nullptr;
|
||||||
|
dnnLayout_t lt_out_backprop = nullptr;
|
||||||
void* conv_res[dnnResourceNumber];
|
void* conv_res[dnnResourceNumber];
|
||||||
|
|
||||||
void MklCleanup() {
|
void MklCleanup() {
|
||||||
@ -409,7 +412,7 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_MKL_FILTER_KERNELS(T) \
|
#define REGISTER_MKL_FILTER_KERNELS(T) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
|
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.Label(mkl_op_registry::kMklOpLabel), \
|
.Label(mkl_op_registry::kMklOpLabel), \
|
||||||
|
@ -23,8 +23,6 @@ limitations under the License.
|
|||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "third_party/mkl/include/mkl_dnn.h"
|
|
||||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
|
||||||
#include "tensorflow/core/framework/numeric_op.h"
|
#include "tensorflow/core/framework/numeric_op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
@ -42,6 +40,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
#include "tensorflow/core/util/use_cudnn.h"
|
#include "tensorflow/core/util/use_cudnn.h"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
#include "third_party/mkl/include/mkl_dnn.h"
|
||||||
|
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -342,7 +342,7 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_MKL_CPU_KERNELS(T) \
|
#define REGISTER_MKL_CPU_KERNELS(T) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
|
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.Label(mkl_op_registry::kMklOpLabel), \
|
.Label(mkl_op_registry::kMklOpLabel), \
|
||||||
|
@ -36,9 +36,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/padding.h"
|
#include "tensorflow/core/util/padding.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
#include "third_party/mkl/include/mkl_dnn.h"
|
#include "third_party/mkl/include/mkl_dnn.h"
|
||||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||||
#include "tensorflow/core/util/mkl_util.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -98,19 +98,18 @@ class MklConv2DOp : public OpKernel {
|
|||||||
filter.shape().DebugString()));
|
filter.shape().DebugString()));
|
||||||
|
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i),
|
||||||
context,
|
std::numeric_limits<int>::max()),
|
||||||
FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
|
errors::InvalidArgument("filter too large"));
|
||||||
errors::InvalidArgument("filter too large"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64 input_depth =
|
const int64 input_depth =
|
||||||
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'C')
|
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'C')
|
||||||
: GetTensorDim(input, data_format_, 'C');
|
: GetTensorDim(input, data_format_, 'C');
|
||||||
OP_REQUIRES(context, input_depth == filter.dim_size(2),
|
OP_REQUIRES(
|
||||||
errors::InvalidArgument(
|
context, input_depth == filter.dim_size(2),
|
||||||
"input and filter must have the same depth: ", input_depth,
|
errors::InvalidArgument("input and filter must have the same depth: ",
|
||||||
" vs ", filter.dim_size(2)));
|
input_depth, " vs ", filter.dim_size(2)));
|
||||||
// The last dimension for filter is out_depth.
|
// The last dimension for filter is out_depth.
|
||||||
const int out_depth = static_cast<int>(filter.dim_size(3));
|
const int out_depth = static_cast<int>(filter.dim_size(3));
|
||||||
|
|
||||||
@ -119,10 +118,9 @@ class MklConv2DOp : public OpKernel {
|
|||||||
const int64 input_rows_raw =
|
const int64 input_rows_raw =
|
||||||
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'H')
|
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'H')
|
||||||
: GetTensorDim(input, data_format_, 'H');
|
: GetTensorDim(input, data_format_, 'H');
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(context, FastBoundsCheck(input_rows_raw,
|
||||||
context,
|
std::numeric_limits<int>::max()),
|
||||||
FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
|
errors::InvalidArgument("Input rows too large"));
|
||||||
errors::InvalidArgument("Input rows too large"));
|
|
||||||
const int input_rows = static_cast<int>(input_rows_raw);
|
const int input_rows = static_cast<int>(input_rows_raw);
|
||||||
const int filter_rows = static_cast<int>(filter.dim_size(0));
|
const int filter_rows = static_cast<int>(filter.dim_size(0));
|
||||||
|
|
||||||
@ -131,10 +129,9 @@ class MklConv2DOp : public OpKernel {
|
|||||||
const int64 input_cols_raw =
|
const int64 input_cols_raw =
|
||||||
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'W')
|
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'W')
|
||||||
: GetTensorDim(input, data_format_, 'W');
|
: GetTensorDim(input, data_format_, 'W');
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(context, FastBoundsCheck(input_cols_raw,
|
||||||
context,
|
std::numeric_limits<int>::max()),
|
||||||
FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
|
errors::InvalidArgument("Input cols too large"));
|
||||||
errors::InvalidArgument("Input cols too large"));
|
|
||||||
const int input_cols = static_cast<int>(input_cols_raw);
|
const int input_cols = static_cast<int>(input_cols_raw);
|
||||||
const int filter_cols = static_cast<int>(filter.dim_size(1));
|
const int filter_cols = static_cast<int>(filter.dim_size(1));
|
||||||
|
|
||||||
@ -142,10 +139,9 @@ class MklConv2DOp : public OpKernel {
|
|||||||
const int64 input_batch_raw =
|
const int64 input_batch_raw =
|
||||||
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'N')
|
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'N')
|
||||||
: GetTensorDim(input, data_format_, 'N');
|
: GetTensorDim(input, data_format_, 'N');
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(context, FastBoundsCheck(input_batch_raw,
|
||||||
context,
|
std::numeric_limits<int>::max()),
|
||||||
FastBoundsCheck(input_batch_raw, std::numeric_limits<int>::max()),
|
errors::InvalidArgument("batch is too large"));
|
||||||
errors::InvalidArgument("batch is too large"));
|
|
||||||
const int batch = static_cast<int>(input_batch_raw);
|
const int batch = static_cast<int>(input_batch_raw);
|
||||||
|
|
||||||
// For now we take the stride from the second and third dimensions only (we
|
// For now we take the stride from the second and third dimensions only (we
|
||||||
@ -438,12 +434,12 @@ class MklConv2DOp : public OpKernel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_MKL_CPU(T) \
|
#define REGISTER_MKL_CPU(T) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
|
REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.Label(mkl_op_registry::kMklOpLabel), \
|
.Label(mkl_op_registry::kMklOpLabel), \
|
||||||
MklConv2DOp<CPUDevice, T, false>); \
|
MklConv2DOp<CPUDevice, T, false>); \
|
||||||
REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \
|
REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.Label(mkl_op_registry::kMklOpLabel), \
|
.Label(mkl_op_registry::kMklOpLabel), \
|
||||||
|
@ -104,6 +104,15 @@ class MklLRNOp : public OpKernel {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(inteltf) MKL will support depth radius not equal to 2 in the future
|
||||||
|
if (depth_radius_ != 2) {
|
||||||
|
Tensor converted_tensor =
|
||||||
|
ConvertMklToTF<T>(context, input, mkl_context.input_shape);
|
||||||
|
mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
|
||||||
|
beta_, converted_tensor);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (input_in_mkl_format) {
|
if (input_in_mkl_format) {
|
||||||
// MKL supports normalization over channel dimension only
|
// MKL supports normalization over channel dimension only
|
||||||
if (mkl_context.input_shape.tf_dim_idx(mkl_context.in_dims - 1) ==
|
if (mkl_context.input_shape.tf_dim_idx(mkl_context.in_dims - 1) ==
|
||||||
@ -112,8 +121,10 @@ class MklLRNOp : public OpKernel {
|
|||||||
static_cast<dnnLayout_t>(mkl_context.input_shape.GetCurLayout());
|
static_cast<dnnLayout_t>(mkl_context.input_shape.GetCurLayout());
|
||||||
workspace_enabled_ = true;
|
workspace_enabled_ = true;
|
||||||
} else {
|
} else {
|
||||||
|
Tensor converted_tensor =
|
||||||
|
ConvertMklToTF<T>(context, input, mkl_context.input_shape);
|
||||||
mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
|
mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
|
||||||
beta_, input);
|
beta_, converted_tensor);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -267,7 +278,7 @@ class MklLRNOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fallback implementation - Taken from lrn_op.cc
|
// Fallback implementation - Taken from lrn_op.cc
|
||||||
// TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
|
// TODO(inteltf) Check if we can use EigenLRNOp directly instead of making a
|
||||||
// copy.
|
// copy.
|
||||||
void MklDefaultToEigen(OpKernelContext* context, int depth_radius_,
|
void MklDefaultToEigen(OpKernelContext* context, int depth_radius_,
|
||||||
float bias_, float alpha_, float beta_,
|
float bias_, float alpha_, float beta_,
|
||||||
@ -378,6 +389,12 @@ class MklLRNGradOp : public OpKernel {
|
|||||||
mkl_context.MklDefaultToEigen(context);
|
mkl_context.MklDefaultToEigen(context);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (depth_radius_ != 2) {
|
||||||
|
mkl_context.MklDefaultToEigen(context);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (ingrad_in_mkl_format || inimage_in_mkl_format) {
|
if (ingrad_in_mkl_format || inimage_in_mkl_format) {
|
||||||
const MklShape* tmp_mkl_shape = (ingrad_in_mkl_format)
|
const MklShape* tmp_mkl_shape = (ingrad_in_mkl_format)
|
||||||
? &mkl_context.ingrad_shape
|
? &mkl_context.ingrad_shape
|
||||||
@ -489,14 +506,11 @@ class MklLRNGradOp : public OpKernel {
|
|||||||
MklShape ingrad_shape, inimage_shape, outimage_shape;
|
MklShape ingrad_shape, inimage_shape, outimage_shape;
|
||||||
dnnPrimitive_t lrn_bwd = nullptr;
|
dnnPrimitive_t lrn_bwd = nullptr;
|
||||||
dnnPrimitive_t convert_input = nullptr;
|
dnnPrimitive_t convert_input = nullptr;
|
||||||
/* dnnPrimitive_t convert_output; */
|
|
||||||
dnnLayout_t lt_input = nullptr;
|
dnnLayout_t lt_input = nullptr;
|
||||||
dnnLayout_t lt_output = nullptr;
|
dnnLayout_t lt_output = nullptr;
|
||||||
dnnLayout_t lt_bdw_input = nullptr;
|
dnnLayout_t lt_bdw_input = nullptr;
|
||||||
dnnLayout_t lt_workspace = nullptr;
|
dnnLayout_t lt_workspace = nullptr;
|
||||||
dnnLayout_t lt_internal_input = nullptr;
|
dnnLayout_t lt_internal_input = nullptr;
|
||||||
/* dnnLayout_t lt_internal_workspace;
|
|
||||||
dnnLayout_t lt_internal_output; */
|
|
||||||
void* res_lrn_bwd[dnnResourceNumber];
|
void* res_lrn_bwd[dnnResourceNumber];
|
||||||
|
|
||||||
// prepare mkl input
|
// prepare mkl input
|
||||||
@ -619,14 +633,36 @@ class MklLRNGradOp : public OpKernel {
|
|||||||
// copy.
|
// copy.
|
||||||
void MklDefaultToEigen(OpKernelContext* context) {
|
void MklDefaultToEigen(OpKernelContext* context) {
|
||||||
// CHECK(false);
|
// CHECK(false);
|
||||||
Tensor in_grads = MklGetInput(context, 0);
|
|
||||||
Tensor in_image = MklGetInput(context, 1);
|
Tensor in_grads;
|
||||||
Tensor out_image = MklGetInput(context, 2);
|
Tensor in_image;
|
||||||
|
Tensor out_image;
|
||||||
|
|
||||||
GetMklShape(context, 0, &ingrad_shape);
|
GetMklShape(context, 0, &ingrad_shape);
|
||||||
GetMklShape(context, 1, &inimage_shape);
|
GetMklShape(context, 1, &inimage_shape);
|
||||||
GetMklShape(context, 2, &outimage_shape);
|
GetMklShape(context, 2, &outimage_shape);
|
||||||
|
|
||||||
|
if (ingrad_shape.IsMklTensor()) {
|
||||||
|
in_grads =
|
||||||
|
ConvertMklToTF<T>(context, MklGetInput(context, 0), ingrad_shape);
|
||||||
|
} else {
|
||||||
|
in_grads = MklGetInput(context, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inimage_shape.IsMklTensor()) {
|
||||||
|
in_image =
|
||||||
|
ConvertMklToTF<T>(context, MklGetInput(context, 1), inimage_shape);
|
||||||
|
} else {
|
||||||
|
in_image = MklGetInput(context, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (outimage_shape.IsMklTensor()) {
|
||||||
|
out_image =
|
||||||
|
ConvertMklToTF<T>(context, MklGetInput(context, 2), outimage_shape);
|
||||||
|
} else {
|
||||||
|
out_image = MklGetInput(context, 2);
|
||||||
|
}
|
||||||
|
|
||||||
const int64 batch = static_cast<int64>(in_grads.dim_size(0));
|
const int64 batch = static_cast<int64>(in_grads.dim_size(0));
|
||||||
const int64 rows = static_cast<int64>(in_grads.dim_size(1));
|
const int64 rows = static_cast<int64>(in_grads.dim_size(1));
|
||||||
const int64 cols = static_cast<int64>(in_grads.dim_size(2));
|
const int64 cols = static_cast<int64>(in_grads.dim_size(2));
|
||||||
|
@ -199,15 +199,13 @@ class MklMatMulOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_CPU(T) \
|
#define REGISTER_CPU(T) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||||
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \
|
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
|
||||||
REGISTER_KERNEL_BUILDER( \
|
|
||||||
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("MKL"), \
|
|
||||||
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>)
|
|
||||||
|
|
||||||
// TODO:Consider template specialization when adding/removing additional types
|
// TODO(inteltf) Consider template specialization when adding/removing
|
||||||
|
// additional types
|
||||||
TF_CALL_float(REGISTER_CPU);
|
TF_CALL_float(REGISTER_CPU);
|
||||||
TF_CALL_double(REGISTER_CPU);
|
TF_CALL_double(REGISTER_CPU);
|
||||||
TF_CALL_complex64(REGISTER_CPU);
|
TF_CALL_complex64(REGISTER_CPU);
|
||||||
|
@ -276,11 +276,6 @@ class MklMaxPoolingGradOp : public OpKernel {
|
|||||||
mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>(
|
mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>(
|
||||||
static_cast<const void*>(output_tensor->flat<T>().data()));
|
static_cast<const void*>(output_tensor->flat<T>().data()));
|
||||||
|
|
||||||
int64 output_size = output_tensor->NumElements();
|
|
||||||
for (int64 i = 0; i < output_size; ++i) {
|
|
||||||
(static_cast<float*>(mkl_context.pooling_res[dnnResourceDiffSrc]))[i] = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
CHECK_EQ(
|
CHECK_EQ(
|
||||||
dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res),
|
dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res),
|
||||||
E_SUCCESS);
|
E_SUCCESS);
|
||||||
|
@ -16,17 +16,17 @@ limitations under the License.
|
|||||||
// See docs in ../ops/nn_ops.cc.
|
// See docs in ../ops/nn_ops.cc.
|
||||||
#ifdef INTEL_MKL
|
#ifdef INTEL_MKL
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/core/framework/numeric_op.h"
|
#include "tensorflow/core/framework/numeric_op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
#include "third_party/mkl/include/mkl_dnn.h"
|
|
||||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
|
||||||
#include "tensorflow/core/platform/default/logging.h"
|
#include "tensorflow/core/platform/default/logging.h"
|
||||||
#include "tensorflow/core/util/mkl_util.h"
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
|
#include "third_party/mkl/include/mkl_dnn.h"
|
||||||
|
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -194,45 +194,29 @@ class MklReluGradOp : public OpKernel {
|
|||||||
|
|
||||||
void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
|
void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
|
||||||
void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
|
void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
|
||||||
|
dnnPrimitive_t cv_input_to_grad = NULL;
|
||||||
|
Tensor mkl_tmp_buf_tensor;
|
||||||
|
void* mkl_buffer_convert = nullptr;
|
||||||
|
|
||||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
|
// if input and grad are not in the same layout, do a conversion between
|
||||||
&mkl_lt_internal_grad, prim_relu_bwd, dnnResourceDiffDst),
|
// them.
|
||||||
E_SUCCESS);
|
if (!dnnLayoutCompare_F32(lt_input, lt_grad)) {
|
||||||
|
AllocTmpBuffer(context, &mkl_tmp_buf_tensor, lt_grad,
|
||||||
|
&mkl_buffer_convert);
|
||||||
|
CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad),
|
||||||
|
E_SUCCESS);
|
||||||
|
|
||||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input,
|
CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, user_i,
|
||||||
prim_relu_bwd, dnnResourceSrc),
|
mkl_buffer_convert),
|
||||||
E_SUCCESS);
|
|
||||||
|
|
||||||
if (!dnnLayoutCompare_F32(mkl_lt_internal_grad, lt_grad)) {
|
|
||||||
AllocTmpBuffer(context, mkl_tmp_grad_buf_tensor, mkl_lt_internal_grad,
|
|
||||||
&relu_res[dnnResourceDiffDst]);
|
|
||||||
CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_grad, lt_grad,
|
|
||||||
mkl_lt_internal_grad),
|
|
||||||
E_SUCCESS);
|
E_SUCCESS);
|
||||||
CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_grad, user_g,
|
relu_res[dnnResourceSrc] = mkl_buffer_convert;
|
||||||
relu_res[dnnResourceDiffDst]),
|
dnnDelete_F32(cv_input_to_grad);
|
||||||
E_SUCCESS);
|
|
||||||
dnnDelete_F32(cv_user_to_reluB_grad);
|
|
||||||
} else {
|
|
||||||
relu_res[dnnResourceDiffDst] = user_g;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input)) {
|
|
||||||
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
|
|
||||||
&relu_res[dnnResourceSrc]);
|
|
||||||
CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_input, lt_input,
|
|
||||||
mkl_lt_internal_input),
|
|
||||||
E_SUCCESS);
|
|
||||||
CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_input, user_i,
|
|
||||||
relu_res[dnnResourceSrc]),
|
|
||||||
E_SUCCESS);
|
|
||||||
dnnDelete_F32(cv_user_to_reluB_input);
|
|
||||||
} else {
|
} else {
|
||||||
relu_res[dnnResourceSrc] = user_i;
|
relu_res[dnnResourceSrc] = user_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
dnnLayoutDelete_F32(mkl_lt_internal_input);
|
relu_res[dnnResourceDiffDst] = user_g;
|
||||||
dnnLayoutDelete_F32(mkl_lt_internal_grad);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void MklCreateInputLayouts(OpKernelContext* context) {
|
void MklCreateInputLayouts(OpKernelContext* context) {
|
||||||
@ -331,7 +315,7 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
|
|||||||
mkl_context.MklCreateInputLayouts(context);
|
mkl_context.MklCreateInputLayouts(context);
|
||||||
float negative_slope = 0.0;
|
float negative_slope = 0.0;
|
||||||
CHECK_EQ(dnnReLUCreateBackward_F32(&mkl_context.prim_relu_bwd, NULL,
|
CHECK_EQ(dnnReLUCreateBackward_F32(&mkl_context.prim_relu_bwd, NULL,
|
||||||
mkl_context.lt_grad, mkl_context.lt_input,
|
mkl_context.lt_grad, mkl_context.lt_grad,
|
||||||
negative_slope),
|
negative_slope),
|
||||||
E_SUCCESS);
|
E_SUCCESS);
|
||||||
Tensor mkl_tmp_grad_buf_tensor, mkl_tmp_input_buf_tensor;
|
Tensor mkl_tmp_grad_buf_tensor, mkl_tmp_input_buf_tensor;
|
||||||
@ -380,12 +364,12 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
|
|||||||
/* Register DNN kernels for supported operations and supported types - right now
|
/* Register DNN kernels for supported operations and supported types - right now
|
||||||
* it is only Relu and f32*/
|
* it is only Relu and f32*/
|
||||||
#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \
|
#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("_MklRelu") \
|
REGISTER_KERNEL_BUILDER(Name("_MklRelu") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<type>("T") \
|
.TypeConstraint<type>("T") \
|
||||||
.Label(mkl_op_registry::kMklOpLabel), \
|
.Label(mkl_op_registry::kMklOpLabel), \
|
||||||
MklReluOp<CPUDevice, type>); \
|
MklReluOp<CPUDevice, type>); \
|
||||||
REGISTER_KERNEL_BUILDER(Name("_MklReluGrad") \
|
REGISTER_KERNEL_BUILDER(Name("_MklReluGrad") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<type>("T") \
|
.TypeConstraint<type>("T") \
|
||||||
.Label(mkl_op_registry::kMklOpLabel), \
|
.Label(mkl_op_registry::kMklOpLabel), \
|
||||||
|
@ -106,7 +106,7 @@ class MklToTfOp : public OpKernel {
|
|||||||
///////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define REGISTER_CPU(T) \
|
#define REGISTER_CPU(T) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("MklToTf") \
|
REGISTER_KERNEL_BUILDER(Name("_MklToTf") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.Label(mkl_op_registry::kMklOpLabel), \
|
.Label(mkl_op_registry::kMklOpLabel), \
|
||||||
|
@ -28,6 +28,12 @@ set -e
|
|||||||
CONTAINER_TYPE=$( echo "$1" | tr '[:upper:]' '[:lower:]' )
|
CONTAINER_TYPE=$( echo "$1" | tr '[:upper:]' '[:lower:]' )
|
||||||
shift 1
|
shift 1
|
||||||
|
|
||||||
|
# Enable support for MKL, for Linux only.
|
||||||
|
if [[ $(uname) == "Linux" ]]; then
|
||||||
|
export TF_NEED_MKL=1
|
||||||
|
export TF_DOWNLOAD_MKL=1
|
||||||
|
fi
|
||||||
|
|
||||||
# Enable support for Google Cloud Platform (GCP)
|
# Enable support for Google Cloud Platform (GCP)
|
||||||
export TF_NEED_GCP=1
|
export TF_NEED_GCP=1
|
||||||
# Enable support for HDFS
|
# Enable support for HDFS
|
||||||
|
@ -46,6 +46,7 @@ apt-get install -y --no-install-recommends \
|
|||||||
git \
|
git \
|
||||||
libcurl4-openssl-dev \
|
libcurl4-openssl-dev \
|
||||||
libtool \
|
libtool \
|
||||||
|
mlocate \
|
||||||
openjdk-8-jdk \
|
openjdk-8-jdk \
|
||||||
openjdk-8-jre-headless \
|
openjdk-8-jre-headless \
|
||||||
pkg-config \
|
pkg-config \
|
||||||
@ -63,6 +64,9 @@ apt-get install -y --no-install-recommends \
|
|||||||
zip \
|
zip \
|
||||||
zlib1g-dev
|
zlib1g-dev
|
||||||
|
|
||||||
|
# populate the database
|
||||||
|
updatedb
|
||||||
|
|
||||||
if [[ "$1" != "--without_cmake" ]]; then
|
if [[ "$1" != "--without_cmake" ]]; then
|
||||||
apt-get install -y --no-install-recommends \
|
apt-get install -y --no-install-recommends \
|
||||||
cmake
|
cmake
|
||||||
|
@ -28,6 +28,7 @@ export TF_NEED_GCP=0
|
|||||||
export TF_NEED_HDFS=0
|
export TF_NEED_HDFS=0
|
||||||
export TF_NEED_CUDA=0
|
export TF_NEED_CUDA=0
|
||||||
export TF_NEED_OPENCL=0
|
export TF_NEED_OPENCL=0
|
||||||
|
export TF_NEED_MKL=0
|
||||||
export COMPUTECPP_PATH="/usr/local"
|
export COMPUTECPP_PATH="/usr/local"
|
||||||
|
|
||||||
export PATH="/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"
|
export PATH="/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"
|
||||||
|
@ -29,6 +29,7 @@ export PYTHON_BIN_PATH="/usr/bin/python"
|
|||||||
export TF_NEED_GCP=0
|
export TF_NEED_GCP=0
|
||||||
export TF_NEED_HDFS=0
|
export TF_NEED_HDFS=0
|
||||||
export TF_NEED_OPENCL=0
|
export TF_NEED_OPENCL=0
|
||||||
|
export TF_NEED_MKL=0
|
||||||
export COMPUTECPP_PATH="/usr/local"
|
export COMPUTECPP_PATH="/usr/local"
|
||||||
|
|
||||||
export PATH="/usr/local/cuda/bin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"
|
export PATH="/usr/local/cuda/bin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"
|
||||||
|
2
third_party/grpc.BUILD
vendored
2
third_party/grpc.BUILD
vendored
@ -178,6 +178,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
],
|
],
|
||||||
|
linkopts = ["-lpthread"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
@ -1787,6 +1788,7 @@ cc_library(
|
|||||||
":grpc_unsecure",
|
":grpc_unsecure",
|
||||||
"//external:protobuf_clib",
|
"//external:protobuf_clib",
|
||||||
],
|
],
|
||||||
|
linkopts = ["-lpthread"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
3
third_party/jemalloc.BUILD
vendored
3
third_party/jemalloc.BUILD
vendored
@ -94,6 +94,9 @@ cc_library(
|
|||||||
"@%ws%//tensorflow:linux_ppc64le": [
|
"@%ws%//tensorflow:linux_ppc64le": [
|
||||||
"-lpthread",
|
"-lpthread",
|
||||||
],
|
],
|
||||||
|
"@%ws%//tensorflow:linux_x86_64": [
|
||||||
|
"-lpthread",
|
||||||
|
],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
|
1
third_party/llvm/llvm.BUILD
vendored
1
third_party/llvm/llvm.BUILD
vendored
@ -1695,6 +1695,7 @@ cc_library(
|
|||||||
":demangle",
|
":demangle",
|
||||||
"@zlib_archive//:zlib",
|
"@zlib_archive//:zlib",
|
||||||
],
|
],
|
||||||
|
linkopts = ["-lpthread", "-ldl"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
1
third_party/mkl/BUILD
vendored
1
third_party/mkl/BUILD
vendored
@ -16,6 +16,7 @@ load(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "intel_binary_blob",
|
name = "intel_binary_blob",
|
||||||
srcs = if_mkl([
|
srcs = if_mkl([
|
||||||
|
"libdl.so.2",
|
||||||
"libmklml_intel.so",
|
"libmklml_intel.so",
|
||||||
"libiomp5.so",
|
"libiomp5.so",
|
||||||
]),
|
]),
|
||||||
|
Loading…
Reference in New Issue
Block a user