working around a known gcc/hcc interface kernel args passing bug
This commit is contained in:
parent
98ab95f7d9
commit
4a34c4e3b0
@ -105,9 +105,7 @@ namespace functor {
|
|||||||
extern template struct Relu6Grad<GPUDevice, T>; \
|
extern template struct Relu6Grad<GPUDevice, T>; \
|
||||||
\
|
\
|
||||||
template <> \
|
template <> \
|
||||||
void LeakyRelu<GPUDevice, T>::operator()( \
|
void LeakyRelu<GPUDevice, T>::operator()(LeakyReluArgs args); \
|
||||||
const GPUDevice& d, typename TTypes<T>::ConstTensor features, T alpha, \
|
|
||||||
typename TTypes<T>::Tensor activations); \
|
|
||||||
extern template struct LeakyRelu<GPUDevice, T>; \
|
extern template struct LeakyRelu<GPUDevice, T>; \
|
||||||
\
|
\
|
||||||
template <> \
|
template <> \
|
||||||
|
@ -143,8 +143,8 @@ class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> {
|
|||||||
|
|
||||||
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
|
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
|
||||||
functor::LeakyRelu<Device, T> functor;
|
functor::LeakyRelu<Device, T> functor;
|
||||||
functor(context->eigen_device<Device>(), input.flat<T>(), alpha_,
|
functor({context->eigen_device<Device>(), input.flat<T>(), alpha_,
|
||||||
output->flat<T>());
|
output->flat<T>()});
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -98,11 +98,21 @@ struct LeakyRelu {
|
|||||||
//
|
//
|
||||||
// features: any shape.
|
// features: any shape.
|
||||||
// activations: same shape as "features".
|
// activations: same shape as "features".
|
||||||
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
|
|
||||||
T alpha, typename TTypes<T>::Tensor activations) {
|
// Need to bundle the args (to the LeakyRelu functor) within a struct
|
||||||
|
// Not doing so leads to Eigen kernel args not getting populated
|
||||||
|
// corretly for Eigen::half type (when building on the ROCM platform)
|
||||||
|
struct LeakyReluArgs {
|
||||||
|
const Device& d;
|
||||||
|
typename TTypes<T>::ConstTensor features;
|
||||||
|
T alpha;
|
||||||
|
typename TTypes<T>::Tensor activations;
|
||||||
|
};
|
||||||
|
void operator()(LeakyReluArgs args) {
|
||||||
// Note that alpha might be > 1 or < 0, so we don't use cwiseMax here.
|
// Note that alpha might be > 1 or < 0, so we don't use cwiseMax here.
|
||||||
activations.device(d) =
|
args.activations.device(args.d) =
|
||||||
(features > static_cast<T>(0)).select(features, features * alpha);
|
(args.features > static_cast<T>(0))
|
||||||
|
.select(args.features, args.features * args.alpha);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user