working around a known gcc/hcc interface kernel args passing bug

This commit is contained in:
Deven Desai 2020-02-05 21:22:38 +00:00
parent 98ab95f7d9
commit 4a34c4e3b0
3 changed files with 17 additions and 9 deletions

View File

@ -105,9 +105,7 @@ namespace functor {
extern template struct Relu6Grad<GPUDevice, T>; \
\
template <> \
void LeakyRelu<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor features, T alpha, \
typename TTypes<T>::Tensor activations); \
void LeakyRelu<GPUDevice, T>::operator()(LeakyReluArgs args); \
extern template struct LeakyRelu<GPUDevice, T>; \
\
template <> \

View File

@ -143,8 +143,8 @@ class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> {
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
functor::LeakyRelu<Device, T> functor;
functor(context->eigen_device<Device>(), input.flat<T>(), alpha_,
output->flat<T>());
functor({context->eigen_device<Device>(), input.flat<T>(), alpha_,
output->flat<T>()});
}
private:

View File

@ -98,11 +98,21 @@ struct LeakyRelu {
//
// features: any shape.
// activations: same shape as "features".
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
T alpha, typename TTypes<T>::Tensor activations) {
// Need to bundle the args (to the LeakyRelu functor) within a struct
// Not doing so leads to Eigen kernel args not getting populated
// corretly for Eigen::half type (when building on the ROCM platform)
struct LeakyReluArgs {
const Device& d;
typename TTypes<T>::ConstTensor features;
T alpha;
typename TTypes<T>::Tensor activations;
};
void operator()(LeakyReluArgs args) {
// Note that alpha might be > 1 or < 0, so we don't use cwiseMax here.
activations.device(d) =
(features > static_cast<T>(0)).select(features, features * alpha);
args.activations.device(args.d) =
(args.features > static_cast<T>(0))
.select(args.features, args.features * args.alpha);
}
};