Improve GPU dtype coverage of concat
PiperOrigin-RevId: 333461512 Change-Id: Ic0dd1e6f30ef321c4da63bd5da4b92a525a5f237
This commit is contained in:
parent
1e8cb572f2
commit
55b546de0e
tensorflow
core
framework
kernels
python/kernel_tests
@ -153,24 +153,25 @@ limitations under the License.
|
||||
#endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines
|
||||
|
||||
// Defines for sets of types.
|
||||
#define TF_CALL_INTEGRAL_TYPES(m) \
|
||||
TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_int32(m) \
|
||||
TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
|
||||
#define TF_CALL_INTEGRAL_TYPES_NO_INT32(m) \
|
||||
TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_uint16(m) \
|
||||
TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
|
||||
|
||||
#define TF_CALL_INTEGRAL_TYPES(m) \
|
||||
TF_CALL_INTEGRAL_TYPES_NO_INT32(m) TF_CALL_int32(m)
|
||||
|
||||
#define TF_CALL_FLOAT_TYPES(m) \
|
||||
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
|
||||
|
||||
#define TF_CALL_REAL_NUMBER_TYPES(m) \
|
||||
TF_CALL_INTEGRAL_TYPES(m) \
|
||||
TF_CALL_FLOAT_TYPES(m)
|
||||
TF_CALL_INTEGRAL_TYPES(m) TF_CALL_FLOAT_TYPES(m)
|
||||
|
||||
#define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \
|
||||
TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
|
||||
|
||||
#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
|
||||
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \
|
||||
TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_uint16(m) \
|
||||
TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
|
||||
#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
|
||||
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \
|
||||
TF_CALL_INTEGRAL_TYPES_NO_INT32(m)
|
||||
|
||||
#define TF_CALL_COMPLEX_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m)
|
||||
|
||||
|
@ -64,11 +64,8 @@ void ConcatGPU(
|
||||
inputs_flat, \
|
||||
Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);
|
||||
|
||||
TF_CALL_int32(REGISTER); // Needed for TensorLists.
|
||||
TF_CALL_int64(REGISTER);
|
||||
TF_CALL_int16(REGISTER);
|
||||
TF_CALL_INTEGRAL_TYPES(REGISTER); // int32 Needed for TensorLists.
|
||||
TF_CALL_bfloat16(REGISTER);
|
||||
TF_CALL_uint8(REGISTER);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER);
|
||||
#undef REGISTER
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -98,11 +98,8 @@ void ConcatGPU(
|
||||
inputs_flat, \
|
||||
Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);
|
||||
|
||||
TF_CALL_int32(REGISTER); // Needed for TensorLists.
|
||||
TF_CALL_int64(REGISTER);
|
||||
TF_CALL_int16(REGISTER);
|
||||
TF_CALL_INTEGRAL_TYPES(REGISTER); // int32 Needed for TensorLists.
|
||||
TF_CALL_bfloat16(REGISTER);
|
||||
TF_CALL_uint8(REGISTER);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER);
|
||||
|
||||
#undef REGISTER
|
||||
|
@ -66,11 +66,8 @@ void ConcatGPUImpl(const Eigen::GpuDevice& d,
|
||||
const GpuDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
|
||||
int split_size, typename TTypes<T, 2>::Matrix* output);
|
||||
|
||||
TF_CALL_int32(REGISTER); // Needed for TensorLists.
|
||||
TF_CALL_int64(REGISTER);
|
||||
TF_CALL_int16(REGISTER);
|
||||
TF_CALL_INTEGRAL_TYPES(REGISTER); // int32 Needed for TensorLists.
|
||||
TF_CALL_bfloat16(REGISTER);
|
||||
TF_CALL_uint8(REGISTER);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER);
|
||||
#undef REGISTER
|
||||
|
||||
|
@ -201,31 +201,19 @@ void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device,
|
||||
const GpuDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
|
||||
int split_size, typename TTypes<T, 2>::Matrix* output);
|
||||
|
||||
TF_CALL_int32(REGISTER_GPUCONCAT32); // Needed for TensorLists.
|
||||
TF_CALL_int64(REGISTER_GPUCONCAT32);
|
||||
TF_CALL_int16(REGISTER_GPUCONCAT32);
|
||||
TF_CALL_uint8(REGISTER_GPUCONCAT32);
|
||||
TF_CALL_INTEGRAL_TYPES(REGISTER_GPUCONCAT32); // int32 Needed for TensorLists.
|
||||
TF_CALL_bfloat16(REGISTER_GPUCONCAT32);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPUCONCAT32);
|
||||
|
||||
TF_CALL_int32(REGISTER_GPUCONCAT64); // Needed for TensorLists.
|
||||
TF_CALL_int64(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_int16(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_uint8(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_INTEGRAL_TYPES(REGISTER_GPUCONCAT64); // int32 Needed for TensorLists.
|
||||
TF_CALL_bfloat16(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPUCONCAT64);
|
||||
|
||||
TF_CALL_int32(REGISTER_GPU32); // Needed for TensorLists.
|
||||
TF_CALL_int64(REGISTER_GPU32);
|
||||
TF_CALL_int16(REGISTER_GPU32);
|
||||
TF_CALL_uint8(REGISTER_GPU32);
|
||||
TF_CALL_INTEGRAL_TYPES(REGISTER_GPU32); // int32 Needed for TensorLists.
|
||||
TF_CALL_bfloat16(REGISTER_GPU32);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU32);
|
||||
|
||||
TF_CALL_int32(REGISTER_GPU64); // Needed for TensorLists.
|
||||
TF_CALL_int64(REGISTER_GPU64);
|
||||
TF_CALL_int16(REGISTER_GPU64);
|
||||
TF_CALL_uint8(REGISTER_GPU64);
|
||||
TF_CALL_INTEGRAL_TYPES(REGISTER_GPU64); // int32 Needed for TensorLists.
|
||||
TF_CALL_bfloat16(REGISTER_GPU64);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU64);
|
||||
|
||||
|
@ -216,9 +216,8 @@ REGISTER_CONCAT(qint32);
|
||||
.HostMemory("axis"), \
|
||||
ConcatV2Op<GPUDevice, type>)
|
||||
|
||||
TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GPU);
|
||||
TF_CALL_bfloat16(REGISTER_GPU);
|
||||
TF_CALL_uint8(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
|
@ -645,6 +645,17 @@ class ConcatOpTest(test.TestCase):
|
||||
inp_tensors_placeholders, -2, output_shape=[2, 3],
|
||||
gather_indexes=[2, 0], feed_dict=feed_dict)
|
||||
|
||||
def testConcatDtype(self):
|
||||
for dtype in [dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64]:
|
||||
with test_util.use_gpu():
|
||||
t1 = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtype)
|
||||
t2 = constant_op.constant([[7, 8, 9], [10, 11, 12]], dtype=dtype)
|
||||
|
||||
c = gen_array_ops.concat_v2([t1, t2], 1)
|
||||
self.assertEqual([2, 6], c.get_shape().as_list())
|
||||
output = self.evaluate(c)
|
||||
self.assertAllEqual([[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], output)
|
||||
|
||||
def testConcatAxisType(self):
|
||||
for dtype in [dtypes.int32, dtypes.int64]:
|
||||
with test_util.use_gpu():
|
||||
|
Loading…
Reference in New Issue
Block a user