Merge pull request #40598 from Intel-tensorflow:nhasabni/mkl_tanh_buildfix
PiperOrigin-RevId: 317484508 Change-Id: I8db4dea4b2d86e8f354c94833ea12f7cb700c0cd
This commit is contained in:
commit
7e0b098690
tensorflow/core
@ -682,12 +682,15 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
rinfo_.push_back(
|
||||
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Optimized TanhGrad support exists only in DNNL 1.x.
|
||||
rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.tanh_grad, mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back(
|
||||
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
|
@ -3024,6 +3024,8 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyReluLeakyReluGrad_Positive);
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
|
||||
@ -3081,6 +3083,7 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhGrad_Positive);
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhTanhGrad_Positive);
|
||||
#undef REGISTER_TEST
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
// clang-format on
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) {
|
||||
|
@ -269,7 +269,7 @@ class MklEltwiseBwdParams {
|
||||
|
||||
MklEltwiseBwdParams(const memory::dims& src_dims,
|
||||
const memory::desc& common_md, algorithm alg_kind,
|
||||
float alpha, float beta, int forward_input_type)
|
||||
float alpha, float beta, int forward_input_type = -1)
|
||||
: src_dims(src_dims),
|
||||
common_md(common_md),
|
||||
alg_kind(alg_kind),
|
||||
@ -644,7 +644,10 @@ class MklReluGradOpBase : public OpKernel {
|
||||
virtual int GetDiffSrcIndex() const { return 0; }
|
||||
// What is the type of input tensor that grad op receives from forward op --
|
||||
// is it 'x' (SRC) or 'y' (DST). For Relu-family, it is 'x', so fwd op SRC.
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
virtual int GetTypeOfInputTensorFromFwdOp() const { return MKLDNN_ARG_SRC; }
|
||||
#endif
|
||||
|
||||
void Compute(OpKernelContext* context) {
|
||||
try {
|
||||
@ -736,8 +739,16 @@ class MklReluGradOpBase : public OpKernel {
|
||||
common_md = src_md;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
MklEltwiseBwdParams<T> bwdParams(src_dims, common_md, alg_kind, alpha_,
|
||||
beta_, GetTypeOfInputTensorFromFwdOp());
|
||||
#else
|
||||
// MKLDNN V0 does not support reusing output of forward op in backward.
|
||||
// So this optimization works only in MKLDNN v1.
|
||||
MklEltwiseBwdParams<T> bwdParams(src_dims, common_md, alg_kind, alpha_,
|
||||
beta_);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
MklEltwiseBwdPrimitive<T>* eltwise_bwd =
|
||||
MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams);
|
||||
|
||||
@ -962,6 +973,11 @@ class MklEluGradOp
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Optimized TanhGrad support exists in DNNL1.x only
|
||||
// (eltwise_tanh_use_dst_for_bwd). We can still support it with DNNL0.x, but
|
||||
// it will not be optimized. So we disable it for DNNL0.x.
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklTanhOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_tanh> {
|
||||
public:
|
||||
@ -1043,6 +1059,7 @@ class MklTanhGradOp
|
||||
(static_cast<T*>(user_g))[0] * (static_cast<T>(1) - tanh * tanh);
|
||||
}
|
||||
};
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
#define RELU6_UPPER_BOUND 6.0f
|
||||
template <typename Device, typename T>
|
||||
@ -1227,6 +1244,7 @@ TF_CALL_bfloat16(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES);
|
||||
TF_CALL_float(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES);
|
||||
TF_CALL_bfloat16(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES);
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
#define REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklTanh") \
|
||||
@ -1242,6 +1260,7 @@ TF_CALL_bfloat16(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES);
|
||||
MklTanhGradOp<CPUDevice, type>);
|
||||
TF_CALL_float(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES);
|
||||
TF_CALL_bfloat16(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES);
|
||||
#endif
|
||||
|
||||
#define REGISTER_RELU6_MKL_SUPPORTED_KERNELS_TYPES(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
|
@ -121,8 +121,11 @@ static Graph* Activation(const string& op_name, const string& kind,
|
||||
BM(OP, 32, 64, 128, 256, cpu); \
|
||||
BM(OP, 33, 65, 129, 257, cpu);
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Optimized MKLDNN TanhGrad support exists in DNNL1.x only.
|
||||
TEST_ALL_SIZES(Tanh)
|
||||
TEST_ALL_SIZES(TanhGrad)
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
TEST_ALL_SIZES(Relu)
|
||||
TEST_ALL_SIZES(ReluGrad)
|
||||
TEST_ALL_SIZES(Elu)
|
||||
|
Loading…
Reference in New Issue
Block a user