PR #36526: Removed NDIMS template arg from BinaryElementWiseOp::Operate

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/36526

Fixes #36525
Copybara import of the project:

--
198cba7bca by Artem Mavrin <artemvmavrin@gmail.com>:

Removed NDIMS template arg from BinaryElementWiseOp::Operate

--
108d1e9202 by Artem Mavrin <artemvmavrin@gmail.com>:

Fixed undeclared identifier 'alpha'

PiperOrigin-RevId: 297502705
Change-Id: Id13f6935f36ee212488ceb5ad3274d1649f49fd7
This commit is contained in:
A. Unique TensorFlower 2020-02-26 20:34:20 -08:00 committed by TensorFlower Gardener
parent e90814a8f4
commit 07d7624b1a
5 changed files with 138 additions and 31 deletions

View File

@ -82,7 +82,29 @@ class BinaryElementWiseOp : public BinaryOp<T> {
{0, 1}, 0, a.shape(), &output));
// Dispatch to the descendant's Operate() function.
static_cast<CHILD*>(this)->Operate(context, a, b, output);
switch (a.dims()) {
#define NDIM_CASE(NDIMS) \
case NDIMS: { \
static_cast<CHILD*>(this)->template Operate<NDIMS>(context, a, b, output); \
break; \
}
NDIM_CASE(0);
NDIM_CASE(1);
NDIM_CASE(2);
NDIM_CASE(3);
NDIM_CASE(4);
NDIM_CASE(5);
NDIM_CASE(6);
NDIM_CASE(7);
NDIM_CASE(8);
#undef NDIM_CASE
default:
context->SetStatus(errors::InvalidArgument(
"We only handle up to Tensor::dims() up to 8, not ", a.dims()));
break;
}
}
};

View File

@ -124,8 +124,14 @@ class FakeQuantWithMinMaxArgsGradientOp
quant_max_ = (1 << num_bits) - 1;
}
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& gradient,
const Tensor& input, Tensor* output) {
OperateNoTemplate(context, gradient, input, output);
}
void OperateNoTemplate(OpKernelContext* context, const Tensor& gradient,
const Tensor& input, Tensor* output) {
OP_REQUIRES(context, input.IsSameSize(gradient),
InvalidArgument("gradient and input must be the same size"));
FakeQuantWithMinMaxArgsGradientFunctor<Device> functor;

View File

@ -63,20 +63,31 @@ class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
public:
using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp;
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
// INPUTS:
// g (gradients): backpropagated gradients
// a (inputs): either the inputs that were passed to ReluOp(), or its
// outputs (using either one yields the same result here).
// OUTPUT:
// gradients to backprop
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
OperateNoTemplate(context, g, a, output);
}
};
template <typename Device, typename T>
void ReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
const Tensor& g, const Tensor& a,
Tensor* output) {
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
functor::ReluGrad<Device, T> functor;
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
output->flat<T>());
}
};
template <typename Device, typename T>
class Relu6Op : public UnaryElementWiseOp<T, Relu6Op<Device, T>> {
@ -95,19 +106,30 @@ class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> {
public:
using BinaryElementWiseOp<T, Relu6GradOp<Device, T>>::BinaryElementWiseOp;
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
// INPUTS:
// g (gradients): backpropagated gradients
// a (inputs): inputs that were passed to Relu6Op()
// OUTPUT:
// gradients to backprop
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
OperateNoTemplate(context, g, a, output);
}
};
template <typename Device, typename T>
void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
const Tensor& g, const Tensor& a,
Tensor* output) {
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
functor::Relu6Grad<Device, T> functor;
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
output->flat<T>());
}
};
template <typename Device, typename T>
class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> {
@ -140,24 +162,36 @@ class LeakyReluGradOp
alpha_ = T(alpha_tmp);
}
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, T alpha, Tensor* output);
// INPUTS:
// g (gradients): backpropagated gradients
// a (inputs): either the inputs that were passed to LeakyReluOp(), or its
// outputs (using either one yields the same result here).
// OUTPUT:
// gradients to backprop
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
functor::LeakyReluGrad<Device, T> functor;
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha_,
output->flat<T>());
OperateNoTemplate(context, g, a, alpha_, output);
}
private:
T alpha_;
};
template <typename Device, typename T>
void LeakyReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
const Tensor& g,
const Tensor& a, T alpha,
Tensor* output) {
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
functor::LeakyReluGrad<Device, T> functor;
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha,
output->flat<T>());
};
template <typename Device, typename T>
class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
public:
@ -175,19 +209,30 @@ class EluGradOp : public BinaryElementWiseOp<T, EluGradOp<Device, T>> {
public:
using BinaryElementWiseOp<T, EluGradOp<Device, T>>::BinaryElementWiseOp;
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
// INPUTS:
// g (gradients): backpropagated gradients
// a (outputs): outputs of the EluOp()
// OUTPUT:
// gradients to backprop
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
OperateNoTemplate(context, g, a, output);
}
};
template <typename Device, typename T>
void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
const Tensor& g, const Tensor& a,
Tensor* output) {
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
functor::EluGrad<Device, T> functor;
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
output->flat<T>());
}
};
template <typename Device, typename T>
class SeluOp : public UnaryElementWiseOp<T, SeluOp<Device, T>> {
@ -206,19 +251,30 @@ class SeluGradOp : public BinaryElementWiseOp<T, SeluGradOp<Device, T>> {
public:
using BinaryElementWiseOp<T, SeluGradOp<Device, T>>::BinaryElementWiseOp;
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
// INPUTS:
// g (gradients): backpropagated gradients
// a (outputs): outputs of the SeluOp()
// OUTPUT:
// gradients to backprop
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
OperateNoTemplate(context, g, a, output);
}
};
template <typename Device, typename T>
void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
const Tensor& g, const Tensor& a,
Tensor* output) {
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
functor::SeluGrad<Device, T> functor;
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
output->flat<T>());
}
};
} // namespace tensorflow

View File

@ -50,20 +50,31 @@ class SoftplusGradOp
explicit SoftplusGradOp(OpKernelConstruction* context)
: BinaryElementWiseOp<T, SoftplusGradOp<Device, T>>(context) {}
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
// INPUTS:
// g (gradients): backpropagated gradients
// a (inputs): inputs that were passed to SoftplusOp()
// OUTPUT:
// gradients to backprop
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
OperateNoTemplate(context, g, a, output);
}
};
template <typename Device, typename T>
void SoftplusGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
const Tensor& g,
const Tensor& a,
Tensor* output) {
OP_REQUIRES(context, a.IsSameSize(g),
errors::InvalidArgument("g and a must be the same size"));
functor::SoftplusGrad<Device, T> functor;
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
output->flat<T>());
}
};
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \

View File

@ -50,20 +50,32 @@ class SoftsignGradOp
explicit SoftsignGradOp(OpKernelConstruction* context)
: BinaryElementWiseOp<T, SoftsignGradOp<Device, T>>(context) {}
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
// INPUTS:
// g (gradients): backpropagated gradients
// a (inputs): inputs that were passed to SoftsignOp()
// OUTPUT:
// gradients to backprop
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
OperateNoTemplate(context, g, a, output);
}
};
template <typename Device, typename T>
void SoftsignGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
const Tensor& g,
const Tensor& a,
Tensor* output) {
OP_REQUIRES(context, a.IsSameSize(g),
errors::InvalidArgument("g and a must be the same size"));
functor::SoftsignGrad<Device, T> functor;
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
output->flat<T>());
}
};
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \