Merge pull request #35965 from ROCmSoftwarePlatform:google-upstream-relu-half
PiperOrigin-RevId: 305203074 Change-Id: Ia0f55a926a8c34225392f5d299dfd21949b051d7
This commit is contained in:
commit
38797a1c8b
@ -142,13 +142,12 @@ namespace functor {
|
||||
extern template struct SeluGrad<GPUDevice, T>;
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// TODO(rocm) : qint8 datatype currently not supported on the ROCm platform
|
||||
template <>
|
||||
void Relu<GPUDevice, qint8>::operator()(
|
||||
const GPUDevice& d, typename TTypes<qint8>::ConstTensor features,
|
||||
typename TTypes<qint8>::Tensor activations);
|
||||
extern template struct Relu<GPUDevice, qint8>;
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
} // namespace functor
|
||||
@ -189,7 +188,6 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
template <typename Device>
|
||||
class ReluOp<Device, qint8>
|
||||
: public UnaryElementWiseOp<qint8, ReluOp<Device, qint8>> {
|
||||
@ -211,7 +209,6 @@ REGISTER_KERNEL_BUILDER(
|
||||
Name("Relu").Device(DEVICE_GPU).TypeConstraint<qint8>("T"),
|
||||
ReluOp<GPUDevice, qint8>);
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
|
@ -26,16 +26,17 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
#include "tensorflow/core/util/gpu_launch_config.h"
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hip/hip_fp16.h"
|
||||
typedef __half2 half2;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace functor {
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
// TODO(rocm): disabling this code on the ROCm platform since the references
|
||||
// to `half2` are leading to compile errors.
|
||||
|
||||
// This kernel computes ReluGrad by processing one half2, two fp16, at a time.
|
||||
// It effectively does: backdrops = (feature > 0) ? gradient : 0
|
||||
// It also tries to use native half2 primitives as much as possible.
|
||||
@ -65,8 +66,9 @@ __global__ void ReluGradHalfKernel(const Eigen::half* __restrict__ gradient,
|
||||
// Fall back: convert half2 to float2 for processing.
|
||||
float2 feature_f2 = __half22float2(feature_h2);
|
||||
float2 gradient_f2 = __half22float2(gradient_h2);
|
||||
float2 backprop_f2 = make_float2((feature_f2.x > 0) ? gradient_f2.x : 0,
|
||||
(feature_f2.y > 0) ? gradient_f2.y : 0);
|
||||
float2 backprop_f2 =
|
||||
make_float2((feature_f2.x > 0.0f) ? float(gradient_f2.x) : 0.0f,
|
||||
(feature_f2.y > 0.0f) ? float(gradient_f2.y) : 0.0f);
|
||||
// Convert back to half2.
|
||||
half2 backprop_h2 = __float22half2_rn(backprop_f2);
|
||||
#endif
|
||||
@ -117,9 +119,7 @@ struct ReluGrad<Device, Eigen::half> {
|
||||
d.stream(), gradient.data(), feature.data(), backprop.data(), count));
|
||||
}
|
||||
};
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
__global__ void Relu_int8x4_kernel(int vect_count,
|
||||
const int32* __restrict__ input,
|
||||
int32* __restrict__ output) {
|
||||
@ -160,7 +160,6 @@ struct Relu<Device, qint8> {
|
||||
reinterpret_cast<int32*>(output.data())));
|
||||
}
|
||||
};
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace functor
|
||||
|
||||
@ -178,9 +177,7 @@ struct Relu<Device, qint8> {
|
||||
template struct functor::SeluGrad<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
template struct functor::Relu<GPUDevice, qint8>;
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user