[Intel MKL] Adding MklTanh op
This commit is contained in:
parent
13ce8851cb
commit
489926629d
@ -675,18 +675,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
rinfo_.push_back(
|
||||
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
// Disable these two MKL operators for now due to some test failures caused
|
||||
// by these two ops
|
||||
/*
|
||||
rinfo_.push_back({csinfo_.tanh,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.tanh),
|
||||
rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.tanh_grad,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
*/
|
||||
rinfo_.push_back(
|
||||
{csinfo_.tanh_grad, mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
|
@ -2949,6 +2949,66 @@ TEST_F(MklLayoutPassTest, NodeRewrite_LeakyReluLeakyReluGrad_Positive) {
|
||||
"DMT/_1->C:2");
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
|
||||
InitGraph( \
|
||||
"node { name: 'A' op: '" #INPUT "'}" \
|
||||
"node { name: 'B' op: 'Tanh'" \
|
||||
" attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['A'] }" \
|
||||
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['A', 'B'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT ");B(_MklTanh);C(Zeta);DMT/_0(Const)|A->B;A->C;" \
|
||||
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_Tanh_Positive);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
|
||||
InitGraph( \
|
||||
"node { name: 'A' op: '" #INPUT "'}" \
|
||||
"node { name: 'B' op: '" #INPUT "'}" \
|
||||
"node { name: 'C' op: 'TanhGrad'" \
|
||||
" attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['A', 'B'] }" \
|
||||
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['A', 'C'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT ");B(" #INPUT ");C(_MklTanhGrad);D(Zeta);DMT/_0(Const);" \
|
||||
"DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;" \
|
||||
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhGrad_Positive);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
|
||||
InitGraph( \
|
||||
"node { name: 'A' op: '" #INPUT "'}" \
|
||||
"node { name: 'B' op: 'Tanh'" \
|
||||
" attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['A'] }" \
|
||||
"node { name: 'C' op: 'TanhGrad'" \
|
||||
" attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['B', 'A'] }" \
|
||||
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
|
||||
" input: ['A', 'C'] }"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT ");B(_MklTanh);C(_MklTanhGrad);D(Zeta);DMT/_0(Const);" \
|
||||
"DMT/_1(Const)|A->B;A->C:1;A->D;A:control->DMT/_0:control;" \
|
||||
"B->C;B:1->C:2;B:control->DMT/_1:control;C->D:1;DMT/_0->B:1;" \
|
||||
"DMT/_1->C:3"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhTanhGrad_Positive);
|
||||
#undef REGISTER_TEST
|
||||
// clang-format on
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
|
@ -8084,6 +8084,27 @@ tf_cc_test_mkl(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test_mkl(
|
||||
name = "mkl_relu_op_test",
|
||||
size = "small",
|
||||
srcs = ["mkl_relu_op_test.cc"],
|
||||
linkstatic = 1, # Fixes dyld error on MacOS.
|
||||
deps = [
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_mkl_kernel_library(
|
||||
name = "mkl_tfconv_op",
|
||||
prefix = "mkl_tfconv",
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -27,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
using mkldnn::algorithm;
|
||||
using mkldnn::eltwise_forward;
|
||||
@ -266,15 +266,19 @@ class MklEltwiseBwdParams {
|
||||
algorithm alg_kind;
|
||||
float alpha;
|
||||
float beta;
|
||||
// Whether the input that grad op gets from forward op is SRC
|
||||
// of forward op or DST of forward op.
|
||||
int forward_input_type;
|
||||
|
||||
MklEltwiseBwdParams(const memory::dims& src_dims,
|
||||
const memory::desc& common_md, algorithm alg_kind,
|
||||
float alpha, float beta)
|
||||
float alpha, float beta, int forward_input_type)
|
||||
: src_dims(src_dims),
|
||||
common_md(common_md),
|
||||
alg_kind(alg_kind),
|
||||
alpha(alpha),
|
||||
beta(beta) {}
|
||||
beta(beta),
|
||||
forward_input_type(forward_input_type) {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -430,7 +434,7 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||
// Create eltwise primitive and add it to net.
|
||||
context_.eltwise_bwd.reset(new mkldnn::eltwise_backward(*context_.bwd_pd));
|
||||
context_.bwd_primitives_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{{bwdParams.forward_input_type, *context_.src_mem},
|
||||
{MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem},
|
||||
{ MKLDNN_ARG_DIFF_SRC,
|
||||
*context_.diff_src_mem }});
|
||||
@ -631,14 +635,30 @@ class MklReluGradOpBase : public OpKernel {
|
||||
|
||||
virtual void Compute_Scalar(OpKernelContext* context) = 0;
|
||||
|
||||
// All activation functions that are part of NN ops, such as Relu, Elu,
|
||||
// LeakyRelu, Relu6, etc have dy at index 0 and y at index 1.
|
||||
//
|
||||
// if forward op is defined as: y = f(x),
|
||||
// {Relu,Elu,Relu6,LeakyRelu}Grad is: z = f_grad(dy,x)
|
||||
// TanhGrad is: z = tanh_grad(y,dy)
|
||||
//
|
||||
// Src below refers to a tensor that gradient op receives from forward
|
||||
// operator. From Relu-family ops, it is 'x'; while for TanhGrad, it is 'y'.
|
||||
virtual int GetDiffDstIndex() const { return 0; }
|
||||
virtual int GetSrcIndex() const { return 1; }
|
||||
virtual int GetDiffSrcIndex() const { return 0; }
|
||||
// What is the type of input tensor that grad op receives from forward op --
|
||||
// is it 'x' (SRC) or 'y' (DST). For Relu-family, it is 'x', so fwd op SRC.
|
||||
virtual int GetTypeOfInputTensorFromFwdOp() const { return MKLDNN_ARG_SRC; }
|
||||
|
||||
void Compute(OpKernelContext* context) {
|
||||
try {
|
||||
MklDnnData<T> src(&cpu_engine);
|
||||
MklDnnData<T> diff_dst(&cpu_engine);
|
||||
|
||||
const size_t diff_dst_index = 0; // index of diff_dst input tensor
|
||||
const size_t src_index = 1; // index of src input tensor
|
||||
const size_t diff_src_index = 0; // index of diff_src output tensor
|
||||
size_t diff_dst_index = GetDiffDstIndex();
|
||||
size_t src_index = GetSrcIndex();
|
||||
const size_t diff_src_index = GetDiffSrcIndex();
|
||||
|
||||
const Tensor& src_tensor = MklGetInput(context, src_index);
|
||||
const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
|
||||
@ -722,7 +742,7 @@ class MklReluGradOpBase : public OpKernel {
|
||||
}
|
||||
|
||||
MklEltwiseBwdParams<T> bwdParams(src_dims, common_md, alg_kind, alpha_,
|
||||
beta_);
|
||||
beta_, GetTypeOfInputTensorFromFwdOp());
|
||||
MklEltwiseBwdPrimitive<T>* eltwise_bwd =
|
||||
MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams);
|
||||
|
||||
@ -976,18 +996,28 @@ class MklTanhOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_tanh> {
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklTanhGradOp
|
||||
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_tanh> {
|
||||
: public MklReluGradOpBase<Device, T,
|
||||
ALGORITHM::eltwise_tanh_use_dst_for_bwd> {
|
||||
public:
|
||||
~MklTanhGradOp() {}
|
||||
|
||||
explicit MklTanhGradOp(OpKernelConstruction* context)
|
||||
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_tanh>(context, 0.0f,
|
||||
0.0f) {}
|
||||
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_tanh_use_dst_for_bwd>(
|
||||
context, 0.0f, 0.0f) {}
|
||||
|
||||
virtual int GetDiffDstIndex() const { return 1; }
|
||||
virtual int GetSrcIndex() const { return 0; }
|
||||
virtual int GetDiffSrcIndex() const { return 0; }
|
||||
|
||||
// TanhGrad gets 'y' from Tanh, where 'y' is output of Tanh(x).
|
||||
virtual int GetTypeOfInputTensorFromFwdOp() const { return MKLDNN_ARG_DST; }
|
||||
|
||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||
const size_t diff_dst_index = 0; // index of diff_dst input tensor
|
||||
const size_t src_index = 1; // index of src input tensor
|
||||
const size_t diff_src_index = 0; // index of diff_src output tensor
|
||||
// NOTE: Order of y and dy for Tanh is reverse of that for Relu/Elu/other
|
||||
// element-wise ops. Tanh is math op in Tensorflow; others are NN ops.
|
||||
const size_t diff_dst_index = GetDiffDstIndex();
|
||||
const size_t src_index = GetSrcIndex();
|
||||
const size_t diff_src_index = GetDiffSrcIndex();
|
||||
const Tensor& src_tensor = MklGetInput(context, src_index);
|
||||
const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
|
||||
Tensor* diff_src_tensor = nullptr;
|
||||
@ -1003,10 +1033,9 @@ class MklTanhGradOp
|
||||
void* user_i =
|
||||
static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
|
||||
// gradient of tanh(x) = 1 - tanh(x)^2
|
||||
T feature = (static_cast<T*>(user_i))[0];
|
||||
T e1 = std::exp(feature);
|
||||
T e2 = std::exp(-feature);
|
||||
T tanh = (e1 - e2) / (e1 + e2);
|
||||
// Input to TanhGrad is output of Tanh. So we do not need to compute
|
||||
// Tanh again.
|
||||
T tanh = (static_cast<T*>(user_i))[0];
|
||||
void* user_g =
|
||||
static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
|
||||
(static_cast<T*>(out_o))[0] =
|
||||
|
193
tensorflow/core/kernels/mkl_relu_op_test.cc
Normal file
193
tensorflow/core/kernels/mkl_relu_op_test.cc
Normal file
@ -0,0 +1,193 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#undef INTEL_MKL
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
// Compare performance of default Tensorflow convolution kernels (Eigen) with
|
||||
// MKL kernels on CPU.
|
||||
|
||||
// Before running these benchmarks configure OpenMP environment variables:
|
||||
// export KMP_BLOCKTIME=0
|
||||
// export OMP_NUM_THREADS=${num_threads}
|
||||
|
||||
namespace tensorflow {
|
||||
static Tensor NonMklTensor() {
|
||||
MklDnnShape non_mkl_shape;
|
||||
non_mkl_shape.SetMklTensor(false);
|
||||
|
||||
auto size = static_cast<int64>(non_mkl_shape.GetSerializeBufferSize());
|
||||
Tensor tensor(DT_UINT8, {size});
|
||||
|
||||
non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(),
|
||||
size * sizeof(uint8));
|
||||
return tensor;
|
||||
}
|
||||
|
||||
static Tensor GetRandomTensor(const TensorShape& shape) {
|
||||
Tensor tensor(DT_FLOAT, TensorShape(shape));
|
||||
tensor.flat<float>() = tensor.flat<float>().setRandom();
|
||||
return tensor;
|
||||
}
|
||||
|
||||
#define CREATE_DEFAULT_FWD_OP(NODE_NAME, OP_NAME) \
|
||||
static Graph* NODE_NAME(const TensorShape& shape) { \
|
||||
auto* graph = new Graph(OpRegistry::Global()); \
|
||||
Tensor input_t = GetRandomTensor(shape); \
|
||||
Node* input = test::graph::Constant(graph, input_t, "input"); \
|
||||
Node* op; \
|
||||
TF_CHECK_OK(NodeBuilder(graph->NewName(#NODE_NAME), #OP_NAME) \
|
||||
.Input(input) \
|
||||
.Attr("T", DT_FLOAT) \
|
||||
.Finalize(graph, &op)); \
|
||||
return graph; \
|
||||
}
|
||||
CREATE_DEFAULT_FWD_OP(Default_Tanh, Tanh)
|
||||
CREATE_DEFAULT_FWD_OP(Default_Elu, Elu)
|
||||
CREATE_DEFAULT_FWD_OP(Default_Relu, Relu)
|
||||
CREATE_DEFAULT_FWD_OP(Default_Relu6, Relu6)
|
||||
CREATE_DEFAULT_FWD_OP(Default_LeakyRelu, LeakyRelu)
|
||||
|
||||
#define CREATE_DEFAULT_BWD_OP(NODE_NAME, OP_NAME) \
|
||||
static Graph* NODE_NAME(const TensorShape& shape) { \
|
||||
auto* graph = new Graph(OpRegistry::Global()); \
|
||||
Tensor input_t = GetRandomTensor(shape); \
|
||||
Node* input = test::graph::Constant(graph, input_t, "input"); \
|
||||
Tensor grad_t = GetRandomTensor(shape); \
|
||||
Node* grad = test::graph::Constant(graph, grad_t, "grad"); \
|
||||
Node* op; \
|
||||
TF_CHECK_OK(NodeBuilder(graph->NewName(#NODE_NAME), #OP_NAME) \
|
||||
.Input(grad) \
|
||||
.Input(input) \
|
||||
.Attr("T", DT_FLOAT) \
|
||||
.Finalize(graph, &op)); \
|
||||
return graph; \
|
||||
}
|
||||
CREATE_DEFAULT_BWD_OP(Default_TanhGrad, TanhGrad)
|
||||
CREATE_DEFAULT_BWD_OP(Default_EluGrad, EluGrad)
|
||||
CREATE_DEFAULT_BWD_OP(Default_ReluGrad, ReluGrad)
|
||||
CREATE_DEFAULT_BWD_OP(Default_Relu6Grad, Relu6Grad)
|
||||
CREATE_DEFAULT_BWD_OP(Default_LeakyReluGrad, LeakyReluGrad)
|
||||
|
||||
#define CREATE_MKL_FWD_OP(NODE_NAME, OP_NAME) \
|
||||
static Graph* NODE_NAME(const TensorShape& shape) { \
|
||||
auto* graph = new Graph(OpRegistry::Global()); \
|
||||
\
|
||||
Tensor input_t = GetRandomTensor(shape); \
|
||||
Node* input = test::graph::Constant(graph, input_t, "input"); \
|
||||
\
|
||||
Node* not_mkl_shape = \
|
||||
test::graph::Constant(graph, NonMklTensor(), "not_mkl"); \
|
||||
\
|
||||
Node* op; \
|
||||
TF_CHECK_OK(NodeBuilder(graph->NewName(#NODE_NAME), #OP_NAME) \
|
||||
.Input(input) \
|
||||
.Input(not_mkl_shape) \
|
||||
.Attr("T", DT_FLOAT) \
|
||||
.Attr("_kernel", "MklLayoutDependentOp") \
|
||||
.Finalize(graph, &op)); \
|
||||
\
|
||||
return graph; \
|
||||
}
|
||||
|
||||
CREATE_MKL_FWD_OP(Mkl_Tanh, _MklTanh)
|
||||
CREATE_MKL_FWD_OP(Mkl_Elu, _MklElu)
|
||||
CREATE_MKL_FWD_OP(Mkl_Relu, _MklRelu)
|
||||
CREATE_MKL_FWD_OP(Mkl_Relu6, _MklRelu6)
|
||||
CREATE_MKL_FWD_OP(Mkl_LeakyRelu, _MklLeakyRelu)
|
||||
|
||||
#define CREATE_MKL_BWD_OP(NODE_NAME, OP_NAME) \
|
||||
static Graph* NODE_NAME(const TensorShape& shape) { \
|
||||
auto* graph = new Graph(OpRegistry::Global()); \
|
||||
\
|
||||
Tensor input_t = GetRandomTensor(shape); \
|
||||
Node* input = test::graph::Constant(graph, input_t, "input"); \
|
||||
Tensor grad_t = GetRandomTensor(shape); \
|
||||
Node* grad = test::graph::Constant(graph, grad_t, "grad"); \
|
||||
\
|
||||
Node* not_mkl_shape = \
|
||||
test::graph::Constant(graph, NonMklTensor(), "not_mkl"); \
|
||||
\
|
||||
Node* op; \
|
||||
TF_CHECK_OK(NodeBuilder(graph->NewName(#NODE_NAME), #OP_NAME) \
|
||||
.Input(grad) \
|
||||
.Input(input) \
|
||||
.Input(not_mkl_shape) \
|
||||
.Input(not_mkl_shape) \
|
||||
.Attr("T", DT_FLOAT) \
|
||||
.Attr("_kernel", "MklLayoutDependentOp") \
|
||||
.Finalize(graph, &op)); \
|
||||
\
|
||||
return graph; \
|
||||
}
|
||||
|
||||
CREATE_MKL_BWD_OP(Mkl_TanhGrad, _MklTanhGrad)
|
||||
CREATE_MKL_BWD_OP(Mkl_EluGrad, _MklEluGrad)
|
||||
CREATE_MKL_BWD_OP(Mkl_ReluGrad, _MklReluGrad)
|
||||
CREATE_MKL_BWD_OP(Mkl_Relu6Grad, _MklRelu6Grad)
|
||||
CREATE_MKL_BWD_OP(Mkl_LeakyReluGrad, _MklLeakyReluGrad)
|
||||
|
||||
#define BM_Activation(op, kind, A, B, C, D, type) \
|
||||
static void BM_##op##_##kind##_##type##_##A##_##B##_##C##_##D(int iters) { \
|
||||
int64 num_computed_elements = (A) * (B) * (C) * (D); \
|
||||
int64 flops_per_iter = num_computed_elements; \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
|
||||
\
|
||||
test::Benchmark(#type, kind##_##op({A, B, C, D})).Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_##op##_##kind##_##type##_##A##_##B##_##C##_##D)
|
||||
|
||||
#define BM(op, A, B, C, D, type) \
|
||||
BM_Activation(op, Default, A, B, C, D, type); \
|
||||
BM_Activation(op, Mkl, A, B, C, D, type);
|
||||
|
||||
#define TEST_ALL_SIZES(OP) \
|
||||
BM(OP, 2, 4, 8, 16, cpu); \
|
||||
BM(OP, 3, 5, 9, 17, cpu); \
|
||||
BM(OP, 32, 64, 128, 256, cpu); \
|
||||
BM(OP, 33, 65, 129, 257, cpu);
|
||||
|
||||
TEST_ALL_SIZES(Tanh)
|
||||
TEST_ALL_SIZES(TanhGrad)
|
||||
TEST_ALL_SIZES(Relu)
|
||||
TEST_ALL_SIZES(ReluGrad)
|
||||
TEST_ALL_SIZES(Elu)
|
||||
TEST_ALL_SIZES(EluGrad)
|
||||
TEST_ALL_SIZES(Relu6)
|
||||
TEST_ALL_SIZES(Relu6Grad)
|
||||
TEST_ALL_SIZES(LeakyRelu)
|
||||
TEST_ALL_SIZES(LeakyReluGrad)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
Loading…
x
Reference in New Issue
Block a user