working around a known gcc/hcc interface kernel args passing bug
This commit is contained in:
parent
98ab95f7d9
commit
4a34c4e3b0
@ -105,9 +105,7 @@ namespace functor {
|
||||
extern template struct Relu6Grad<GPUDevice, T>; \
|
||||
\
|
||||
template <> \
|
||||
void LeakyRelu<GPUDevice, T>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T>::ConstTensor features, T alpha, \
|
||||
typename TTypes<T>::Tensor activations); \
|
||||
void LeakyRelu<GPUDevice, T>::operator()(LeakyReluArgs args); \
|
||||
extern template struct LeakyRelu<GPUDevice, T>; \
|
||||
\
|
||||
template <> \
|
||||
|
@ -143,8 +143,8 @@ class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> {
|
||||
|
||||
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
|
||||
functor::LeakyRelu<Device, T> functor;
|
||||
functor(context->eigen_device<Device>(), input.flat<T>(), alpha_,
|
||||
output->flat<T>());
|
||||
functor({context->eigen_device<Device>(), input.flat<T>(), alpha_,
|
||||
output->flat<T>()});
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -98,11 +98,21 @@ struct LeakyRelu {
|
||||
//
|
||||
// features: any shape.
|
||||
// activations: same shape as "features".
|
||||
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
|
||||
T alpha, typename TTypes<T>::Tensor activations) {
|
||||
|
||||
// Need to bundle the args (to the LeakyRelu functor) within a struct
|
||||
// Not doing so leads to Eigen kernel args not getting populated
|
||||
// corretly for Eigen::half type (when building on the ROCM platform)
|
||||
struct LeakyReluArgs {
|
||||
const Device& d;
|
||||
typename TTypes<T>::ConstTensor features;
|
||||
T alpha;
|
||||
typename TTypes<T>::Tensor activations;
|
||||
};
|
||||
void operator()(LeakyReluArgs args) {
|
||||
// Note that alpha might be > 1 or < 0, so we don't use cwiseMax here.
|
||||
activations.device(d) =
|
||||
(features > static_cast<T>(0)).select(features, features * alpha);
|
||||
args.activations.device(args.d) =
|
||||
(args.features > static_cast<T>(0))
|
||||
.select(args.features, args.features * args.alpha);
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user