[Intel MKL] Adding MklTanh op

This commit is contained in:
Niranjan Hasabnis 2020-05-12 15:26:43 -07:00
parent 13ce8851cb
commit 489926629d
5 changed files with 325 additions and 28 deletions

View File

@ -675,18 +675,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back(
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
// Disable these two MKL operators for now due to some test failures caused
// by these two ops
/*
rinfo_.push_back({csinfo_.tanh,
mkl_op_registry::GetMklOpName(csinfo_.tanh),
rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.tanh_grad,
mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
*/
rinfo_.push_back(
{csinfo_.tanh_grad, mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
rinfo_.push_back(
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});

View File

@ -2949,6 +2949,66 @@ TEST_F(MklLayoutPassTest, NodeRewrite_LeakyReluLeakyReluGrad_Positive) {
"DMT/_1->C:2");
}
// clang-format off
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Tanh'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklTanh);C(Zeta);DMT/_0(Const)|A->B;A->C;" \
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Tanh_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'TanhGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }" \
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(_MklTanhGrad);D(Zeta);DMT/_0(Const);" \
"DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhGrad_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Tanh'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'TanhGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['B', 'A'] }" \
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklTanh);C(_MklTanhGrad);D(Zeta);DMT/_0(Const);" \
"DMT/_1(Const)|A->B;A->C:1;A->D;A:control->DMT/_0:control;" \
"B->C;B:1->C:2;B:control->DMT/_1:control;C->D:1;DMT/_0->B:1;" \
"DMT/_1->C:3"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhTanhGrad_Positive);
#undef REGISTER_TEST
// clang-format on
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"

View File

