Add the ROCm GPU kernel for RELU int8x4
This commit is contained in:
parent
057cf24986
commit
d1f1d78b86
@ -143,7 +143,7 @@ namespace functor {
|
|||||||
typename TTypes<T>::Tensor backprops); \
|
typename TTypes<T>::Tensor backprops); \
|
||||||
extern template struct SeluGrad<GPUDevice, T>;
|
extern template struct SeluGrad<GPUDevice, T>;
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
// TODO(rocm) : qint8 datatype currently not supported on the ROCm platform
|
// TODO(rocm) : qint8 datatype currently not supported on the ROCm platform
|
||||||
template <>
|
template <>
|
||||||
void Relu<GPUDevice, qint8>::operator()(
|
void Relu<GPUDevice, qint8>::operator()(
|
||||||
@ -191,7 +191,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
|||||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||||
#undef REGISTER_GPU_KERNELS
|
#undef REGISTER_GPU_KERNELS
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
template <typename Device>
|
template <typename Device>
|
||||||
class ReluOp<Device, qint8>
|
class ReluOp<Device, qint8>
|
||||||
: public UnaryElementWiseOp<qint8, ReluOp<Device, qint8>> {
|
: public UnaryElementWiseOp<qint8, ReluOp<Device, qint8>> {
|
||||||
|
@ -119,12 +119,22 @@ struct ReluGrad<Device, Eigen::half> {
|
|||||||
};
|
};
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
__global__ void Relu_int8x4_kernel(int vect_count,
|
__global__ void Relu_int8x4_kernel(int vect_count,
|
||||||
const int32* __restrict__ input,
|
const int32* __restrict__ input,
|
||||||
int32* __restrict__ output) {
|
int32* __restrict__ output) {
|
||||||
CUDA_1D_KERNEL_LOOP(index, vect_count) {
|
CUDA_1D_KERNEL_LOOP(index, vect_count) {
|
||||||
|
#if GOOGLE_CUDA
|
||||||
output[index] = __vmaxs4(input[index], 0);
|
output[index] = __vmaxs4(input[index], 0);
|
||||||
|
#else
|
||||||
|
uint32 signs = (~input[index]) & 0x80808080;
|
||||||
|
signs = signs>>7;
|
||||||
|
signs |= signs<<1;
|
||||||
|
signs |= signs<<2;
|
||||||
|
signs |= signs<<4;
|
||||||
|
signs &= 0x7f7f7f7f;
|
||||||
|
output[index] = input[index] & signs;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,7 +178,7 @@ struct Relu<Device, qint8> {
|
|||||||
template struct functor::SeluGrad<GPUDevice, T>;
|
template struct functor::SeluGrad<GPUDevice, T>;
|
||||||
|
|
||||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
|
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
template struct functor::Relu<GPUDevice, qint8>;
|
template struct functor::Relu<GPUDevice, qint8>;
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
@ -79,8 +79,6 @@ class ReluTest(test.TestCase):
|
|||||||
def testReluInt8x4GoodShape(self):
|
def testReluInt8x4GoodShape(self):
|
||||||
if not test.is_gpu_available(cuda_only=True):
|
if not test.is_gpu_available(cuda_only=True):
|
||||||
self.skipTest("No GPU available")
|
self.skipTest("No GPU available")
|
||||||
if test.is_built_with_rocm():
|
|
||||||
self.skipTest("ROCm does not support int8x4 type")
|
|
||||||
inputs = np.array([[-50, 7, 23, 0], [-1, -5, 6, 11]])
|
inputs = np.array([[-50, 7, 23, 0], [-1, -5, 6, 11]])
|
||||||
np_relu = self._npRelu(inputs)
|
np_relu = self._npRelu(inputs)
|
||||||
tf_relu = nn_ops.relu(constant_op.constant(inputs, dtypes.qint8))
|
tf_relu = nn_ops.relu(constant_op.constant(inputs, dtypes.qint8))
|
||||||
@ -91,8 +89,6 @@ class ReluTest(test.TestCase):
|
|||||||
def testReluInt8x4BadShape(self):
|
def testReluInt8x4BadShape(self):
|
||||||
if not test.is_gpu_available(cuda_only=True):
|
if not test.is_gpu_available(cuda_only=True):
|
||||||
self.skipTest("No GPU available")
|
self.skipTest("No GPU available")
|
||||||
if test.is_built_with_rocm():
|
|
||||||
self.skipTest("ROCm does not support int8x4 type")
|
|
||||||
inputs = constant_op.constant(
|
inputs = constant_op.constant(
|
||||||
np.array([[-50, 7, 23], [0, 1, -5], [6, -2, 11]]), dtypes.qint8)
|
np.array([[-50, 7, 23], [0, 1, -5], [6, -2, 11]]), dtypes.qint8)
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
|
Loading…
Reference in New Issue
Block a user