Add the ROCm GPU kernel for RELU int8x4
This commit is contained in:
parent
057cf24986
commit
d1f1d78b86
@ -143,7 +143,7 @@ namespace functor {
|
||||
typename TTypes<T>::Tensor backprops); \
|
||||
extern template struct SeluGrad<GPUDevice, T>;
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// TODO(rocm) : qint8 datatype currently not supported on the ROCm platform
|
||||
template <>
|
||||
void Relu<GPUDevice, qint8>::operator()(
|
||||
@ -191,7 +191,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
template <typename Device>
|
||||
class ReluOp<Device, qint8>
|
||||
: public UnaryElementWiseOp<qint8, ReluOp<Device, qint8>> {
|
||||
|
@ -119,12 +119,22 @@ struct ReluGrad<Device, Eigen::half> {
|
||||
};
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
__global__ void Relu_int8x4_kernel(int vect_count,
|
||||
const int32* __restrict__ input,
|
||||
int32* __restrict__ output) {
|
||||
CUDA_1D_KERNEL_LOOP(index, vect_count) {
|
||||
#if GOOGLE_CUDA
|
||||
output[index] = __vmaxs4(input[index], 0);
|
||||
#else
|
||||
uint32 signs = (~input[index]) & 0x80808080;
|
||||
signs = signs>>7;
|
||||
signs |= signs<<1;
|
||||
signs |= signs<<2;
|
||||
signs |= signs<<4;
|
||||
signs &= 0x7f7f7f7f;
|
||||
output[index] = input[index] & signs;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -168,7 +178,7 @@ struct Relu<Device, qint8> {
|
||||
template struct functor::SeluGrad<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
template struct functor::Relu<GPUDevice, qint8>;
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
|
@ -79,8 +79,6 @@ class ReluTest(test.TestCase):
|
||||
def testReluInt8x4GoodShape(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
self.skipTest("No GPU available")
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("ROCm does not support int8x4 type")
|
||||
inputs = np.array([[-50, 7, 23, 0], [-1, -5, 6, 11]])
|
||||
np_relu = self._npRelu(inputs)
|
||||
tf_relu = nn_ops.relu(constant_op.constant(inputs, dtypes.qint8))
|
||||
@ -91,8 +89,6 @@ class ReluTest(test.TestCase):
|
||||
def testReluInt8x4BadShape(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
self.skipTest("No GPU available")
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("ROCm does not support int8x4 type")
|
||||
inputs = constant_op.constant(
|
||||
np.array([[-50, 7, 23], [0, 1, -5], [6, -2, 11]]), dtypes.qint8)
|
||||
with self.assertRaisesRegexp(
|
||||
|
Loading…
Reference in New Issue
Block a user