@ -8084,6 +8084,27 @@ tf_cc_test_mkl(
],
)
tf_cc_test_mkl(
name = "mkl_relu_op_test",
size = "small",
srcs = ["mkl_relu_op_test.cc"],
linkstatic = 1, # Fixes dyld error on MacOS.
deps = [
":ops_testutil",
":ops_util",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_mkl_kernel_library(
name = "mkl_tfconv_op",
prefix = "mkl_tfconv",

View File

@ -19,7 +19,6 @@ limitations under the License.
#include <unordered_map>
#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@ -27,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
using mkldnn::algorithm;
using mkldnn::eltwise_forward;
@ -266,15 +266,19 @@ class MklEltwiseBwdParams {
algorithm alg_kind;
float alpha;
float beta;
// Whether the input that grad op gets from forward op is SRC
// of forward op or DST of forward op.
int forward_input_type;
MklEltwiseBwdParams(const memory::dims& src_dims,
const memory::desc& common_md, algorithm alg_kind,
float alpha, float beta)
float alpha, float beta, int forward_input_type)
: src_dims(src_dims),
common_md(common_md),
alg_kind(alg_kind),
alpha(alpha),
beta(beta) {}
beta(beta),
forward_input_type(forward_input_type) {}
};
template <typename T>
@ -430,7 +434,7 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
// Create eltwise primitive and add it to net.
context_.eltwise_bwd.reset(new mkldnn::eltwise_backward(*context_.bwd_pd));
context_.bwd_primitives_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{{bwdParams.forward_input_type, *context_.src_mem},
{MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem},
{ MKLDNN_ARG_DIFF_SRC,
*context_.diff_src_mem }});
@ -631,14 +635,30 @@ class MklReluGradOpBase : public OpKernel {
virtual void Compute_Scalar(OpKernelContext* context) = 0;
// All activation functions that are part of NN ops, such as Relu, Elu,
// LeakyRelu, Relu6, etc have dy at index 0 and y at index 1.
//
// if forward op is defined as: y = f(x),
// {Relu,Elu,Relu6,LeakyRelu}Grad is: z = f_grad(dy,x)
// TanhGrad is: z = tanh_grad(y,dy)
//
// Src below refers to a tensor that gradient op receives from forward
// operator. From Relu-family ops, it is 'x'; while for TanhGrad, it is 'y'.
virtual int GetDiffDstIndex() const { return 0; }
virtual int GetSrcIndex() const { return 1; }
virtual int GetDiffSrcIndex() const { return 0; }
// What is the type of input tensor that grad op receives from forward op --
// is it 'x' (SRC) or 'y' (DST). For Relu-family, it is 'x', so fwd op SRC.
virtual int GetTypeOfInputTensorFromFwdOp() const { return MKLDNN_ARG_SRC; }
void Compute(OpKernelContext* context) {
try {
MklDnnData<T> src(&cpu_engine);
MklDnnData<T> diff_dst(&cpu_engine);
const size_t diff_dst_index = 0; // index of diff_dst input tensor
const size_t src_index = 1; // index of src input tensor
const size_t diff_src_index = 0; // index of diff_src output tensor
size_t diff_dst_index = GetDiffDstIndex();
size_t src_index = GetSrcIndex();
const size_t diff_src_index = GetDiffSrcIndex();
const Tensor& src_tensor = MklGetInput(context, src_index);
const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
@ -722,7 +742,7 @@ class MklReluGradOpBase : public OpKernel {
}
MklEltwiseBwdParams<T> bwdParams(src_dims, common_md, alg_kind, alpha_,
beta_);
beta_, GetTypeOfInputTensorFromFwdOp());
MklEltwiseBwdPrimitive<T>* eltwise_bwd =
MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams);
@ -976,18 +996,28 @@ class MklTanhOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_tanh> {
template <typename Device, typename T>
class MklTanhGradOp
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_tanh> {
: public MklReluGradOpBase<Device, T,
ALGORITHM::eltwise_tanh_use_dst_for_bwd> {
public:
~MklTanhGradOp() {}
explicit MklTanhGradOp(OpKernelConstruction* context)
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_tanh>(context, 0.0f,
0.0f) {}
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_tanh_use_dst_for_bwd>(
context, 0.0f, 0.0f) {}
virtual int GetDiffDstIndex() const { return 1; }
virtual int GetSrcIndex() const { return 0; }
virtual int GetDiffSrcIndex() const { return 0; }
// TanhGrad gets 'y' from Tanh, where 'y' is output of Tanh(x).
virtual int GetTypeOfInputTensorFromFwdOp() const { return MKLDNN_ARG_DST; }
virtual void Compute_Scalar(OpKernelContext* context) {
const size_t diff_dst_index = 0; // index of diff_dst input tensor
const size_t src_index = 1; // index of src input tensor
const size_t diff_src_index = 0; // index of diff_src output tensor
// NOTE: Order of y and dy for Tanh is reverse of that for Relu/Elu/other
// element-wise ops. Tanh is math op in Tensorflow; others are NN ops.
const size_t diff_dst_index = GetDiffDstIndex();
const size_t src_index = GetSrcIndex();
const size_t diff_src_index = GetDiffSrcIndex();
const Tensor& src_tensor = MklGetInput(context, src_index);
const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
Tensor* diff_src_tensor = nullptr;
@ -1003,10 +1033,9 @@ class MklTanhGradOp
void* user_i =
static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
// gradient of tanh(x) = 1 - tanh(x)^2
T feature = (static_cast<T*>(user_i))[0];
T e1 = std::exp(feature);
T e2 = std::exp(-feature);
T tanh = (e1 - e2) / (e1 + e2);
// Input to TanhGrad is output of Tanh. So we do not need to compute
// Tanh again.
T tanh = (static_cast<T*>(user_i))[0];
void* user_g =
static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
(static_cast<T*>(out_o))[0] =

View File

@ -0,0 +1,193 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#undef INTEL_MKL
#ifdef INTEL_MKL
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "mkldnn.hpp"
#include "tensorflow/core/util/mkl_util.h"
// Compare performance of default Tensorflow convolution kernels (Eigen) with
// MKL kernels on CPU.
// Before running these benchmarks configure OpenMP environment variables:
// export KMP_BLOCKTIME=0
// export OMP_NUM_THREADS=${num_threads}
namespace tensorflow {
static Tensor NonMklTensor() {
MklDnnShape non_mkl_shape;
non_mkl_shape.SetMklTensor(false);
auto size = static_cast<int64>(non_mkl_shape.GetSerializeBufferSize());
Tensor tensor(DT_UINT8, {size});
non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(),
size * sizeof(uint8));
return tensor;
}
static Tensor GetRandomTensor(const TensorShape& shape) {
Tensor tensor(DT_FLOAT, TensorShape(shape));
tensor.flat<float>() = tensor.flat<float>().setRandom();
return tensor;
}
#define CREATE_DEFAULT_FWD_OP(NODE_NAME, OP_NAME) \
static Graph* NODE_NAME(const TensorShape& shape) { \
auto* graph = new Graph(OpRegistry::Global()); \
Tensor input_t = GetRandomTensor(shape); \
Node* input = test::graph::Constant(graph, input_t, "input"); \
Node* op; \
TF_CHECK_OK(NodeBuilder(graph->NewName(#NODE_NAME), #OP_NAME) \
.Input(input) \
.Attr("T", DT_FLOAT) \
.Finalize(graph, &op)); \
return graph; \
}
CREATE_DEFAULT_FWD_OP(Default_Tanh, Tanh)
CREATE_DEFAULT_FWD_OP(Default_Elu, Elu)
CREATE_DEFAULT_FWD_OP(Default_Relu, Relu)
CREATE_DEFAULT_FWD_OP(Default_Relu6, Relu6)
CREATE_DEFAULT_FWD_OP(Default_LeakyRelu, LeakyRelu)
#define CREATE_DEFAULT_BWD_OP(NODE_NAME, OP_NAME) \
static Graph* NODE_NAME(const TensorShape& shape) { \
auto* graph = new Graph(OpRegistry::Global()); \
Tensor input_t = GetRandomTensor(shape); \
Node* input = test::graph::Constant(graph, input_t, "input"); \
Tensor grad_t = GetRandomTensor(shape); \
Node* grad = test::graph::Constant(graph, grad_t, "grad"); \
Node* op; \
TF_CHECK_OK(NodeBuilder(graph->NewName(#NODE_NAME), #OP_NAME) \
.Input(grad) \
.Input(input) \
.Attr("T", DT_FLOAT) \
.Finalize(graph, &op)); \
return graph; \
}
CREATE_DEFAULT_BWD_OP(Default_TanhGrad, TanhGrad)
CREATE_DEFAULT_BWD_OP(Default_EluGrad, EluGrad)
CREATE_DEFAULT_BWD_OP(Default_ReluGrad, ReluGrad)
CREATE_DEFAULT_BWD_OP(Default_Relu6Grad, Relu6Grad)
CREATE_DEFAULT_BWD_OP(Default_LeakyReluGrad, LeakyReluGrad)
#define CREATE_MKL_FWD_OP(NODE_NAME, OP_NAME) \
static Graph* NODE_NAME(const TensorShape& shape) { \
auto* graph = new Graph(OpRegistry::Global()); \
\
Tensor input_t = GetRandomTensor(shape); \
Node* input = test::graph::Constant(graph, input_t, "input"); \
\
Node* not_mkl_shape = \
test::graph::Constant(graph, NonMklTensor(), "not_mkl"); \
\
Node* op; \
TF_CHECK_OK(NodeBuilder(graph->NewName(#NODE_NAME), #OP_NAME) \
.Input(input) \
.Input(not_mkl_shape) \
.Attr("T", DT_FLOAT) \
.Attr("_kernel", "MklLayoutDependentOp") \
.Finalize(graph, &op)); \
\
return graph; \
}
CREATE_MKL_FWD_OP(Mkl_Tanh, _MklTanh)
CREATE_MKL_FWD_OP(Mkl_Elu, _MklElu)
CREATE_MKL_FWD_OP(Mkl_Relu, _MklRelu)
CREATE_MKL_FWD_OP(Mkl_Relu6, _MklRelu6)
CREATE_MKL_FWD_OP(Mkl_LeakyRelu, _MklLeakyRelu)
#define CREATE_MKL_BWD_OP(NODE_NAME, OP_NAME) \
static Graph* NODE_NAME(const TensorShape& shape) { \
auto* graph = new Graph(OpRegistry::Global()); \
\
Tensor input_t = GetRandomTensor(shape); \
Node* input = test::graph::Constant(graph, input_t, "input"); \
Tensor grad_t = GetRandomTensor(shape); \
Node* grad = test::graph::Constant(graph, grad_t, "grad"); \
\
Node* not_mkl_shape = \
test::graph::Constant(graph, NonMklTensor(), "not_mkl"); \
\
Node* op; \
TF_CHECK_OK(NodeBuilder(graph->NewName(#NODE_NAME), #OP_NAME) \
.Input(grad) \
.Input(input) \
.Input(not_mkl_shape) \
.Input(not_mkl_shape) \
.Attr("T", DT_FLOAT) \
.Attr("_kernel", "MklLayoutDependentOp") \
.Finalize(graph, &op)); \
\
return graph; \
}
CREATE_MKL_BWD_OP(Mkl_TanhGrad, _MklTanhGrad)
CREATE_MKL_BWD_OP(Mkl_EluGrad, _MklEluGrad)
CREATE_MKL_BWD_OP(Mkl_ReluGrad, _MklReluGrad)
CREATE_MKL_BWD_OP(Mkl_Relu6Grad, _MklRelu6Grad)
CREATE_MKL_BWD_OP(Mkl_LeakyReluGrad, _MklLeakyReluGrad)
#define BM_Activation(op, kind, A, B, C, D, type) \
static void BM_##op##_##kind##_##type##_##A##_##B##_##C##_##D(int iters) { \
int64 num_computed_elements = (A) * (B) * (C) * (D); \
int64 flops_per_iter = num_computed_elements; \
testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
\
test::Benchmark(#type, kind##_##op({A, B, C, D})).Run(iters); \
} \
BENCHMARK(BM_##op##_##kind##_##type##_##A##_##B##_##C##_##D)
#define BM(op, A, B, C, D, type) \
BM_Activation(op, Default, A, B, C, D, type); \
BM_Activation(op, Mkl, A, B, C, D, type);
#define TEST_ALL_SIZES(OP) \
BM(OP, 2, 4, 8, 16, cpu); \
BM(OP, 3, 5, 9, 17, cpu); \
BM(OP, 32, 64, 128, 256, cpu); \
BM(OP, 33, 65, 129, 257, cpu);
TEST_ALL_SIZES(Tanh)
TEST_ALL_SIZES(TanhGrad)
TEST_ALL_SIZES(Relu)
TEST_ALL_SIZES(ReluGrad)
TEST_ALL_SIZES(Elu)
TEST_ALL_SIZES(EluGrad)
TEST_ALL_SIZES(Relu6)
TEST_ALL_SIZES(Relu6Grad)
TEST_ALL_SIZES(LeakyRelu)
TEST_ALL_SIZES(LeakyReluGrad)
} // namespace tensorflow
#endif // INTEL_MKL