Merge pull request #35965 from ROCmSoftwarePlatform:google-upstream-relu-half

PiperOrigin-RevId: 305203074
Change-Id: Ia0f55a926a8c34225392f5d299dfd21949b051d7
This commit is contained in:
TensorFlower Gardener 2020-04-07 00:54:59 -07:00
commit 38797a1c8b
2 changed files with 9 additions and 15 deletions

View File

@ -142,13 +142,12 @@ namespace functor {
extern template struct SeluGrad<GPUDevice, T>; extern template struct SeluGrad<GPUDevice, T>;
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// TODO(rocm) : qint8 datatype currently not supported on the ROCm platform
template <> template <>
void Relu<GPUDevice, qint8>::operator()( void Relu<GPUDevice, qint8>::operator()(
const GPUDevice& d, typename TTypes<qint8>::ConstTensor features, const GPUDevice& d, typename TTypes<qint8>::ConstTensor features,
typename TTypes<qint8>::Tensor activations); typename TTypes<qint8>::Tensor activations);
extern template struct Relu<GPUDevice, qint8>; extern template struct Relu<GPUDevice, qint8>;
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
} // namespace functor } // namespace functor
@ -189,7 +188,6 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS #undef REGISTER_GPU_KERNELS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Device> template <typename Device>
class ReluOp<Device, qint8> class ReluOp<Device, qint8>
: public UnaryElementWiseOp<qint8, ReluOp<Device, qint8>> { : public UnaryElementWiseOp<qint8, ReluOp<Device, qint8>> {
@ -211,7 +209,6 @@ REGISTER_KERNEL_BUILDER(
Name("Relu").Device(DEVICE_GPU).TypeConstraint<qint8>("T"), Name("Relu").Device(DEVICE_GPU).TypeConstraint<qint8>("T"),
ReluOp<GPUDevice, qint8>); ReluOp<GPUDevice, qint8>);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL

View File

@ -26,16 +26,17 @@ limitations under the License.
#include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/gpu_launch_config.h" #include "tensorflow/core/util/gpu_launch_config.h"
#if TENSORFLOW_USE_ROCM
#include "rocm/include/hip/hip_fp16.h"
typedef __half2 half2;
#endif
namespace tensorflow { namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice; typedef Eigen::GpuDevice GPUDevice;
namespace functor { namespace functor {
#if GOOGLE_CUDA
// TODO(rocm): disabling this code on the ROCm platform since the references
// to `half2` are leading to compile errors.
// This kernel computes ReluGrad by processing one half2, two fp16, at a time. // This kernel computes ReluGrad by processing one half2, two fp16, at a time.
// It effectively does: backdrops = (feature > 0) ? gradient : 0 // It effectively does: backdrops = (feature > 0) ? gradient : 0
// It also tries to use native half2 primitives as much as possible. // It also tries to use native half2 primitives as much as possible.
@ -65,8 +66,9 @@ __global__ void ReluGradHalfKernel(const Eigen::half* __restrict__ gradient,
// Fall back: convert half2 to float2 for processing. // Fall back: convert half2 to float2 for processing.
float2 feature_f2 = __half22float2(feature_h2); float2 feature_f2 = __half22float2(feature_h2);
float2 gradient_f2 = __half22float2(gradient_h2); float2 gradient_f2 = __half22float2(gradient_h2);
float2 backprop_f2 = make_float2((feature_f2.x > 0) ? gradient_f2.x : 0, float2 backprop_f2 =
(feature_f2.y > 0) ? gradient_f2.y : 0); make_float2((feature_f2.x > 0.0f) ? float(gradient_f2.x) : 0.0f,
(feature_f2.y > 0.0f) ? float(gradient_f2.y) : 0.0f);
// Convert back to half2. // Convert back to half2.
half2 backprop_h2 = __float22half2_rn(backprop_f2); half2 backprop_h2 = __float22half2_rn(backprop_f2);
#endif #endif
@ -117,9 +119,7 @@ struct ReluGrad<Device, Eigen::half> {
d.stream(), gradient.data(), feature.data(), backprop.data(), count)); d.stream(), gradient.data(), feature.data(), backprop.data(), count));
} }
}; };
#endif // GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
__global__ void Relu_int8x4_kernel(int vect_count, __global__ void Relu_int8x4_kernel(int vect_count,
const int32* __restrict__ input, const int32* __restrict__ input,
int32* __restrict__ output) { int32* __restrict__ output) {
@ -160,7 +160,6 @@ struct Relu<Device, qint8> {
reinterpret_cast<int32*>(output.data()))); reinterpret_cast<int32*>(output.data())));
} }
}; };
#endif // GOOGLE_CUDA
} // namespace functor } // namespace functor
@ -178,9 +177,7 @@ struct Relu<Device, qint8> {
template struct functor::SeluGrad<GPUDevice, T>; template struct functor::SeluGrad<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template struct functor::Relu<GPUDevice, qint8>; template struct functor::Relu<GPUDevice, qint8>;
#endif // GOOGLE_CUDA
} // end namespace tensorflow } // end namespace tensorflow