Enable MKL in configure and various bug fixes (#9580)
* relu grad and maxpooling grad fixes for perf * Graph layout pass and conversion pass changes This commit makes following changes: - Enables support for ReluGrad and BiasAddGrad - Adds support for detecting depthwise/batchwise pooling - Adds more unit tests for Graph rewrite pass - Improvements to handling control-flow edges - Bug fixes * Defaulting to Eigen when LRN depth_radius!=2 * Fixed mkl_conv_grad_filter.cc for conv_ops_tests.py * Style fix to mkl_matmul and remove unnecessary 'MKL' label on matmul kernel * Style fixes based on clang-format to mkl_conv_* and mkl_matmul * Bug fixes * Adding OP_REQUIRES_OK check in Concat * Making some style changes * Enabled the configuration of MKL settings * relu grad and maxpooling grad fixes for perf * Graph layout pass and conversion pass changes This commit makes following changes: - Enables support for ReluGrad and BiasAddGrad - Adds support for detecting depthwise/batchwise pooling - Adds more unit tests for Graph rewrite pass - Improvements to handling control-flow edges - Bug fixes * Defaulting to Eigen when LRN depth_radius!=2 * Fixed mkl_conv_grad_filter.cc for conv_ops_tests.py * Style fix to mkl_matmul and remove unnecessary 'MKL' label on matmul kernel * Style fixes based on clang-format to mkl_conv_* and mkl_matmul * Bug fixes * Adding OP_REQUIRES_OK check in Concat * Making some style changes * Enabled the configuration of MKL settings * Fixing graph unit tests with Mkl op name change to _Mkl; Fixed missing _ in MklToTf op * Fixed missing libdl.so.2 in BUILD file * Fixes for unit test build failures. * Changes in mkl_conv_grad_filter_ops.cc for Google code style * Fixes to remove dead code * removed the dead code and added a TODO for mkl implementation to handle this case in the future * Fixed buildifier sanity check error * Adding support for google's CI automation * Updated link to new MKL version * Fix for missing locate command in CI * Adding updatedb to populate the database after installing mlocate * Fixed buildifier issue * setting tf_need_mkl=0 in libtf files * Added third_party/mkl/* to .gitignore * Added third_party/eigen3/mkl_include to .gitignore * In configured, set MKL-enabling options only for Linux.
This commit is contained in:
parent
3273cf4f4d
commit
27dd167c5f
2
.gitignore
vendored
2
.gitignore
vendored
@ -4,6 +4,8 @@ node_modules
|
||||
/.bazelrc
|
||||
/.tf_configure.bazelrc
|
||||
/bazel-*
|
||||
/third_party/eigen3/mkl_include
|
||||
/third_party/mkl/*
|
||||
/third_party/py/numpy/numpy_include
|
||||
/tools/python_bin_path.sh
|
||||
/tools/git/gen
|
||||
|
57
configure
vendored
57
configure
vendored
@ -180,25 +180,35 @@ fi
|
||||
setup_python
|
||||
|
||||
## Set up MKL related environment settings
|
||||
if false; then # Disable building with MKL for now
|
||||
while [ "$TF_NEED_MKL" == "" ]; do
|
||||
while [ "$TF_NEED_MKL" == "" ]; do
|
||||
fromuser=""
|
||||
read -p "Do you wish to build TensorFlow with MKL support (experimental)? [y/N] " INPUT
|
||||
read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT
|
||||
fromuser="1"
|
||||
case $INPUT in
|
||||
[Yy]* ) echo "MKL support (experimental) (will be enabled for TensorFlow"; TF_NEED_MKL=1;;
|
||||
[Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;;
|
||||
[Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
|
||||
"" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
|
||||
* ) echo "Invalid selection: " $INPUT;;
|
||||
esac
|
||||
done
|
||||
done
|
||||
|
||||
OSNAME=`uname -s`
|
||||
OSNAME=`uname -s`
|
||||
|
||||
if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL
|
||||
if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL
|
||||
fromuser=""
|
||||
read -p "Do you wish to download MKL LIB from the web? [Y/n] " INPUT
|
||||
fromuser="1"
|
||||
case $INPUT in
|
||||
[Yy]* ) TF_DOWNLOAD_MKL=1;;
|
||||
[Nn]* ) TF_DOWNLOAD_MKL=0;;
|
||||
"" ) TF_DOWNLOAD_MKL=1;;
|
||||
* ) echo "Invalid selection: " $INPUT; exit 1;;
|
||||
esac
|
||||
|
||||
if [[ "$TF_DOWNLOAD_MKL" == "1" ]]; then
|
||||
DST=`dirname $0`
|
||||
ARCHIVE_BASENAME=mklml_lnx_2017.0.2.20170209.tgz
|
||||
GITHUB_RELEASE_TAG=v0.5
|
||||
ARCHIVE_BASENAME=mklml_lnx_2018.0.20170425.tgz
|
||||
GITHUB_RELEASE_TAG=v0.7
|
||||
MKLURL="https://github.com/01org/mkl-dnn/releases/download/$GITHUB_RELEASE_TAG/$ARCHIVE_BASENAME"
|
||||
if ! [ -e "$DST/third_party/mkl/$ARCHIVE_BASENAME" ]; then
|
||||
wget --no-check-certificate -P $DST/third_party/mkl/ $MKLURL
|
||||
@ -208,6 +218,19 @@ if false; then # Disable building with MKL for now
|
||||
MKL_INSTALL_PATH=$DST/third_party/mkl/$extracted_dir_name
|
||||
MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"`
|
||||
|
||||
else
|
||||
default_mkl_path=/opt/intel/mklml
|
||||
fromuser=""
|
||||
read -p "Please specify the location where MKL is installed. [Default is $default_mkl_path]: " MKL_INSTALL_PATH
|
||||
fromuser="1"
|
||||
if [ -z "$MKL_INSTALL_PATH" ]; then
|
||||
MKL_INSTALL_PATH=$default_mkl_path
|
||||
fi
|
||||
# Result returned from "read" will be used unexpanded. That make "~" unuseable.
|
||||
# Going through one more level of expansion to handle that.
|
||||
MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"`
|
||||
fi
|
||||
|
||||
if [ "$OSNAME" == "Linux" ]; then
|
||||
# Full MKL configuration
|
||||
MKL_RT_LIB_PATH="lib/intel64/libmkl_rt.so" #${TF_MKL_EXT}#TODO version?
|
||||
@ -226,8 +249,17 @@ if false; then # Disable building with MKL for now
|
||||
ln -sf $MKL_INSTALL_PATH/${MKL_ML_OMP_LIB_PATH} third_party/mkl/
|
||||
ln -sf $MKL_INSTALL_PATH/include third_party/mkl/
|
||||
ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include
|
||||
loc=$(locate -e libdl.so.2 | sed -n 1p)
|
||||
ln -sf $loc third_party/mkl/libdl.so.2
|
||||
elif [ -e "$MKL_INSTALL_PATH/${MKL_RT_LIB_PATH}" ]; then
|
||||
ln -sf $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} third_party/mkl/
|
||||
ln -sf $MKL_INSTALL_PATH/${MKL_RT_OMP_LIB_PATH} third_party/mkl/
|
||||
ln -sf $MKL_INSTALL_PATH/include third_party/mkl/
|
||||
ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include
|
||||
loc=$(locate -e libdl.so.2 | sed -n 1p)
|
||||
ln -sf $loc third_party/mkl/libdl.so.2
|
||||
else
|
||||
echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} does not exist";
|
||||
echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} nor $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} exists";
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@ -241,9 +273,8 @@ cat > third_party/mkl/mkl.config <<EOF
|
||||
MKL_INSTALL_PATH=$MKL_INSTALL_PATH
|
||||
EOF
|
||||
|
||||
fi # TF_NEED_MKL
|
||||
################## MKL
|
||||
fi # Disable building with MKL for now
|
||||
fi # TF_NEED_MKL
|
||||
## End MKL setup
|
||||
|
||||
## Set up architecture-dependent optimization flags.
|
||||
if [ -z "$CC_OPT_FLAGS" ]; then
|
||||
|
@ -98,6 +98,7 @@ cc_library(
|
||||
name = "simple_orc_jit",
|
||||
srcs = ["simple_orc_jit.cc"],
|
||||
hdrs = ["simple_orc_jit.h"],
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
":compiler_functor",
|
||||
":cpu_runtime",
|
||||
|
@ -17,6 +17,7 @@
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h"
|
||||
|
@ -1282,7 +1282,10 @@ cc_library(
|
||||
] + tf_additional_verbs_lib_defines(),
|
||||
linkopts = select({
|
||||
"//tensorflow:freebsd": [],
|
||||
"//conditions:default": ["-ldl"],
|
||||
"//conditions:default": [
|
||||
"-ldl",
|
||||
"-lpthread",
|
||||
],
|
||||
}),
|
||||
deps = tf_additional_lib_deps() + [
|
||||
":lib_hash_crc32c_accelerate_internal",
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
#include "tensorflow/core/graph/mkl_layout_pass.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
@ -282,49 +283,70 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
csinfo_.mkl_conv2d_with_bias_backprop_bias =
|
||||
"_MklConv2DWithBiasBackpropBias";
|
||||
csinfo_.relu = "Relu";
|
||||
csinfo_.reshape = "Reshape";
|
||||
csinfo_.relu_grad = "ReluGrad";
|
||||
csinfo_.reshape = "Reshape";
|
||||
csinfo_.split = "Split";
|
||||
|
||||
// NOTE: names are alphabetically sorted.
|
||||
rinfo_.push_back({csinfo_.avg_pool, GetMklOpName(csinfo_.avg_pool), 1,
|
||||
CopyAttrsPooling, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.avg_pool,
|
||||
GetMklOpName(csinfo_.avg_pool),
|
||||
CopyAttrsPooling, AlwaysRewrite, nullptr});
|
||||
rinfo_.push_back({csinfo_.avg_pool_grad,
|
||||
GetMklOpName(csinfo_.avg_pool_grad), 2, CopyAttrsPooling,
|
||||
AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.concat, GetMklOpName(csinfo_.concat), 0,
|
||||
CopyAttrsConcat, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.concatv2, GetMklOpName(csinfo_.concatv2), 0,
|
||||
CopyAttrsConcatV2, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.conv2d, GetMklOpName(csinfo_.conv2d), 2,
|
||||
CopyAttrsConv2D, AlwaysRewrite});
|
||||
GetMklOpName(csinfo_.avg_pool_grad),
|
||||
CopyAttrsPooling, AlwaysRewrite, nullptr});
|
||||
// BiasAddGrad gets written into Conv2DWithBiasBackpropBias depending
|
||||
// on if context contains Conv2D.
|
||||
rinfo_.push_back({csinfo_.bias_add_grad,
|
||||
csinfo_.mkl_conv2d_with_bias_backprop_bias,
|
||||
CopyAttrsBiasAddGrad, ContextMatchRewrite,
|
||||
&biasaddgrad_conv2dwithbias_context_});
|
||||
// BiasAddGrad gets written into BiasAddGrad depending on if context
|
||||
// contains MatMul.
|
||||
rinfo_.push_back({csinfo_.bias_add_grad, csinfo_.matmul,
|
||||
CopyAttrsBiasAddGrad, ContextMatchRewrite,
|
||||
&biasaddgrad_matmul_context_});
|
||||
rinfo_.push_back({csinfo_.concat,
|
||||
GetMklOpName(csinfo_.concat),
|
||||
CopyAttrsConcat, AlwaysRewrite, nullptr});
|
||||
rinfo_.push_back({csinfo_.concatv2,
|
||||
GetMklOpName(csinfo_.concatv2),
|
||||
CopyAttrsConcatV2, AlwaysRewrite, nullptr});
|
||||
rinfo_.push_back({csinfo_.conv2d,
|
||||
GetMklOpName(csinfo_.conv2d),
|
||||
CopyAttrsConv2D, AlwaysRewrite, nullptr});
|
||||
rinfo_.push_back({csinfo_.conv2d_grad_filter,
|
||||
GetMklOpName(csinfo_.conv2d_grad_filter), 3,
|
||||
CopyAttrsConv2D, AlwaysRewrite});
|
||||
GetMklOpName(csinfo_.conv2d_grad_filter),
|
||||
CopyAttrsConv2D, AlwaysRewrite, nullptr});
|
||||
rinfo_.push_back({csinfo_.conv2d_grad_input,
|
||||
GetMklOpName(csinfo_.conv2d_grad_input), 3,
|
||||
CopyAttrsConv2D, AlwaysRewrite});
|
||||
GetMklOpName(csinfo_.conv2d_grad_input),
|
||||
CopyAttrsConv2D, AlwaysRewrite, nullptr});
|
||||
rinfo_.push_back({csinfo_.fused_batch_norm,
|
||||
GetMklOpName(csinfo_.fused_batch_norm), 5,
|
||||
CopyAttrsFusedBatchNorm, AlwaysRewrite});
|
||||
GetMklOpName(csinfo_.fused_batch_norm),
|
||||
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
|
||||
rinfo_.push_back({csinfo_.fused_batch_norm_grad,
|
||||
GetMklOpName(csinfo_.fused_batch_norm_grad), 5,
|
||||
CopyAttrsFusedBatchNorm, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.lrn, GetMklOpName(csinfo_.lrn), 1, CopyAttrsLRN,
|
||||
AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.lrn_grad, GetMklOpName(csinfo_.lrn_grad), 3,
|
||||
CopyAttrsLRN, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.max_pool, GetMklOpName(csinfo_.max_pool), 1,
|
||||
CopyAttrsPooling, AlwaysRewrite});
|
||||
GetMklOpName(csinfo_.fused_batch_norm_grad),
|
||||
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
|
||||
rinfo_.push_back({csinfo_.lrn,
|
||||
GetMklOpName(csinfo_.lrn),
|
||||
CopyAttrsLRN, AlwaysRewrite, nullptr});
|
||||
rinfo_.push_back({csinfo_.lrn_grad,
|
||||
GetMklOpName(csinfo_.lrn_grad),
|
||||
CopyAttrsLRN, AlwaysRewrite, nullptr});
|
||||
rinfo_.push_back({csinfo_.max_pool,
|
||||
GetMklOpName(csinfo_.max_pool),
|
||||
CopyAttrsPooling, NonDepthBatchWisePoolRewrite, nullptr});
|
||||
rinfo_.push_back({csinfo_.max_pool_grad,
|
||||
GetMklOpName(csinfo_.max_pool_grad), 3, CopyAttrsPooling,
|
||||
AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu), 1,
|
||||
CopyAttrsRelu, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.reshape, GetMklOpName(csinfo_.reshape), 2,
|
||||
CopyAttrsReshape, AlwaysRewrite});
|
||||
|
||||
// TODO(inteltf): we do not support ReluGrad and BiasAddGrad yet.
|
||||
GetMklOpName(csinfo_.max_pool_grad),
|
||||
CopyAttrsPooling, AlwaysRewrite, nullptr});
|
||||
rinfo_.push_back({csinfo_.relu,
|
||||
GetMklOpName(csinfo_.relu),
|
||||
CopyAttrsRelu, AlwaysRewrite, nullptr});
|
||||
rinfo_.push_back({csinfo_.relu_grad,
|
||||
GetMklOpName(csinfo_.relu_grad),
|
||||
CopyAttrsRelu, AlwaysRewrite, nullptr});
|
||||
rinfo_.push_back({csinfo_.reshape,
|
||||
GetMklOpName(csinfo_.reshape),
|
||||
CopyAttrsReshape, AlwaysRewrite, nullptr});
|
||||
|
||||
// Add info about which ops to add workspace edge to and the slots.
|
||||
wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
|
||||
@ -338,8 +360,15 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// maxhops in backward data-flow graph. Since input of forward nodes
|
||||
// (Conv2D) directly goes to backward nodes, we do not expect the
|
||||
// hop-distance would be more than few nodes.
|
||||
cinfo_.push_back({csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias,
|
||||
kNodeMergeContextMaxDepth});
|
||||
biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul,
|
||||
kNodeMergeContextMaxDepth};
|
||||
|
||||
biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad,
|
||||
csinfo_.mkl_conv2d_with_bias,
|
||||
kNodeMergeContextMaxDepth};
|
||||
|
||||
cinfo_.push_back(&biasaddgrad_matmul_context_);
|
||||
cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_);
|
||||
}
|
||||
|
||||
// Standard interface to run pass
|
||||
@ -354,7 +383,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// @return true, if and only if graph is mutated; false otherwise.
|
||||
bool RunPass(std::unique_ptr<Graph>* g);
|
||||
|
||||
private:
|
||||
/// Structure to specify the context information used in a node rewrite rule
|
||||
typedef struct {
|
||||
string node; // Name of the node to be rewritten
|
||||
string fwd; // Name of the node in the forward pass that this node
|
||||
// corresponds to
|
||||
size_t max_hop; // Maximum number of hops the fwd is located
|
||||
// from this node. If the fwd is farther than max_hop
|
||||
// then we do not rewrite the node.
|
||||
} ContextInfo;
|
||||
|
||||
/// Structure to specify the name of an original node, its new name after
|
||||
/// rewrite, the number of inputs to the original node, the function to
|
||||
/// be used to copy attributes for the op, and the rule (if any) which
|
||||
@ -362,11 +400,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
typedef struct {
|
||||
string name; // Original name of op of the node in the graph
|
||||
string new_name; // New name of the op of the node in the graph
|
||||
int num_ins; // The number of inputs to the original op type
|
||||
// A function handler to copy attributes from an old node to a new node.
|
||||
std::function<void(const Node*, NodeBuilder*)> copy_attrs;
|
||||
std::function<bool(const Node*)> rewrite_rule; // A rule under which to
|
||||
// rewrite this node.
|
||||
// A rule under which to rewrite this node
|
||||
std::function<bool(const Node*, const ContextInfo* c)> rewrite_rule;
|
||||
// ContextInfo, if any, to be used for rewrite
|
||||
ContextInfo* context;
|
||||
} RewriteInfo;
|
||||
|
||||
/// Structure to specify a forward op, a backward op, and the slot numbers
|
||||
@ -393,16 +432,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
string new_node; // Name of the node after merge
|
||||
} MergeInfo;
|
||||
|
||||
/// Structure to specify the context information used in a node rewrite rule
|
||||
typedef struct {
|
||||
string node; // Name of the node to be rewritten
|
||||
string fwd; // Name of the node in the forward pass that this node
|
||||
// corresponds to
|
||||
size_t max_hop; // Maximum number of hops the fwd is located
|
||||
// from this node. If the fwd is farther than max_hop
|
||||
// then we do not rewrite the node.
|
||||
} ContextInfo;
|
||||
|
||||
/// Structure to store all constant strings
|
||||
/// NOTE: names are alphabetically sorted.
|
||||
struct {
|
||||
@ -427,10 +456,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
string mkl_conv2d_with_bias_backprop_bias;
|
||||
string relu;
|
||||
string relu_grad;
|
||||
string split;
|
||||
string reshape;
|
||||
string split;
|
||||
} csinfo_;
|
||||
|
||||
private:
|
||||
/// Maintain info about nodes to rewrite
|
||||
std::vector<RewriteInfo> rinfo_;
|
||||
|
||||
@ -441,7 +471,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
std::vector<MergeInfo> minfo_;
|
||||
|
||||
/// Maintain info about nodes to rewrite
|
||||
static std::vector<ContextInfo> cinfo_;
|
||||
static std::vector<ContextInfo*> cinfo_;
|
||||
|
||||
/// Context variables used in referencing rules
|
||||
static ContextInfo biasaddgrad_matmul_context_;
|
||||
static ContextInfo biasaddgrad_conv2dwithbias_context_;
|
||||
|
||||
/// Hash table to maintain nodes visited in the graph.
|
||||
std::unordered_set<const Node*> visited_nodes_;
|
||||
@ -464,19 +498,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// Clear all visited nodes
|
||||
inline void UnMarkRewrittenNodes() { visited_nodes_.clear(); }
|
||||
|
||||
// Is this a graph node that can accept variable number of inputs?
|
||||
// Return true if yes, false otherwise.
|
||||
//
|
||||
// Concat, Split are vararg nodes.
|
||||
inline bool IsVarArgNode(Node* n) {
|
||||
if (n->type_string() == csinfo_.concat ||
|
||||
n->type_string() == csinfo_.concatv2 ||
|
||||
n->type_string() == csinfo_.split) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Is OpDef::ArgDef a list type? It could be N * T or list(type).
|
||||
// Refer to opdef.proto for details of list type.
|
||||
inline bool ArgIsList(const OpDef::ArgDef& arg) const {
|
||||
@ -510,6 +531,39 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
return string(kMklOpPrefix) + name;
|
||||
}
|
||||
|
||||
// Can op represented by node 'n' run on DEVICE_CPU?
|
||||
// Op can run on CPU with MKL if the runtime assigned device or the
|
||||
// user requested device contains device CPU, or both are empty.
|
||||
bool CanOpRunOnCPUDevice(const Node* n) {
|
||||
bool result = true;
|
||||
string reason;
|
||||
|
||||
// Substring that should be checked for in device name for CPU device.
|
||||
const char* const kCPUDeviceSubStr = "cpu";
|
||||
|
||||
// If Op has been specifically assigned to a non-CPU device, then No.
|
||||
if (!n->assigned_device_name().empty() &&
|
||||
!StringPiece(n->assigned_device_name()).contains(kCPUDeviceSubStr)) {
|
||||
result = false;
|
||||
reason = "Op has been assigned a runtime device that is not CPU.";
|
||||
}
|
||||
|
||||
// If user has specifically assigned this op to a non-CPU device, then No.
|
||||
if (!n->def().device().empty() &&
|
||||
!StringPiece(n->def().device()).contains(kCPUDeviceSubStr)) {
|
||||
result = false;
|
||||
reason = "User has assigned a device that is not CPU.";
|
||||
}
|
||||
|
||||
if (result == false) {
|
||||
VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
|
||||
<< n->type_string() << ", reason: " << reason;
|
||||
}
|
||||
|
||||
// Otherwise Yes.
|
||||
return result;
|
||||
}
|
||||
|
||||
// Return a node that can be merged with input node 'n'
|
||||
//
|
||||
// @return pointer to the node if we can find such a
|
||||
@ -538,13 +592,46 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
|
||||
// Default rewrite rule to be used in scenario 1 for rewrite.
|
||||
// @return - true (since we want to always rewrite)
|
||||
static bool AlwaysRewrite(const Node* n) { return true; }
|
||||
// Rewrite rule that uses context-information for matching
|
||||
static bool AlwaysRewrite(const Node* n, const ContextInfo* c = nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check if we are performing pooling on depth or batch. If it is, then we
|
||||
// do not rewrite MaxPool node to Mkl version.
|
||||
// @return - true (if it is not a depth/batch wise pooling case);
|
||||
// false otherwise.
|
||||
static bool NonDepthBatchWisePoolRewrite(const Node* n,
|
||||
const ContextInfo* c) {
|
||||
CHECK_NOTNULL(n);
|
||||
|
||||
string data_format_str;
|
||||
TensorFormat data_format;
|
||||
std::vector<int32> ksize, strides;
|
||||
CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
|
||||
CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
|
||||
CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(),
|
||||
true);
|
||||
CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
|
||||
|
||||
// Condition that specifies non-batch-wise and non-depth-wise pooling.
|
||||
if (GetTensorDim(ksize, data_format, 'N') == 1 &&
|
||||
GetTensorDim(strides, data_format, 'N') == 1 &&
|
||||
GetTensorDim(ksize, data_format, 'C') == 1 &&
|
||||
GetTensorDim(strides, data_format, 'C') == 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Rewrite rule that uses context-information for matching,
|
||||
// used in scenario 2.
|
||||
//
|
||||
// @input - Node 'n' for which to search for matching context
|
||||
// @return - true if matching context is found; false otherwise.
|
||||
static bool ContextMatchRewrite(const Node* n);
|
||||
// @input - The context 'c' under which to rewrite
|
||||
// @return - true if we can rewrite node under context 'c';
|
||||
// false otherwise.
|
||||
static bool ContextMatchRewrite(const Node* n, const ContextInfo* c);
|
||||
|
||||
// Helper function that searches the matching contextinfo for the node.
|
||||
// Implements depth-first search in the data dependence graph for the
|
||||
@ -598,6 +685,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// node that we are constructing.
|
||||
//
|
||||
// @input g - input graph,
|
||||
// @input orig_node - Original node that we are rewriting
|
||||
// @input inputs - inputs to old node that we are using for constructing
|
||||
// new inputs,
|
||||
// @input input_idx - the index in the 'inputs' vector pointing to the
|
||||
@ -608,9 +696,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// @output output_nodes - the list of new nodes creating Mkl tensors
|
||||
//
|
||||
// @return None
|
||||
void GetNodesProducingMklTensorList(
|
||||
std::unique_ptr<Graph>* g,
|
||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
||||
void GetNodesProducingMklTensorList(std::unique_ptr<Graph>* g,
|
||||
Node* orig_node, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
||||
int* input_idx, int list_length,
|
||||
std::vector<NodeBuilder::NodeOut>* output_nodes);
|
||||
|
||||
@ -620,6 +707,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// if 'n' is not an Mkl layer.
|
||||
//
|
||||
// @input g - input graph,
|
||||
// @input orig_node - Original node that we are rewriting,
|
||||
// @input n - Node based on which we are creating Mkl node,
|
||||
// @input n_output_slot - the output slot of node 'n'
|
||||
// which is feeding to the node that we are constructing
|
||||
@ -627,9 +715,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// @output mkl_node_output_slot - the slot number of mkl_node that
|
||||
// will feed the tensor
|
||||
// @return None
|
||||
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* n,
|
||||
int n_output_slot, Node** mkl_node,
|
||||
int* mkl_node_output_slot);
|
||||
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
|
||||
Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot);
|
||||
|
||||
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
|
||||
// in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
|
||||
@ -695,13 +782,18 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
Node* orig_node);
|
||||
};
|
||||
|
||||
std::vector<MklLayoutRewritePass::ContextInfo> MklLayoutRewritePass::cinfo_;
|
||||
MklLayoutRewritePass::ContextInfo
|
||||
MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_;
|
||||
MklLayoutRewritePass::ContextInfo
|
||||
MklLayoutRewritePass::biasaddgrad_matmul_context_;
|
||||
std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
|
||||
|
||||
// We register Mkl rewrite pass for phase 1 in post rewrite group.
|
||||
// We register Mkl rewrite pass for phase 1 in post partitioning group.
|
||||
// We register it here so that we get a complete picture of all users of Mkl
|
||||
// nodes. Do not change the ordering of the Mkl passes.
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 1,
|
||||
MklLayoutRewritePass);
|
||||
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
|
||||
OptimizationPassRegistry::POST_PARTITIONING;
|
||||
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// Helper functions for creating new node
|
||||
@ -737,28 +829,15 @@ void MklLayoutRewritePass::GetNodesProducingTFTensorList(
|
||||
|
||||
while (list_length != 0) {
|
||||
CHECK_GT(list_length, 0);
|
||||
CHECK_LE(*input_idx, inputs.size());
|
||||
CHECK_LT(*input_idx, inputs.size());
|
||||
Node* n = inputs[*input_idx].first;
|
||||
int slot = inputs[*input_idx].second;
|
||||
const OpDef::ArgDef& arg = n->op_def().output_arg(slot);
|
||||
// If input node 'n' is producing a list/array output at output
|
||||
// slot 'slot' then we need to find out the length of that list/array.
|
||||
if (ArgIsList(arg)) {
|
||||
int N = GetTensorListLength(arg, n);
|
||||
CHECK_LE(N, list_length);
|
||||
for (int j = 0; j < N; j++) {
|
||||
output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
|
||||
}
|
||||
(*input_idx)++;
|
||||
list_length -= N;
|
||||
} else {
|
||||
// But if input node 'n' is just producing a single tensor at
|
||||
// If input node 'n' is just producing a single tensor at
|
||||
// output slot 'slot' then we just add that single node.
|
||||
output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
|
||||
(*input_idx)++;
|
||||
list_length--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(nhasabni) We should move this to mkl_util.h.
|
||||
@ -782,13 +861,32 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
|
||||
// device of the original
|
||||
// node.
|
||||
.Finalize(&**g, out));
|
||||
|
||||
// If number of inputs to the original node is > 0, then we add
|
||||
// control dependency between 1st input (index 0) of the original node and
|
||||
// the dummy Mkl node. This is needed because control-flow ops such as Enter,
|
||||
// Merge, etc, require frame_name of the dummy Mkl node to be same as the
|
||||
// rewritten node. Adding control edge between 1st input of the original node
|
||||
// and the dummy Mkl node ensures that the dummy node is in the same frame
|
||||
// as the original node. Choosing 1st input is not necessary - any input of
|
||||
// the original node is fine because all the inputs of a node are always in
|
||||
// the same frame.
|
||||
if (orig_node->num_inputs() > 0) {
|
||||
Node* orig_input0 = nullptr;
|
||||
TF_CHECK_OK(orig_node->input_node(0,
|
||||
const_cast<const Node**>(&orig_input0)));
|
||||
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
|
||||
}
|
||||
|
||||
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
|
||||
}
|
||||
|
||||
void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
||||
std::unique_ptr<Graph>* g,
|
||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
|
||||
int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
|
||||
Node* orig_node,
|
||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
||||
int* input_idx, int list_length,
|
||||
std::vector<NodeBuilder::NodeOut>* output_nodes) {
|
||||
CHECK_LT(*input_idx, inputs.size());
|
||||
CHECK_GT(list_length, 0);
|
||||
CHECK_NOTNULL(output_nodes);
|
||||
@ -796,48 +894,29 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
||||
|
||||
while (list_length != 0) {
|
||||
CHECK_GT(list_length, 0);
|
||||
CHECK_LE(*input_idx, inputs.size());
|
||||
CHECK_LT(*input_idx, inputs.size());
|
||||
Node* n = inputs[*input_idx].first;
|
||||
int slot = inputs[*input_idx].second;
|
||||
const OpDef::ArgDef& arg = n->op_def().output_arg(slot);
|
||||
// We need to check first if the input edge is going to carry a
|
||||
// single tensor or a list of tensors. If it is a list of tensors,
|
||||
// then we need to create list of Mkl dummy nodes.
|
||||
if (ArgIsList(arg)) {
|
||||
// If input node 'n' is producing a list/array output at output
|
||||
// slot 'slot' then we need to find out the length of that list/array.
|
||||
int N = GetTensorListLength(arg, n);
|
||||
CHECK_LE(N, list_length);
|
||||
// If 'n' is producing a single tensor, then create a single Mkl tensor
|
||||
// node.
|
||||
Node* mkl_node = nullptr;
|
||||
int mkl_node_output_slot = 0;
|
||||
// If it is a list, then create a list of Mkl dummy nodes.
|
||||
for (int j = 0; j < N; j++) {
|
||||
GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
|
||||
output_nodes->push_back(
|
||||
NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
|
||||
}
|
||||
(*input_idx)++;
|
||||
list_length -= N;
|
||||
} else {
|
||||
// If it is not a list, then create a single Mkl tensor node.
|
||||
Node* mkl_node = nullptr;
|
||||
int mkl_node_output_slot = 0;
|
||||
GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
|
||||
output_nodes->push_back(
|
||||
NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
|
||||
GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
|
||||
&mkl_node_output_slot);
|
||||
output_nodes->push_back(NodeBuilder::NodeOut(mkl_node,
|
||||
mkl_node_output_slot));
|
||||
(*input_idx)++;
|
||||
list_length--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get an input node that will feed Mkl tensor to the new
|
||||
// node that we are constructing. An input node could be (1) 'n'
|
||||
// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
|
||||
// if 'n' is not an Mkl layer.
|
||||
void MklLayoutRewritePass::GetNodeProducingMklTensor(
|
||||
std::unique_ptr<Graph>* g, Node* n, int n_output_slot, Node** mkl_node,
|
||||
int* mkl_node_output_slot) {
|
||||
void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
|
||||
Node* orig_node, Node* n,
|
||||
int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
|
||||
CHECK_NOTNULL(n);
|
||||
CHECK_NOTNULL(mkl_node);
|
||||
CHECK_NOTNULL(mkl_node_output_slot);
|
||||
@ -860,7 +939,7 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor(
|
||||
// to create a dummy node that will feed a dummy Mkl tensor to this node.
|
||||
// DummyMklTensor node has no input and generates only 1 output
|
||||
// (dummy Mkl tensor) as output slot number 0.
|
||||
GetDummyMklTensorNode(g, mkl_node, n);
|
||||
GetDummyMklTensorNode(g, mkl_node, orig_node);
|
||||
CHECK_NOTNULL(*mkl_node);
|
||||
*mkl_node_output_slot = 0;
|
||||
}
|
||||
@ -926,16 +1005,16 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
|
||||
if (ArgIsList(arg)) {
|
||||
std::vector<NodeBuilder::NodeOut> new_node_inputs;
|
||||
int N = GetTensorListLength(arg, old_node);
|
||||
GetNodesProducingMklTensorList(g, old_node_inputs, &iidx, N,
|
||||
&new_node_inputs);
|
||||
GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
|
||||
N, &new_node_inputs);
|
||||
nb->Input(new_node_inputs);
|
||||
nn_slot_idx++;
|
||||
} else {
|
||||
Node* mkl_node = nullptr;
|
||||
int mkl_node_output_slot = 0;
|
||||
GetNodeProducingMklTensor(g, old_node_inputs[iidx].first,
|
||||
old_node_inputs[iidx].second, &mkl_node,
|
||||
&mkl_node_output_slot);
|
||||
GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
|
||||
old_node_inputs[iidx].second,
|
||||
&mkl_node, &mkl_node_output_slot);
|
||||
nb->Input(mkl_node, mkl_node_output_slot);
|
||||
iidx++;
|
||||
nn_slot_idx++;
|
||||
@ -1027,6 +1106,23 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
|
||||
// device of the original
|
||||
// node.
|
||||
.Finalize(&**g, out));
|
||||
|
||||
// If number of inputs to the original node is > 0, then we add
|
||||
// control dependency between 1st input (index 0) of the original node and
|
||||
// the dummy Mkl node. This is needed because control-flow ops such as Enter,
|
||||
// Merge, etc, require frame_name of the dummy Mkl node to be same as the
|
||||
// rewritten node. Adding control edge between 1st input of the original node
|
||||
// and the dummy Mkl node ensures that the dummy node is in the same frame
|
||||
// as the original node. Choosing 1st input is not necessary - any input of
|
||||
// the original node is fine because all the inputs of a node are always in
|
||||
// the same frame.
|
||||
if (orig_node->num_inputs() > 0) {
|
||||
Node* orig_input0 = nullptr;
|
||||
TF_CHECK_OK(orig_node->input_node(0,
|
||||
const_cast<const Node**>(&orig_input0)));
|
||||
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
|
||||
}
|
||||
|
||||
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
|
||||
}
|
||||
|
||||
@ -1235,6 +1331,19 @@ void MklLayoutRewritePass::CopyAttrsRelu(const Node* orig_node,
|
||||
nb->Attr("T", T);
|
||||
}
|
||||
|
||||
void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
|
||||
NodeBuilder* nb) {
|
||||
DataType T;
|
||||
DataType Tshape;
|
||||
|
||||
// Get all attributes from old node.
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
|
||||
// Add attributes to new node.
|
||||
nb->Attr("T", T);
|
||||
nb->Attr("Tshape", Tshape);
|
||||
}
|
||||
|
||||
void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
|
||||
NodeBuilder* nb) {
|
||||
DataType T;
|
||||
@ -1303,20 +1412,6 @@ void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
|
||||
nb->Attr("is_training", is_training);
|
||||
}
|
||||
|
||||
void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
|
||||
NodeBuilder* nb) {
|
||||
DataType T;
|
||||
DataType Tshape;
|
||||
|
||||
// Get all attributes from old node.
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
|
||||
|
||||
// Add attributes to new node.
|
||||
nb->Attr("T", T);
|
||||
nb->Attr("Tshape", Tshape);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// Helper functions related to node merge pass
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
@ -1353,8 +1448,9 @@ Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int B_in = b->num_inputs();
|
||||
gtl::InlinedVector<Node*, 4> b_control_edges;
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in);
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
|
||||
FillInputs(b, &b_control_edges, &b_in);
|
||||
|
||||
// Shouldn't merge if a and b have different control edges.
|
||||
@ -1438,7 +1534,7 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
|
||||
CHECK_EQ(succ->in_edges().size(), 2);
|
||||
Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3
|
||||
int oper3_mkl_slot = 0; // For dummy MKL tensor node, output slot is 0.
|
||||
GetDummyMklTensorNode(g, &oper3_mkl, succ); // Get dummy Mkl tensor node
|
||||
GetDummyMklTensorNode(g, &oper3_mkl, pred); // Get dummy Mkl tensor node
|
||||
// as BiasAdd does not have Mkl tensor as input.
|
||||
CHECK_NOTNULL(oper3_mkl);
|
||||
|
||||
@ -1483,9 +1579,38 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
|
||||
// Set the Mkl layer label for this op.
|
||||
new_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
|
||||
|
||||
// Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
|
||||
// node are already copied in BuildNode. We handle control edges now.
|
||||
for (const Edge* e : pred->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
|
||||
}
|
||||
}
|
||||
for (const Edge* e : succ->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
|
||||
}
|
||||
}
|
||||
|
||||
// Incoming edges are fixed, we will fix the outgoing edges now.
|
||||
// First, we will fix outgoing control edges from 'pred' node.
|
||||
// We don't need to handle outgoing data edges from 'pred' node
|
||||
// because pred has only 1 output going to succ node (we enforced
|
||||
// this check for merge already).
|
||||
for (const Edge* e : pred->out_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
|
||||
}
|
||||
}
|
||||
|
||||
// Second, we will fix outgoing control and data edges from 'succ' node.
|
||||
for (const Edge* e : succ->out_edges()) {
|
||||
(*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input());
|
||||
if (e->IsControlEdge()) {
|
||||
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
|
||||
} else {
|
||||
CHECK_NOTNULL((*g)->AddEdge(new_node, e->src_output(), e->dst(),
|
||||
e->dst_input()));
|
||||
}
|
||||
}
|
||||
|
||||
// Copy device assigned to old node to new node.
|
||||
@ -1550,18 +1675,22 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
|
||||
"data_format or T attribute or devices of BiasAddGrad and "
|
||||
"Conv2D do not match. Will skip node rewrite optimization");
|
||||
}
|
||||
} else if (orig_node->type_string() == csinfo_.bias_add_grad &&
|
||||
ri->new_name == csinfo_.matmul) {
|
||||
// When BiasAddGrad has MatMul in context, we do not do any rewrite
|
||||
// and leave BiasAddGrad as it is. But we check for this condition
|
||||
// when we check for node rewrite rule. So we should not even come
|
||||
// here for MatMul. So we will fail now.
|
||||
return Status(
|
||||
error::Code::INVALID_ARGUMENT,
|
||||
"No rewrite is required for BiasAddGrad for MatMul context.");
|
||||
}
|
||||
}
|
||||
|
||||
// Get all inputs.
|
||||
const int num = orig_node->in_edges().size();
|
||||
// Check the number of inputs against the user-specified value for non-vararg
|
||||
// nodes.
|
||||
if (!IsVarArgNode(orig_node)) {
|
||||
CHECK_EQ(num, ri->num_ins);
|
||||
}
|
||||
const int num_inputs = orig_node->in_edges().size();
|
||||
gtl::InlinedVector<Node*, 4> control_edges;
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num);
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs);
|
||||
FillInputs(orig_node, &control_edges, &inputs);
|
||||
|
||||
// Build new node. We use same name as original node, but change the op name.
|
||||
@ -1596,8 +1725,15 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
|
||||
TF_CHECK_OK(nb.Finalize(&**g, &new_node));
|
||||
CHECK_NOTNULL(new_node);
|
||||
|
||||
// Incoming edges from 'orig_node' node to new 'new_node' node are already
|
||||
// copied in BuildNode. Copy outgoing edges from 'orig_node' node to new
|
||||
// Incoming data edges from 'orig_node' node to new 'new_node' node are
|
||||
// already copied in BuildNode. We need to handle control edges now.
|
||||
for (const Edge* e : orig_node->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
|
||||
}
|
||||
}
|
||||
|
||||
// Copy outgoing edges from 'orig_node' node to new
|
||||
// 'new_node' node, since the output also follows same ordering among
|
||||
// Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
|
||||
// tensors appropriately. Specifically, nth output of the original node
|
||||
@ -1605,15 +1741,12 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
|
||||
// of the tensors. For the contiguous ordering of the tensors, it will be n.
|
||||
// GetTensorDataIndex provides this mapping function.
|
||||
for (const Edge* e : orig_node->out_edges()) {
|
||||
// We need to handle control-edges by using their original slot number.
|
||||
// Generally, -1 is reserved for control slot.
|
||||
if (e->src_output() < 0) {
|
||||
(*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input());
|
||||
if (e->IsControlEdge()) {
|
||||
CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
|
||||
} else {
|
||||
(*g)->AddEdge(
|
||||
new_node,
|
||||
GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
|
||||
e->dst(), e->dst_input());
|
||||
CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(),
|
||||
e->src()->num_outputs()),
|
||||
e->dst(), e->dst_input()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1640,8 +1773,8 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n,
|
||||
bool is_matching_cinfo_found = false;
|
||||
std::vector<const ContextInfo*> mci;
|
||||
for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) {
|
||||
if (n->type_string() == ci->node) {
|
||||
mci.push_back(&*ci);
|
||||
if (n->type_string() == (*ci)->node) {
|
||||
mci.push_back(*ci);
|
||||
is_matching_cinfo_found = true;
|
||||
}
|
||||
}
|
||||
@ -1701,9 +1834,10 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n) {
|
||||
bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n,
|
||||
const ContextInfo* c) {
|
||||
const Node* fwd_node = nullptr;
|
||||
return SearchMatchingContext(n, &fwd_node) != nullptr;
|
||||
return SearchMatchingContext(n, &fwd_node) == c;
|
||||
}
|
||||
|
||||
const MklLayoutRewritePass::RewriteInfo*
|
||||
@ -1719,18 +1853,29 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// BiasAddGrad is not an Mkl layer, so we make an exception for it.
|
||||
if (n->type_string() != csinfo_.bias_add_grad) {
|
||||
if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// We support 2 types of node rewrites:
|
||||
// 1. Rewriting BiasAddGrad depending on its context.
|
||||
// 1. Rewriting BiasAddGrad depending on its MklConv2DWithBias context.
|
||||
// 2. Rewriting an op to Mkl op always
|
||||
// We return true if any of these 2 conditions is met.
|
||||
|
||||
// Find matching RewriteInfo and then check that rewrite rule applies.
|
||||
for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
|
||||
if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
|
||||
if (n->type_string().compare(ri->name) == 0 &&
|
||||
ri->rewrite_rule(n, ri->context)) {
|
||||
// If we are rewriting BiasAddGrad into BiasAddGrad for MatMul context,
|
||||
// then we just return directly.
|
||||
if (n->type_string() == csinfo_.bias_add_grad &&
|
||||
ri->context->fwd == csinfo_.matmul &&
|
||||
ri->new_name == csinfo_.bias_add_grad) {
|
||||
return nullptr;
|
||||
}
|
||||
return &*ri;
|
||||
}
|
||||
}
|
||||
@ -1753,7 +1898,8 @@ bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
|
||||
GetReversePostOrder(**g, &order); // This will give us topological sort.
|
||||
|
||||
for (Node* n : order) {
|
||||
if (!n->IsOp()) {
|
||||
// If node is not an op or it cannot run on CPU device, then skip.
|
||||
if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -1801,18 +1947,31 @@ bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
|
||||
return MklLayoutRewritePass().RunPass(g);
|
||||
}
|
||||
|
||||
Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
|
||||
if (options.graph == nullptr) {
|
||||
Status MklLayoutRewritePass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
if (options.graph == nullptr && options.partition_graphs == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the ownership of graph
|
||||
std::unique_ptr<Graph>* g = std::move(options.graph);
|
||||
auto process_graph = [&](std::unique_ptr<Graph>* g) {
|
||||
// Get the ownership of a graph
|
||||
std::unique_ptr<Graph>* ng = std::move(g);
|
||||
RunPass(ng);
|
||||
// Return the ownership of a graph back
|
||||
g->reset(ng->release());
|
||||
};
|
||||
|
||||
RunPass(g);
|
||||
|
||||
// Return the ownership of graph back
|
||||
options.graph->reset(g->release());
|
||||
if (kMklLayoutRewritePassGroup !=
|
||||
OptimizationPassRegistry::POST_PARTITIONING) {
|
||||
// For any pre-partitioning phase, a graph is stored in options.graph.
|
||||
process_graph(options.graph);
|
||||
} else {
|
||||
// For post partitioning phase, graphs are stored in
|
||||
// options.partition_graphs.
|
||||
for (auto& pg : *options.partition_graphs) {
|
||||
process_graph(&pg.second);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -39,7 +39,11 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
static void InitGraph(const string& s, Graph* graph) {
|
||||
const char kCPUDevice[] = "/job:a/replica:0/task:0/cpu:0";
|
||||
const char kGPUDevice[] = "/job:a/replica:0/task:0/gpu:0";
|
||||
|
||||
static void InitGraph(const string& s, Graph* graph,
|
||||
const string& device = kCPUDevice) {
|
||||
GraphDef graph_def;
|
||||
|
||||
auto parser = protobuf::TextFormat::Parser();
|
||||
@ -47,14 +51,18 @@ static void InitGraph(const string& s, Graph* graph) {
|
||||
CHECK(parser.MergeFromString(s, &graph_def)) << s;
|
||||
GraphConstructorOptions opts;
|
||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
|
||||
|
||||
for (Node* node : graph->nodes()) {
|
||||
node->set_assigned_device_name(device);
|
||||
}
|
||||
}
|
||||
|
||||
class MklLayoutPassTest : public ::testing::Test {
|
||||
public:
|
||||
MklLayoutPassTest() : graph_(OpRegistry::Global()) {}
|
||||
|
||||
void InitGraph(const string& s) {
|
||||
::tensorflow::InitGraph(s, &graph_);
|
||||
void InitGraph(const string& s, const string& device = kCPUDevice) {
|
||||
::tensorflow::InitGraph(s, &graph_, device);
|
||||
original_ = CanonicalGraphString(&graph_);
|
||||
}
|
||||
|
||||
@ -114,7 +122,8 @@ REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
|
||||
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
|
||||
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
|
||||
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
|
||||
REGISTER_OP("_MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful();
|
||||
REGISTER_OP("_MklInput2").Output("o: uint8")
|
||||
.Output("o1: uint8").SetIsStateful();
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Unit tests related to node merge optiimization
|
||||
@ -162,8 +171,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
|
||||
" input: ['E', 'Y']}");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
|
||||
"M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
|
||||
"DMT/_0->E:5;E->Z;M->E:3;N->E:4;Y->Z:1");
|
||||
"M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;"
|
||||
"A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;M->E:3;"
|
||||
"N->E:4;Y->Z:1");
|
||||
}
|
||||
|
||||
// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved)
|
||||
@ -194,8 +204,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
|
||||
" input: ['E', 'Y']}");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
|
||||
"M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
|
||||
"DMT/_0->E:5;E->Z;M:1->E:3;N:1->E:4;Y->Z:1");
|
||||
"M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;"
|
||||
"A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;"
|
||||
"M:1->E:3;N:1->E:4;Y->Z:1");
|
||||
}
|
||||
|
||||
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
|
||||
@ -226,8 +237,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Sub)|"
|
||||
"A->E;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;DMT/_2->E:5;"
|
||||
"E->Z;Y->Z:1");
|
||||
"A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||
"A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;"
|
||||
"DMT/_2->E:5;E->Z;Y->Z:1");
|
||||
}
|
||||
|
||||
// Graph contains only _MklConv2D, no AddBias.
|
||||
@ -330,9 +342,6 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
|
||||
"N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3");
|
||||
}
|
||||
|
||||
// Disabling Conv2DBackpropBias test for now as we have disabled rewrite
|
||||
// of BiasAddGrad into BackpropBias
|
||||
#if 0
|
||||
// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias
|
||||
// rewrite tests
|
||||
|
||||
@ -361,18 +370,17 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
|
||||
" input: ['E'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
|
||||
"E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);"
|
||||
"O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;DMT/_0->F:1;E->F;"
|
||||
"M->D:3;N->D:4;O->D:5");
|
||||
"E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);"
|
||||
"N(_MklInput);O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;"
|
||||
"DMT/_0->F:1;E->F;E:control->DMT/_0:control;M->D:3;N->D:4;"
|
||||
"O->D:5");
|
||||
}
|
||||
#endif
|
||||
|
||||
// No _MklConv2D in context, but Conv2D in context.
|
||||
// Only Conv2D would be rewritten to _MklConv2D, but no rewrite
|
||||
// for BiasAddGrad should happen.
|
||||
// No _MklConv2DWithBias in context, but _MklConv2D in context.
|
||||
// No rewrite for BiasAddGrad should happen.
|
||||
// C=_MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved)
|
||||
// C=_MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous)
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_No_MklConv2DWithBias) {
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
@ -507,8 +515,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) {
|
||||
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['B', 'C'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|"
|
||||
"A->C;B->C:1;B->D;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);"
|
||||
"DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
|
||||
"A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
|
||||
"DMT/_1->C:3");
|
||||
}
|
||||
|
||||
// 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will
|
||||
@ -535,7 +545,9 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
|
||||
" input: ['C', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;B->C:1;C->D:1;C->E;"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;"
|
||||
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||
"A:control->DMT/_2:control;B->C:1;C->D:1;C->E;"
|
||||
"C:1->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
|
||||
}
|
||||
|
||||
@ -558,6 +570,50 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
|
||||
"A->C;B->C:1;B->D;C->D:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Int32Input'}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'Conv2DBackpropFilter'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B', 'C']}"
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropFilter);"
|
||||
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Mul)|"
|
||||
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||
"A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
|
||||
"DMT/_1->D:4;DMT/_2->D:5");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradInput_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Int32Input'}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'Conv2DBackpropInput'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['B', 'A', 'C']}"
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropInput);"
|
||||
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Mul)|"
|
||||
"A->D:1;A->E;B->D;B:control->DMT/_0:control;"
|
||||
"B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;"
|
||||
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
||||
}
|
||||
|
||||
// Concat Op test: Concat with no Mkl layer feeding it
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
|
||||
InitGraph(
|
||||
@ -572,13 +628,14 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
|
||||
"node { name: 'D' op: 'Concat'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'N' value { i: 2 } }"
|
||||
" input: ['A', 'B']}"
|
||||
" input: ['A', 'B:0', 'B:1']}"
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;B->D:1;B->D:2;C->E;"
|
||||
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;A:control->DMT/_0:control;"
|
||||
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;"
|
||||
"B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
||||
}
|
||||
|
||||
// Concat with 2 Mkl layers feeding it
|
||||
@ -616,9 +673,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
|
||||
"F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;"
|
||||
"F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;"
|
||||
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
|
||||
"B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;"
|
||||
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
|
||||
"DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;H->I:1");
|
||||
"DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;"
|
||||
"G:control->DMT/_4:control;H->I:1");
|
||||
}
|
||||
|
||||
// Concat with 1 Mkl and 1 non-Mkl layer feeding it
|
||||
@ -651,12 +711,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
|
||||
"H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
|
||||
"H(_MklConcat);I(Mul)|A->E;A->I;A:control->DMT/_0:control;"
|
||||
"A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
|
||||
"DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:1->H:4;F->H:2;"
|
||||
"G->H;H->I:1");
|
||||
"G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1");
|
||||
}
|
||||
|
||||
#if 0
|
||||
// ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
|
||||
InitGraph(
|
||||
@ -676,11 +736,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;C->E;"
|
||||
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
||||
"A(Const);B(InputList);C(Input);D(_MklConcatV2);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;"
|
||||
"B:control->DMT/_0:control;B:control->DMT/_1:control;"
|
||||
"B:control->DMT/_2:control;C->E;D->E:1;DMT/_0->D:3;"
|
||||
"DMT/_1->D:4;DMT/_2->D:5");
|
||||
}
|
||||
#endif
|
||||
|
||||
// ConcatV2 with 2 Mkl layers feeding it
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
|
||||
@ -718,9 +779,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
|
||||
"F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;"
|
||||
"F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;"
|
||||
"A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;"
|
||||
"C:control->DMT/_0:control;C:control->DMT/_1:control;"
|
||||
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
|
||||
"DMT/_4->H:5;E->H;E:1->H:3;F->H:1;F:1->H:4;G->H:2;H->I:1");
|
||||
"DMT/_4->H:5;E->H;E:1->H:3;E:control->DMT/_4:control;F->H:1;"
|
||||
"F:1->H:4;G->H:2;H->I:1");
|
||||
}
|
||||
|
||||
// ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it
|
||||
@ -754,11 +818,175 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
|
||||
"H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
|
||||
"DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;F->H:1;"
|
||||
"H(_MklConcatV2);I(Mul)|A->E;A->I;A:control->DMT/_0:control;"
|
||||
"A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
|
||||
"DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;"
|
||||
"E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;"
|
||||
"G->H:2;H->I:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_Relu_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Relu'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(_MklRelu);C(Mul);DMT/_0(Const)|A->B;A->C;"
|
||||
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'ReluGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }"
|
||||
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'C'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(_MklReluGrad);D(Mul);DMT/_0(Const);"
|
||||
"DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;"
|
||||
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_ReluReluGrad_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Relu'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'ReluGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }"
|
||||
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'C'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(_MklRelu);C(_MklReluGrad);D(Mul);DMT/_0(Const);"
|
||||
"DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;"
|
||||
"A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;"
|
||||
"DMT/_1->C:2");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'AvgPool'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(_MklAvgPool);C(Mul);DMT/_0(Const)|A->B;A->C;"
|
||||
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolGrad_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Int32Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'AvgPoolGrad' "
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
|
||||
" input: ['A', 'B'] }"
|
||||
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['B', 'C'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Mul);DMT/_0(Const);"
|
||||
"DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
|
||||
"A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
|
||||
"DMT/_1->C:3");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolAvgPoolGrad_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'I' op: 'Int32Input'}"
|
||||
"node { name: 'B' op: 'AvgPool'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'AvgPoolGrad' "
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
|
||||
" input: ['I', 'B'] }"
|
||||
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'C'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Mul);DMT/_0(Const);"
|
||||
"DMT/_1(Const);I(Int32Input)|A->B;A->D;A:control->DMT/_0:control;"
|
||||
"B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;DMT/_1->C:2;I->C;"
|
||||
"I:control->DMT/_1:control");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGrad_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'FusedBatchNormGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'epsilon' value { f: 0.0001 } }"
|
||||
" attr { key: 'is_training' value { b: true } }"
|
||||
" input: ['A', 'B', 'C', 'D', 'E'] }"
|
||||
"node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'F'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
|
||||
"F(_MklFusedBatchNormGrad);G(Mul)|A->F;A->G;"
|
||||
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
|
||||
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
|
||||
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
|
||||
"E->F:4;F->G:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'FusedBatchNorm'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'epsilon' value { f: 0.0001 } }"
|
||||
" attr { key: 'is_training' value { b: true } }"
|
||||
" input: ['A', 'B', 'C', 'D', 'E'] }"
|
||||
"node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'F'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
|
||||
"F(_MklFusedBatchNorm);G(Mul)|A->F;A->G;"
|
||||
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
|
||||
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
|
||||
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
|
||||
"E->F:4;F->G:1");
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Unit tests related to rewriting node for workspace edges
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
@ -802,13 +1030,13 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
|
||||
"node { name: 'H' op: 'Input'}"
|
||||
"node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['H', 'G'] }");
|
||||
EXPECT_EQ(
|
||||
DoMklLayoutOptimizationPass(),
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);I(Mul)|"
|
||||
"A->B;B->C;B->E;B->G:2;B:1->G:3;B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;"
|
||||
"C->E:1;C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;"
|
||||
"DMT/_2->G:5;E->G;E:1->G:4;F->G:1;G->I:1;H->I");
|
||||
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);"
|
||||
"I(Mul)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
|
||||
"B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;B:control->DMT/_1:control;C->E:1;"
|
||||
"C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;DMT/_2->G:5;"
|
||||
"E->G;E:1->G:4;E:control->DMT/_2:control;F->G:1;G->I:1;H->I");
|
||||
}
|
||||
|
||||
/* Test LRN->LRNGrad replacement by workspace nodes. */
|
||||
@ -838,8 +1066,9 @@ TEST_F(MklLayoutPassTest, LRN_Positive) {
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);E(_MklLRNGrad);F(Mul)|"
|
||||
"A->B;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;C->E;C->F;D->E:1;"
|
||||
"DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
|
||||
"A->B;A:control->DMT/_0:control;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;"
|
||||
"C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
|
||||
"D->E:1;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
|
||||
}
|
||||
|
||||
/* Test LRN->LRNGrad replacement when only one of them is present. */
|
||||
@ -858,7 +1087,7 @@ TEST_F(MklLayoutPassTest, LRN_Negative1) {
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(_MklLRN);C(Mul);DMT/_0(Const)|"
|
||||
"A->B;A->C;B->C:1;DMT/_0->B:1");
|
||||
"A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
|
||||
}
|
||||
|
||||
/* Test LRN->LRNGrad replacement when only one of them is present. */
|
||||
@ -880,8 +1109,10 @@ TEST_F(MklLayoutPassTest, LRN_Negative2) {
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
|
||||
"A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
|
||||
"DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
|
||||
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
|
||||
"A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
|
||||
"DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
|
||||
}
|
||||
|
||||
/* Test LRN->LRNGrad negative case, where single LRN feeds
|
||||
@ -920,9 +1151,13 @@ TEST_F(MklLayoutPassTest, LRN_Negative3) {
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
|
||||
"DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;B->E:2;"
|
||||
"B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;D->E:1;"
|
||||
"D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
|
||||
"DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;"
|
||||
"A:control->DMT/_0:control;B->E:2;"
|
||||
"B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;"
|
||||
"C:control->DMT/_1:control;C:control->DMT/_2:control;"
|
||||
"C:control->DMT/_3:control;C:control->DMT/_4:control;"
|
||||
"C:control->DMT/_5:control;C:control->DMT/_6:control;"
|
||||
"D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
|
||||
"DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1");
|
||||
}
|
||||
|
||||
@ -951,8 +1186,9 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Mul)|"
|
||||
"A->B;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;C->E;C->F;D->E:2;"
|
||||
"DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
|
||||
"A->B;A:control->DMT/_0:control;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;"
|
||||
"C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
|
||||
"D->E:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
|
||||
}
|
||||
|
||||
// Test MaxPool>MaxPoolGrad replacement when only one of them is present.
|
||||
@ -972,7 +1208,7 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) {
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(_MklMaxPool);C(Mul);DMT/_0(Const)|"
|
||||
"A->B;A->C;B->C:1;DMT/_0->B:1");
|
||||
"A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
|
||||
}
|
||||
|
||||
// Test MaxPoolGrad replacement when only one of them is present.
|
||||
@ -995,8 +1231,374 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
|
||||
"A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
|
||||
"DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
|
||||
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
|
||||
"A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
|
||||
"DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
|
||||
}
|
||||
|
||||
// Test MaxPool handling for batch-wise pooling (NCHW)
|
||||
// No rewrite should take place in such case
|
||||
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative3) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'MaxPool'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||
}
|
||||
|
||||
// Test MaxPool handling for batch-wise pooling (NCHW)
|
||||
// No rewrite should take place in such case
|
||||
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative4) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'MaxPool'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||
}
|
||||
|
||||
// Test MaxPool handling for depth-wise pooling (NHWC)
|
||||
// No rewrite should take place in such case
|
||||
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative5) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'MaxPool'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 1, i:2, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||
}
|
||||
|
||||
// Test MaxPool handling for depth-wise pooling (NCHW)
|
||||
// No rewrite should take place in such case
|
||||
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative6) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'MaxPool'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:2, i:1, i:1} } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||
}
|
||||
|
||||
// Test MaxPool handling for batch-wise pooling (NHWC)
|
||||
// No rewrite should take place in such case
|
||||
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative7) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'MaxPool'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NHWC' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||
}
|
||||
|
||||
// Test MaxPool handling for batch-wise pooling (NHWC)
|
||||
// No rewrite should take place in such case
|
||||
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative8) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'MaxPool'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NHWC' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||
}
|
||||
|
||||
// Test MaxPool handling for depth-wise pooling (NHWC)
|
||||
// No rewrite should take place in such case
|
||||
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative9) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'MaxPool'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NHWC' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:2} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||
}
|
||||
|
||||
// Test MaxPool handling for depth-wise pooling (NHWC)
|
||||
// No rewrite should take place in such case
|
||||
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative10) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'MaxPool'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NHWC' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:2} } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Single Conv2D Op on GPU device
|
||||
// No rewrite should happen
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['B', 'C'] }", kGPUDevice);
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Conv2D);D(Mul)|A->C;B->C:1;B->D;C->D:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'O' op: '_MklInput'}"
|
||||
"node { name: 'D' op: '_MklConv2DWithBias'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
|
||||
"node { name: 'E' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['D', 'A']}"
|
||||
"node { name: 'F' op: 'BiasAddGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['E'] }", kGPUDevice);
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
|
||||
"E(Sub);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
|
||||
"O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;E->F;"
|
||||
"M->D:3;N->D:4;O->D:5");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Int32Input'}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'Conv2DBackpropFilter'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B', 'C']}"
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'D'] }", kGPUDevice);
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Mul)|"
|
||||
"A->D;A->E;B->D:1;C->D:2;D->E:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Relu'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }", kGPUDevice);
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Relu);C(Mul)|A->B;A->C;B->C:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'ReluGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }"
|
||||
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'C'] }", kGPUDevice);
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(ReluGrad);D(Mul)|A->C;A->D;B->C:1;C->D:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'MaxPool'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NHWC' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }", kGPUDevice);
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'AvgPool'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NHWC' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }", kGPUDevice);
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(AvgPool);C(Mul)|A->B;A->C;B->C:1");
|
||||
}
|
||||
|
||||
// Concat Op test: Concat with no Mkl layer feeding it
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Const' "
|
||||
" attr { key: 'dtype' value { type: DT_INT32 } }"
|
||||
" attr { key: 'value' value { "
|
||||
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
|
||||
" int_val: 0 } } } }"
|
||||
"node { name: 'B' op: 'InputList'"
|
||||
" attr { key: 'N' value { i: 2 } }}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'Concat'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'N' value { i: 2 } }"
|
||||
" input: ['A', 'B:0', 'B:1']}"
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D'] }", kGPUDevice);
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Const);B(InputList);C(Input);D(Concat);E(Mul)|A->D;"
|
||||
"B->D:1;B:1->D:2;C->E;D->E:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Const' "
|
||||
" attr { key: 'dtype' value { type: DT_INT32 } }"
|
||||
" attr { key: 'value' value { "
|
||||
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
|
||||
" int_val: 0 } } } }"
|
||||
"node { name: 'B' op: 'InputList'"
|
||||
" attr { key: 'N' value { i: 2 } }}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'ConcatV2'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'Tidx' value { type: DT_INT32 } }"
|
||||
" attr { key: 'N' value { i: 2 } }"
|
||||
" input: ['B:0', 'B:1', 'A']}"
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D'] }", kGPUDevice);
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Const);B(InputList);C(Input);D(ConcatV2);E(Mul)|"
|
||||
"A->D:2;B->D;B:1->D:1;C->E;D->E:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'FusedBatchNorm'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'epsilon' value { f: 0.0001 } }"
|
||||
" attr { key: 'is_training' value { b: true } }"
|
||||
" input: ['A', 'B', 'C', 'D', 'E'] }"
|
||||
"node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'F'] }", kGPUDevice);
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(Input);E(Input);"
|
||||
"F(FusedBatchNorm);G(Mul)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
|
||||
"E->F:4;F->G:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B', 'M', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['C', 'D'] }"
|
||||
"node { name: 'Y' op: 'Input'}"
|
||||
"node { name: 'Z' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['E', 'Y']}", kGPUDevice);
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
|
||||
"M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->C;"
|
||||
"B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
@ -98,12 +98,13 @@ class MklToTfConversionPass : public GraphOptimizationPass {
|
||||
Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*);
|
||||
};
|
||||
|
||||
// We register MklToTf insertion for phase 1 in post-partition grouping.
|
||||
// We register this pass after partitioning so that we get a complete
|
||||
// picture of inputs and outputs of the nodes in the graphs.
|
||||
// We register MklToTf insertion for phase 2 in post-partition grouping
|
||||
// because we register MklLayoutRewritePass for phase 1 in post-partition
|
||||
// grouping. We register this pass after partitioning so that we get a
|
||||
// complete picture of inputs and outputs of the nodes in the graphs.
|
||||
const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
|
||||
OptimizationPassRegistry::POST_PARTITIONING;
|
||||
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 1, MklToTfConversionPass);
|
||||
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
|
||||
|
||||
Status MklToTfConversionPass::InsertConversionNodeOnEdge(
|
||||
std::unique_ptr<Graph>* g, Edge* e) {
|
||||
@ -121,10 +122,12 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
|
||||
string data_format;
|
||||
|
||||
TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype));
|
||||
TF_CHECK_OK(GetNodeAttr(dst->def(), "T", &dst_datatype));
|
||||
if (src_datatype != dst_datatype) {
|
||||
string err_msg = "T attribute of " + src->name() + " and " + dst->name() +
|
||||
" do not match. Will not insert" +
|
||||
bool dst_dtype_found = GetNodeAttr(dst->def(), "T", &dst_datatype) ==
|
||||
Status::OK();
|
||||
// We compare source and destination datatypes only when both are found.
|
||||
if (dst_dtype_found && (src_datatype != dst_datatype)) {
|
||||
string err_msg = "T attribute of " + src->name() + " and " +
|
||||
dst->name() + " do not match. Will not insert" +
|
||||
" MklToTf node in such case.";
|
||||
return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
|
||||
}
|
||||
@ -202,18 +205,19 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
|
||||
<< src->type_string() << " and " << dst->type_string();
|
||||
|
||||
// Let's get source and destination data type.
|
||||
DataType src_datatype = DT_INVALID;
|
||||
if (GetNodeAttr(src->def(), "T", &src_datatype) != Status::OK()) {
|
||||
continue;
|
||||
}
|
||||
// We cannot check datatype on destination node because destination node
|
||||
// may not be Mkl node.
|
||||
DataType dst_datatype = DT_INVALID;
|
||||
GetNodeAttr(dst->def(), "T", &dst_datatype);
|
||||
DataType src_datatype;
|
||||
DataType dst_datatype;
|
||||
bool src_is_mkl_op = (GetNodeAttr(src->def(), "T", &src_datatype) ==
|
||||
Status::OK() &&
|
||||
IsMklSupportedOp(src->type_string(), src_datatype));
|
||||
bool dst_is_mkl_op = (GetNodeAttr(dst->def(), "T", &dst_datatype) ==
|
||||
Status::OK() &&
|
||||
IsMklSupportedOp(dst->type_string(), dst_datatype));
|
||||
|
||||
// Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
|
||||
if (IsMklSupportedOp(src->type_string(), src_datatype) &&
|
||||
!IsMklSupportedOp(dst->type_string(), dst_datatype)) {
|
||||
if (src_is_mkl_op && !dst_is_mkl_op) {
|
||||
VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name()
|
||||
<< " and " << dst->name() << " for inserting conversion nodes";
|
||||
candidate_edges.push_back(const_cast<Edge*>(e));
|
||||
|
@ -149,7 +149,7 @@ TEST_F(MklToTfConversionPass, Positive) {
|
||||
" input: ['C', 'D']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
|
||||
"_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
|
||||
"Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
|
||||
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
|
||||
} else {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
@ -172,7 +172,7 @@ TEST_F(MklToTfConversionPass, Positive) {
|
||||
" input: ['C', 'D']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
|
||||
"_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;"
|
||||
"Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;"
|
||||
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
|
||||
}
|
||||
}
|
||||
|
@ -265,6 +265,7 @@ class MklConcatOp : public OpKernel {
|
||||
s.GetDimension() > 0 ? s.GetSizes()[concat_dim] : 1;
|
||||
}
|
||||
mkl_context.MklCreateInputLayouts(context, input_shapes);
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
|
||||
CHECK_EQ(dnnConcatCreate_F32(&mkl_context.prim_concat, NULL, N,
|
||||
&mkl_context.lt_inputs[0]),
|
||||
@ -316,12 +317,14 @@ class MklConcatOp : public OpKernel {
|
||||
|
||||
mkl_context.mkl_tmp_tensors.resize(N);
|
||||
mkl_context.MklPrepareConcatInputs(context, input_tensors);
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
|
||||
// Execute primitive.
|
||||
CHECK_EQ(dnnExecute_F32(mkl_context.prim_concat, mkl_context.concat_res),
|
||||
E_SUCCESS);
|
||||
|
||||
mkl_context.MklCleanup();
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -38,9 +38,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -37,9 +37,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -266,8 +266,11 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
|
||||
int input_offsets[2];
|
||||
size_t conv_strides[2];
|
||||
MklShape input_shape, grad_filter_shape, out_backprop_shape;
|
||||
dnnPrimitive_t prim_conv_bwdfilter, convert_bwdfilter;
|
||||
dnnLayout_t lt_input, lt_grad_filter, lt_out_backprop;
|
||||
dnnPrimitive_t prim_conv_bwdfilter = nullptr;
|
||||
dnnPrimitive_t convert_bwdfilter = nullptr;
|
||||
dnnLayout_t lt_input = nullptr;
|
||||
dnnLayout_t lt_grad_filter = nullptr;
|
||||
dnnLayout_t lt_out_backprop = nullptr;
|
||||
void* conv_res[dnnResourceNumber];
|
||||
|
||||
void MklCleanup() {
|
||||
|
@ -23,8 +23,6 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -42,6 +40,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -36,9 +36,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -98,19 +98,18 @@ class MklConv2DOp : public OpKernel {
|
||||
filter.shape().DebugString()));
|
||||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
|
||||
OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i),
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("filter too large"));
|
||||
}
|
||||
|
||||
const int64 input_depth =
|
||||
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'C')
|
||||
: GetTensorDim(input, data_format_, 'C');
|
||||
OP_REQUIRES(context, input_depth == filter.dim_size(2),
|
||||
errors::InvalidArgument(
|
||||
"input and filter must have the same depth: ", input_depth,
|
||||
" vs ", filter.dim_size(2)));
|
||||
OP_REQUIRES(
|
||||
context, input_depth == filter.dim_size(2),
|
||||
errors::InvalidArgument("input and filter must have the same depth: ",
|
||||
input_depth, " vs ", filter.dim_size(2)));
|
||||
// The last dimension for filter is out_depth.
|
||||
const int out_depth = static_cast<int>(filter.dim_size(3));
|
||||
|
||||
@ -119,9 +118,8 @@ class MklConv2DOp : public OpKernel {
|
||||
const int64 input_rows_raw =
|
||||
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'H')
|
||||
: GetTensorDim(input, data_format_, 'H');
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
|
||||
OP_REQUIRES(context, FastBoundsCheck(input_rows_raw,
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("Input rows too large"));
|
||||
const int input_rows = static_cast<int>(input_rows_raw);
|
||||
const int filter_rows = static_cast<int>(filter.dim_size(0));
|
||||
@ -131,9 +129,8 @@ class MklConv2DOp : public OpKernel {
|
||||
const int64 input_cols_raw =
|
||||
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'W')
|
||||
: GetTensorDim(input, data_format_, 'W');
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
|
||||
OP_REQUIRES(context, FastBoundsCheck(input_cols_raw,
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("Input cols too large"));
|
||||
const int input_cols = static_cast<int>(input_cols_raw);
|
||||
const int filter_cols = static_cast<int>(filter.dim_size(1));
|
||||
@ -142,9 +139,8 @@ class MklConv2DOp : public OpKernel {
|
||||
const int64 input_batch_raw =
|
||||
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'N')
|
||||
: GetTensorDim(input, data_format_, 'N');
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(input_batch_raw, std::numeric_limits<int>::max()),
|
||||
OP_REQUIRES(context, FastBoundsCheck(input_batch_raw,
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("batch is too large"));
|
||||
const int batch = static_cast<int>(input_batch_raw);
|
||||
|
||||
|
@ -104,6 +104,15 @@ class MklLRNOp : public OpKernel {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(inteltf) MKL will support depth radius not equal to 2 in the future
|
||||
if (depth_radius_ != 2) {
|
||||
Tensor converted_tensor =
|
||||
ConvertMklToTF<T>(context, input, mkl_context.input_shape);
|
||||
mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
|
||||
beta_, converted_tensor);
|
||||
return;
|
||||
}
|
||||
|
||||
if (input_in_mkl_format) {
|
||||
// MKL supports normalization over channel dimension only
|
||||
if (mkl_context.input_shape.tf_dim_idx(mkl_context.in_dims - 1) ==
|
||||
@ -112,8 +121,10 @@ class MklLRNOp : public OpKernel {
|
||||
static_cast<dnnLayout_t>(mkl_context.input_shape.GetCurLayout());
|
||||
workspace_enabled_ = true;
|
||||
} else {
|
||||
Tensor converted_tensor =
|
||||
ConvertMklToTF<T>(context, input, mkl_context.input_shape);
|
||||
mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
|
||||
beta_, input);
|
||||
beta_, converted_tensor);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@ -267,7 +278,7 @@ class MklLRNOp : public OpKernel {
|
||||
}
|
||||
|
||||
// Fallback implementation - Taken from lrn_op.cc
|
||||
// TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
|
||||
// TODO(inteltf) Check if we can use EigenLRNOp directly instead of making a
|
||||
// copy.
|
||||
void MklDefaultToEigen(OpKernelContext* context, int depth_radius_,
|
||||
float bias_, float alpha_, float beta_,
|
||||
@ -378,6 +389,12 @@ class MklLRNGradOp : public OpKernel {
|
||||
mkl_context.MklDefaultToEigen(context);
|
||||
return;
|
||||
}
|
||||
|
||||
if (depth_radius_ != 2) {
|
||||
mkl_context.MklDefaultToEigen(context);
|
||||
return;
|
||||
}
|
||||
|
||||
if (ingrad_in_mkl_format || inimage_in_mkl_format) {
|
||||
const MklShape* tmp_mkl_shape = (ingrad_in_mkl_format)
|
||||
? &mkl_context.ingrad_shape
|
||||
@ -489,14 +506,11 @@ class MklLRNGradOp : public OpKernel {
|
||||
MklShape ingrad_shape, inimage_shape, outimage_shape;
|
||||
dnnPrimitive_t lrn_bwd = nullptr;
|
||||
dnnPrimitive_t convert_input = nullptr;
|
||||
/* dnnPrimitive_t convert_output; */
|
||||
dnnLayout_t lt_input = nullptr;
|
||||
dnnLayout_t lt_output = nullptr;
|
||||
dnnLayout_t lt_bdw_input = nullptr;
|
||||
dnnLayout_t lt_workspace = nullptr;
|
||||
dnnLayout_t lt_internal_input = nullptr;
|
||||
/* dnnLayout_t lt_internal_workspace;
|
||||
dnnLayout_t lt_internal_output; */
|
||||
void* res_lrn_bwd[dnnResourceNumber];
|
||||
|
||||
// prepare mkl input
|
||||
@ -619,14 +633,36 @@ class MklLRNGradOp : public OpKernel {
|
||||
// copy.
|
||||
void MklDefaultToEigen(OpKernelContext* context) {
|
||||
// CHECK(false);
|
||||
Tensor in_grads = MklGetInput(context, 0);
|
||||
Tensor in_image = MklGetInput(context, 1);
|
||||
Tensor out_image = MklGetInput(context, 2);
|
||||
|
||||
Tensor in_grads;
|
||||
Tensor in_image;
|
||||
Tensor out_image;
|
||||
|
||||
GetMklShape(context, 0, &ingrad_shape);
|
||||
GetMklShape(context, 1, &inimage_shape);
|
||||
GetMklShape(context, 2, &outimage_shape);
|
||||
|
||||
if (ingrad_shape.IsMklTensor()) {
|
||||
in_grads =
|
||||
ConvertMklToTF<T>(context, MklGetInput(context, 0), ingrad_shape);
|
||||
} else {
|
||||
in_grads = MklGetInput(context, 0);
|
||||
}
|
||||
|
||||
if (inimage_shape.IsMklTensor()) {
|
||||
in_image =
|
||||
ConvertMklToTF<T>(context, MklGetInput(context, 1), inimage_shape);
|
||||
} else {
|
||||
in_image = MklGetInput(context, 1);
|
||||
}
|
||||
|
||||
if (outimage_shape.IsMklTensor()) {
|
||||
out_image =
|
||||
ConvertMklToTF<T>(context, MklGetInput(context, 2), outimage_shape);
|
||||
} else {
|
||||
out_image = MklGetInput(context, 2);
|
||||
}
|
||||
|
||||
const int64 batch = static_cast<int64>(in_grads.dim_size(0));
|
||||
const int64 rows = static_cast<int64>(in_grads.dim_size(1));
|
||||
const int64 cols = static_cast<int64>(in_grads.dim_size(2));
|
||||
|
@ -202,12 +202,10 @@ class MklMatMulOp : public OpKernel {
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("MKL"), \
|
||||
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>)
|
||||
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
|
||||
|
||||
// TODO:Consider template specialization when adding/removing additional types
|
||||
// TODO(inteltf) Consider template specialization when adding/removing
|
||||
// additional types
|
||||
TF_CALL_float(REGISTER_CPU);
|
||||
TF_CALL_double(REGISTER_CPU);
|
||||
TF_CALL_complex64(REGISTER_CPU);
|
||||
|
@ -276,11 +276,6 @@ class MklMaxPoolingGradOp : public OpKernel {
|
||||
mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>(
|
||||
static_cast<const void*>(output_tensor->flat<T>().data()));
|
||||
|
||||
int64 output_size = output_tensor->NumElements();
|
||||
for (int64 i = 0; i < output_size; ++i) {
|
||||
(static_cast<float*>(mkl_context.pooling_res[dnnResourceDiffSrc]))[i] = 0;
|
||||
}
|
||||
|
||||
CHECK_EQ(
|
||||
dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res),
|
||||
E_SUCCESS);
|
||||
|
@ -16,17 +16,17 @@ limitations under the License.
|
||||
// See docs in ../ops/nn_ops.cc.
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
#include "tensorflow/core/platform/default/logging.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -194,45 +194,29 @@ class MklReluGradOp : public OpKernel {
|
||||
|
||||
void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
|
||||
void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
|
||||
dnnPrimitive_t cv_input_to_grad = NULL;
|
||||
Tensor mkl_tmp_buf_tensor;
|
||||
void* mkl_buffer_convert = nullptr;
|
||||
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
|
||||
&mkl_lt_internal_grad, prim_relu_bwd, dnnResourceDiffDst),
|
||||
// if input and grad are not in the same layout, do a conversion between
|
||||
// them.
|
||||
if (!dnnLayoutCompare_F32(lt_input, lt_grad)) {
|
||||
AllocTmpBuffer(context, &mkl_tmp_buf_tensor, lt_grad,
|
||||
&mkl_buffer_convert);
|
||||
CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad),
|
||||
E_SUCCESS);
|
||||
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input,
|
||||
prim_relu_bwd, dnnResourceSrc),
|
||||
CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, user_i,
|
||||
mkl_buffer_convert),
|
||||
E_SUCCESS);
|
||||
|
||||
if (!dnnLayoutCompare_F32(mkl_lt_internal_grad, lt_grad)) {
|
||||
AllocTmpBuffer(context, mkl_tmp_grad_buf_tensor, mkl_lt_internal_grad,
|
||||
&relu_res[dnnResourceDiffDst]);
|
||||
CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_grad, lt_grad,
|
||||
mkl_lt_internal_grad),
|
||||
E_SUCCESS);
|
||||
CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_grad, user_g,
|
||||
relu_res[dnnResourceDiffDst]),
|
||||
E_SUCCESS);
|
||||
dnnDelete_F32(cv_user_to_reluB_grad);
|
||||
} else {
|
||||
relu_res[dnnResourceDiffDst] = user_g;
|
||||
}
|
||||
|
||||
if (!dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input)) {
|
||||
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
|
||||
&relu_res[dnnResourceSrc]);
|
||||
CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_input, lt_input,
|
||||
mkl_lt_internal_input),
|
||||
E_SUCCESS);
|
||||
CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_input, user_i,
|
||||
relu_res[dnnResourceSrc]),
|
||||
E_SUCCESS);
|
||||
dnnDelete_F32(cv_user_to_reluB_input);
|
||||
relu_res[dnnResourceSrc] = mkl_buffer_convert;
|
||||
dnnDelete_F32(cv_input_to_grad);
|
||||
} else {
|
||||
relu_res[dnnResourceSrc] = user_i;
|
||||
}
|
||||
|
||||
dnnLayoutDelete_F32(mkl_lt_internal_input);
|
||||
dnnLayoutDelete_F32(mkl_lt_internal_grad);
|
||||
relu_res[dnnResourceDiffDst] = user_g;
|
||||
|
||||
}
|
||||
|
||||
void MklCreateInputLayouts(OpKernelContext* context) {
|
||||
@ -331,7 +315,7 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
|
||||
mkl_context.MklCreateInputLayouts(context);
|
||||
float negative_slope = 0.0;
|
||||
CHECK_EQ(dnnReLUCreateBackward_F32(&mkl_context.prim_relu_bwd, NULL,
|
||||
mkl_context.lt_grad, mkl_context.lt_input,
|
||||
mkl_context.lt_grad, mkl_context.lt_grad,
|
||||
negative_slope),
|
||||
E_SUCCESS);
|
||||
Tensor mkl_tmp_grad_buf_tensor, mkl_tmp_input_buf_tensor;
|
||||
|
@ -106,7 +106,7 @@ class MklToTfOp : public OpKernel {
|
||||
///////////////////////////////////////////////////////////
|
||||
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MklToTf") \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklToTf") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
|
@ -28,6 +28,12 @@ set -e
|
||||
CONTAINER_TYPE=$( echo "$1" | tr '[:upper:]' '[:lower:]' )
|
||||
shift 1
|
||||
|
||||
# Enable support for MKL, for Linux only.
|
||||
if [[ $(uname) == "Linux" ]]; then
|
||||
export TF_NEED_MKL=1
|
||||
export TF_DOWNLOAD_MKL=1
|
||||
fi
|
||||
|
||||
# Enable support for Google Cloud Platform (GCP)
|
||||
export TF_NEED_GCP=1
|
||||
# Enable support for HDFS
|
||||
|
@ -46,6 +46,7 @@ apt-get install -y --no-install-recommends \
|
||||
git \
|
||||
libcurl4-openssl-dev \
|
||||
libtool \
|
||||
mlocate \
|
||||
openjdk-8-jdk \
|
||||
openjdk-8-jre-headless \
|
||||
pkg-config \
|
||||
@ -63,6 +64,9 @@ apt-get install -y --no-install-recommends \
|
||||
zip \
|
||||
zlib1g-dev
|
||||
|
||||
# populate the database
|
||||
updatedb
|
||||
|
||||
if [[ "$1" != "--without_cmake" ]]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
cmake
|
||||
|
@ -28,6 +28,7 @@ export TF_NEED_GCP=0
|
||||
export TF_NEED_HDFS=0
|
||||
export TF_NEED_CUDA=0
|
||||
export TF_NEED_OPENCL=0
|
||||
export TF_NEED_MKL=0
|
||||
export COMPUTECPP_PATH="/usr/local"
|
||||
|
||||
export PATH="/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"
|
||||
|
@ -29,6 +29,7 @@ export PYTHON_BIN_PATH="/usr/bin/python"
|
||||
export TF_NEED_GCP=0
|
||||
export TF_NEED_HDFS=0
|
||||
export TF_NEED_OPENCL=0
|
||||
export TF_NEED_MKL=0
|
||||
export COMPUTECPP_PATH="/usr/local"
|
||||
|
||||
export PATH="/usr/local/cuda/bin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"
|
||||
|
2
third_party/grpc.BUILD
vendored
2
third_party/grpc.BUILD
vendored
@ -178,6 +178,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
],
|
||||
linkopts = ["-lpthread"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -1787,6 +1788,7 @@ cc_library(
|
||||
":grpc_unsecure",
|
||||
"//external:protobuf_clib",
|
||||
],
|
||||
linkopts = ["-lpthread"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
3
third_party/jemalloc.BUILD
vendored
3
third_party/jemalloc.BUILD
vendored
@ -94,6 +94,9 @@ cc_library(
|
||||
"@%ws%//tensorflow:linux_ppc64le": [
|
||||
"-lpthread",
|
||||
],
|
||||
"@%ws%//tensorflow:linux_x86_64": [
|
||||
"-lpthread",
|
||||
],
|
||||
"//conditions:default": [
|
||||
],
|
||||
}),
|
||||
|
1
third_party/llvm/llvm.BUILD
vendored
1
third_party/llvm/llvm.BUILD
vendored
@ -1695,6 +1695,7 @@ cc_library(
|
||||
":demangle",
|
||||
"@zlib_archive//:zlib",
|
||||
],
|
||||
linkopts = ["-lpthread", "-ldl"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
1
third_party/mkl/BUILD
vendored
1
third_party/mkl/BUILD
vendored
@ -16,6 +16,7 @@ load(
|
||||
cc_library(
|
||||
name = "intel_binary_blob",
|
||||
srcs = if_mkl([
|
||||
"libdl.so.2",
|
||||
"libmklml_intel.so",
|
||||
"libiomp5.so",
|
||||
]),
|
||||
|
Loading…
Reference in New Issue
Block a user