Add bool type supports for GPU kernels (#11927)
* Add bool type supports for GPU kernels * Add bool type test codes for GPU kernels
This commit is contained in:
parent
de01be952d
commit
881de45c2d
@ -117,6 +117,7 @@ TF_CALL_complex64(REGISTER);
|
||||
TF_CALL_complex128(REGISTER);
|
||||
TF_CALL_int64(REGISTER);
|
||||
REGISTER(bfloat16);
|
||||
REGISTER(bool);
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
|
@ -203,24 +203,28 @@ TF_CALL_complex64(REGISTER_GPUCONCAT32);
|
||||
TF_CALL_complex128(REGISTER_GPUCONCAT32);
|
||||
TF_CALL_int64(REGISTER_GPUCONCAT32);
|
||||
REGISTER_GPUCONCAT32(bfloat16);
|
||||
REGISTER_GPUCONCAT32(bool);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_complex64(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_complex128(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_int64(REGISTER_GPUCONCAT64);
|
||||
REGISTER_GPUCONCAT64(bfloat16);
|
||||
REGISTER_GPUCONCAT64(bool);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU32);
|
||||
TF_CALL_complex64(REGISTER_GPU32);
|
||||
TF_CALL_complex128(REGISTER_GPU32);
|
||||
TF_CALL_int64(REGISTER_GPU32);
|
||||
REGISTER_GPU32(bfloat16);
|
||||
REGISTER_GPU32(bool);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU64);
|
||||
TF_CALL_complex64(REGISTER_GPU64);
|
||||
TF_CALL_complex128(REGISTER_GPU64);
|
||||
TF_CALL_int64(REGISTER_GPU64);
|
||||
REGISTER_GPU64(bfloat16);
|
||||
REGISTER_GPU64(bool);
|
||||
|
||||
#undef REGISTER_GPUCONCAT32
|
||||
#undef REGISTER_GPUCONCAT64
|
||||
|
@ -196,6 +196,7 @@ REGISTER_GPU(bfloat16);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
REGISTER_GPU(bool);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
|
@ -158,6 +158,7 @@ REGISTER_PACK(string);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
REGISTER_GPU(bool);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
|
@ -32,6 +32,7 @@ REGISTER_KERNEL_BUILDER(Name("Reshape")
|
||||
.TypeConstraint<int32>("Tshape"), \
|
||||
ReshapeOp);
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
|
||||
REGISTER_GPU_KERNEL(bool);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
|
@ -138,6 +138,7 @@ class ConcatOpTest(test.TestCase):
|
||||
self.assertAllClose(result[ind], params[p[i]], 0.01)
|
||||
|
||||
def testRandom(self):
|
||||
self._testRandom(dtypes.bool)
|
||||
self._testRandom(dtypes.float32)
|
||||
self._testRandom(dtypes.int16)
|
||||
self._testRandom(dtypes.int32)
|
||||
|
@ -41,6 +41,10 @@ class ReshapeTest(test.TestCase):
|
||||
self._testReshape(x, y, False)
|
||||
self._testReshape(x, y, True)
|
||||
|
||||
def testBoolBasic(self):
|
||||
x = np.arange(1., 7.).reshape([1, 6]) > 3
|
||||
self._testBothReshape(x, [2, 3])
|
||||
|
||||
def testFloatBasic(self):
|
||||
x = np.arange(1., 7.).reshape([1, 6]).astype(np.float32)
|
||||
self._testBothReshape(x, [2, 3])
|
||||
|
@ -45,7 +45,7 @@ class StackOpTest(test.TestCase):
|
||||
np.random.seed(7)
|
||||
with self.test_session(use_gpu=True):
|
||||
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
|
||||
for dtype in [np.float32, np.int32, np.int64]:
|
||||
for dtype in [np.bool, np.float32, np.int32, np.int64]:
|
||||
data = np.random.randn(*shape).astype(dtype)
|
||||
# Convert [data[0], data[1], ...] separately to tensorflow
|
||||
# TODO(irving): Remove list() once we handle maps correctly
|
||||
@ -67,7 +67,7 @@ class StackOpTest(test.TestCase):
|
||||
np.random.seed(7)
|
||||
with self.test_session(use_gpu=True):
|
||||
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
|
||||
for dtype in [np.float32, np.int32, np.int64]:
|
||||
for dtype in [np.bool, np.float32, np.int32, np.int64]:
|
||||
data = np.random.randn(*shape).astype(dtype)
|
||||
# Pack back into a single tensorflow tensor directly using np array
|
||||
c = array_ops.stack(data)
|
||||
|
Loading…
Reference in New Issue
Block a user