PR #36526: Removed NDIMS template arg from BinaryElementWiseOp::Operate
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/36526 Fixes #36525 Copybara import of the project: --198cba7bca
by Artem Mavrin <artemvmavrin@gmail.com>: Removed NDIMS template arg from BinaryElementWiseOp::Operate --108d1e9202
by Artem Mavrin <artemvmavrin@gmail.com>: Fixed undeclared identifier 'alpha' PiperOrigin-RevId: 297502705 Change-Id: Id13f6935f36ee212488ceb5ad3274d1649f49fd7
This commit is contained in:
parent
e90814a8f4
commit
07d7624b1a
tensorflow/core
@ -82,7 +82,29 @@ class BinaryElementWiseOp : public BinaryOp<T> {
|
||||
{0, 1}, 0, a.shape(), &output));
|
||||
|
||||
// Dispatch to the descendant's Operate() function.
|
||||
static_cast<CHILD*>(this)->Operate(context, a, b, output);
|
||||
switch (a.dims()) {
|
||||
#define NDIM_CASE(NDIMS) \
|
||||
case NDIMS: { \
|
||||
static_cast<CHILD*>(this)->template Operate<NDIMS>(context, a, b, output); \
|
||||
break; \
|
||||
}
|
||||
|
||||
NDIM_CASE(0);
|
||||
NDIM_CASE(1);
|
||||
NDIM_CASE(2);
|
||||
NDIM_CASE(3);
|
||||
NDIM_CASE(4);
|
||||
NDIM_CASE(5);
|
||||
NDIM_CASE(6);
|
||||
NDIM_CASE(7);
|
||||
NDIM_CASE(8);
|
||||
#undef NDIM_CASE
|
||||
|
||||
default:
|
||||
context->SetStatus(errors::InvalidArgument(
|
||||
"We only handle up to Tensor::dims() up to 8, not ", a.dims()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -124,8 +124,14 @@ class FakeQuantWithMinMaxArgsGradientOp
|
||||
quant_max_ = (1 << num_bits) - 1;
|
||||
}
|
||||
|
||||
template <int NDIMS>
|
||||
void Operate(OpKernelContext* context, const Tensor& gradient,
|
||||
const Tensor& input, Tensor* output) {
|
||||
OperateNoTemplate(context, gradient, input, output);
|
||||
}
|
||||
|
||||
void OperateNoTemplate(OpKernelContext* context, const Tensor& gradient,
|
||||
const Tensor& input, Tensor* output) {
|
||||
OP_REQUIRES(context, input.IsSameSize(gradient),
|
||||
InvalidArgument("gradient and input must be the same size"));
|
||||
FakeQuantWithMinMaxArgsGradientFunctor<Device> functor;
|
||||
|
@ -63,21 +63,32 @@ class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
|
||||
public:
|
||||
using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp;
|
||||
|
||||
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
|
||||
const Tensor& a, Tensor* output);
|
||||
|
||||
// INPUTS:
|
||||
// g (gradients): backpropagated gradients
|
||||
// a (inputs): either the inputs that were passed to ReluOp(), or its
|
||||
// outputs (using either one yields the same result here).
|
||||
// OUTPUT:
|
||||
// gradients to backprop
|
||||
template <int NDIMS>
|
||||
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
|
||||
Tensor* output) {
|
||||
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
|
||||
functor::ReluGrad<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
|
||||
output->flat<T>());
|
||||
OperateNoTemplate(context, g, a, output);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
void ReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
|
||||
const Tensor& g, const Tensor& a,
|
||||
Tensor* output) {
|
||||
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
|
||||
functor::ReluGrad<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
|
||||
output->flat<T>());
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
class Relu6Op : public UnaryElementWiseOp<T, Relu6Op<Device, T>> {
|
||||
public:
|
||||
@ -95,20 +106,31 @@ class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> {
|
||||
public:
|
||||
using BinaryElementWiseOp<T, Relu6GradOp<Device, T>>::BinaryElementWiseOp;
|
||||
|
||||
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
|
||||
const Tensor& a, Tensor* output);
|
||||
|
||||
// INPUTS:
|
||||
// g (gradients): backpropagated gradients
|
||||
// a (inputs): inputs that were passed to Relu6Op()
|
||||
// OUTPUT:
|
||||
// gradients to backprop
|
||||
template <int NDIMS>
|
||||
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
|
||||
Tensor* output) {
|
||||
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
|
||||
functor::Relu6Grad<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
|
||||
output->flat<T>());
|
||||
OperateNoTemplate(context, g, a, output);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
|
||||
const Tensor& g, const Tensor& a,
|
||||
Tensor* output) {
|
||||
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
|
||||
functor::Relu6Grad<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
|
||||
output->flat<T>());
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> {
|
||||
public:
|
||||
@ -140,24 +162,36 @@ class LeakyReluGradOp
|
||||
alpha_ = T(alpha_tmp);
|
||||
}
|
||||
|
||||
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
|
||||
const Tensor& a, T alpha, Tensor* output);
|
||||
|
||||
// INPUTS:
|
||||
// g (gradients): backpropagated gradients
|
||||
// a (inputs): either the inputs that were passed to LeakyReluOp(), or its
|
||||
// outputs (using either one yields the same result here).
|
||||
// OUTPUT:
|
||||
// gradients to backprop
|
||||
template <int NDIMS>
|
||||
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
|
||||
Tensor* output) {
|
||||
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
|
||||
functor::LeakyReluGrad<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha_,
|
||||
output->flat<T>());
|
||||
OperateNoTemplate(context, g, a, alpha_, output);
|
||||
}
|
||||
|
||||
private:
|
||||
T alpha_;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
void LeakyReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
|
||||
const Tensor& g,
|
||||
const Tensor& a, T alpha,
|
||||
Tensor* output) {
|
||||
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
|
||||
functor::LeakyReluGrad<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha,
|
||||
output->flat<T>());
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
|
||||
public:
|
||||
@ -175,20 +209,31 @@ class EluGradOp : public BinaryElementWiseOp<T, EluGradOp<Device, T>> {
|
||||
public:
|
||||
using BinaryElementWiseOp<T, EluGradOp<Device, T>>::BinaryElementWiseOp;
|
||||
|
||||
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
|
||||
const Tensor& a, Tensor* output);
|
||||
|
||||
// INPUTS:
|
||||
// g (gradients): backpropagated gradients
|
||||
// a (outputs): outputs of the EluOp()
|
||||
// OUTPUT:
|
||||
// gradients to backprop
|
||||
template <int NDIMS>
|
||||
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
|
||||
Tensor* output) {
|
||||
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
|
||||
functor::EluGrad<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
|
||||
output->flat<T>());
|
||||
OperateNoTemplate(context, g, a, output);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
|
||||
const Tensor& g, const Tensor& a,
|
||||
Tensor* output) {
|
||||
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
|
||||
functor::EluGrad<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
|
||||
output->flat<T>());
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
class SeluOp : public UnaryElementWiseOp<T, SeluOp<Device, T>> {
|
||||
public:
|
||||
@ -206,20 +251,31 @@ class SeluGradOp : public BinaryElementWiseOp<T, SeluGradOp<Device, T>> {
|
||||
public:
|
||||
using BinaryElementWiseOp<T, SeluGradOp<Device, T>>::BinaryElementWiseOp;
|
||||
|
||||
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
|
||||
const Tensor& a, Tensor* output);
|
||||
|
||||
// INPUTS:
|
||||
// g (gradients): backpropagated gradients
|
||||
// a (outputs): outputs of the SeluOp()
|
||||
// OUTPUT:
|
||||
// gradients to backprop
|
||||
template <int NDIMS>
|
||||
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
|
||||
Tensor* output) {
|
||||
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
|
||||
functor::SeluGrad<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
|
||||
output->flat<T>());
|
||||
OperateNoTemplate(context, g, a, output);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
|
||||
const Tensor& g, const Tensor& a,
|
||||
Tensor* output) {
|
||||
if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
|
||||
functor::SeluGrad<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
|
||||
output->flat<T>());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#undef EIGEN_USE_THREADS
|
||||
|
@ -50,20 +50,31 @@ class SoftplusGradOp
|
||||
explicit SoftplusGradOp(OpKernelConstruction* context)
|
||||
: BinaryElementWiseOp<T, SoftplusGradOp<Device, T>>(context) {}
|
||||
|
||||
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
|
||||
const Tensor& a, Tensor* output);
|
||||
|
||||
// INPUTS:
|
||||
// g (gradients): backpropagated gradients
|
||||
// a (inputs): inputs that were passed to SoftplusOp()
|
||||
// OUTPUT:
|
||||
// gradients to backprop
|
||||
template <int NDIMS>
|
||||
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
|
||||
Tensor* output) {
|
||||
OP_REQUIRES(context, a.IsSameSize(g),
|
||||
errors::InvalidArgument("g and a must be the same size"));
|
||||
functor::SoftplusGrad<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
|
||||
output->flat<T>());
|
||||
OperateNoTemplate(context, g, a, output);
|
||||
}
|
||||
};
|
||||
template <typename Device, typename T>
|
||||
void SoftplusGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
|
||||
const Tensor& g,
|
||||
const Tensor& a,
|
||||
Tensor* output) {
|
||||
OP_REQUIRES(context, a.IsSameSize(g),
|
||||
errors::InvalidArgument("g and a must be the same size"));
|
||||
functor::SoftplusGrad<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
|
||||
output->flat<T>());
|
||||
}
|
||||
|
||||
#define REGISTER_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
|
@ -50,21 +50,33 @@ class SoftsignGradOp
|
||||
explicit SoftsignGradOp(OpKernelConstruction* context)
|
||||
: BinaryElementWiseOp<T, SoftsignGradOp<Device, T>>(context) {}
|
||||
|
||||
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
|
||||
const Tensor& a, Tensor* output);
|
||||
|
||||
// INPUTS:
|
||||
// g (gradients): backpropagated gradients
|
||||
// a (inputs): inputs that were passed to SoftsignOp()
|
||||
// OUTPUT:
|
||||
// gradients to backprop
|
||||
template <int NDIMS>
|
||||
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
|
||||
Tensor* output) {
|
||||
OP_REQUIRES(context, a.IsSameSize(g),
|
||||
errors::InvalidArgument("g and a must be the same size"));
|
||||
functor::SoftsignGrad<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
|
||||
output->flat<T>());
|
||||
OperateNoTemplate(context, g, a, output);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
void SoftsignGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
|
||||
const Tensor& g,
|
||||
const Tensor& a,
|
||||
Tensor* output) {
|
||||
OP_REQUIRES(context, a.IsSameSize(g),
|
||||
errors::InvalidArgument("g and a must be the same size"));
|
||||
functor::SoftsignGrad<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
|
||||
output->flat<T>());
|
||||
}
|
||||
|
||||
#define REGISTER_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Softsign").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
|
Loading…
Reference in New Issue
Block a user