Improve GPU dtype coverage of concat

PiperOrigin-RevId: 333461512
Change-Id: Ic0dd1e6f30ef321c4da63bd5da4b92a525a5f237
This commit is contained in:
Gaurav Jain 2020-09-24 00:48:21 -07:00 committed by TensorFlower Gardener
parent 1e8cb572f2
commit 55b546de0e
7 changed files with 29 additions and 39 deletions

View File

@ -153,24 +153,25 @@ limitations under the License.
#endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines
// Defines for sets of types.
#define TF_CALL_INTEGRAL_TYPES(m) \
TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_int32(m) \
TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
#define TF_CALL_INTEGRAL_TYPES_NO_INT32(m) \
TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_uint16(m) \
TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
#define TF_CALL_INTEGRAL_TYPES(m) \
TF_CALL_INTEGRAL_TYPES_NO_INT32(m) TF_CALL_int32(m)
#define TF_CALL_FLOAT_TYPES(m) \
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
#define TF_CALL_REAL_NUMBER_TYPES(m) \
TF_CALL_INTEGRAL_TYPES(m) \
TF_CALL_FLOAT_TYPES(m)
TF_CALL_INTEGRAL_TYPES(m) TF_CALL_FLOAT_TYPES(m)
#define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \
TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \
TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_uint16(m) \
TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \
TF_CALL_INTEGRAL_TYPES_NO_INT32(m)
#define TF_CALL_COMPLEX_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m)

View File

@ -64,11 +64,8 @@ void ConcatGPU(
inputs_flat, \
Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);
TF_CALL_int32(REGISTER); // Needed for TensorLists.
TF_CALL_int64(REGISTER);
TF_CALL_int16(REGISTER);
TF_CALL_INTEGRAL_TYPES(REGISTER); // int32 Needed for TensorLists.
TF_CALL_bfloat16(REGISTER);
TF_CALL_uint8(REGISTER);
TF_CALL_GPU_ALL_TYPES(REGISTER);
#undef REGISTER
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -98,11 +98,8 @@ void ConcatGPU(
inputs_flat, \
Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);
TF_CALL_int32(REGISTER); // Needed for TensorLists.
TF_CALL_int64(REGISTER);
TF_CALL_int16(REGISTER);
TF_CALL_INTEGRAL_TYPES(REGISTER); // int32 Needed for TensorLists.
TF_CALL_bfloat16(REGISTER);
TF_CALL_uint8(REGISTER);
TF_CALL_GPU_ALL_TYPES(REGISTER);
#undef REGISTER

View File

@ -66,11 +66,8 @@ void ConcatGPUImpl(const Eigen::GpuDevice& d,
const GpuDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
int split_size, typename TTypes<T, 2>::Matrix* output);
TF_CALL_int32(REGISTER); // Needed for TensorLists.
TF_CALL_int64(REGISTER);
TF_CALL_int16(REGISTER);
TF_CALL_INTEGRAL_TYPES(REGISTER); // int32 Needed for TensorLists.
TF_CALL_bfloat16(REGISTER);
TF_CALL_uint8(REGISTER);
TF_CALL_GPU_ALL_TYPES(REGISTER);
#undef REGISTER

View File

@ -201,31 +201,19 @@ void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device,
const GpuDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
int split_size, typename TTypes<T, 2>::Matrix* output);
TF_CALL_int32(REGISTER_GPUCONCAT32); // Needed for TensorLists.
TF_CALL_int64(REGISTER_GPUCONCAT32);
TF_CALL_int16(REGISTER_GPUCONCAT32);
TF_CALL_uint8(REGISTER_GPUCONCAT32);
TF_CALL_INTEGRAL_TYPES(REGISTER_GPUCONCAT32); // int32 Needed for TensorLists.
TF_CALL_bfloat16(REGISTER_GPUCONCAT32);
TF_CALL_GPU_ALL_TYPES(REGISTER_GPUCONCAT32);
TF_CALL_int32(REGISTER_GPUCONCAT64); // Needed for TensorLists.
TF_CALL_int64(REGISTER_GPUCONCAT64);
TF_CALL_int16(REGISTER_GPUCONCAT64);
TF_CALL_uint8(REGISTER_GPUCONCAT64);
TF_CALL_INTEGRAL_TYPES(REGISTER_GPUCONCAT64); // int32 Needed for TensorLists.
TF_CALL_bfloat16(REGISTER_GPUCONCAT64);
TF_CALL_GPU_ALL_TYPES(REGISTER_GPUCONCAT64);
TF_CALL_int32(REGISTER_GPU32); // Needed for TensorLists.
TF_CALL_int64(REGISTER_GPU32);
TF_CALL_int16(REGISTER_GPU32);
TF_CALL_uint8(REGISTER_GPU32);
TF_CALL_INTEGRAL_TYPES(REGISTER_GPU32); // int32 Needed for TensorLists.
TF_CALL_bfloat16(REGISTER_GPU32);
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU32);
TF_CALL_int32(REGISTER_GPU64); // Needed for TensorLists.
TF_CALL_int64(REGISTER_GPU64);
TF_CALL_int16(REGISTER_GPU64);
TF_CALL_uint8(REGISTER_GPU64);
TF_CALL_INTEGRAL_TYPES(REGISTER_GPU64); // int32 Needed for TensorLists.
TF_CALL_bfloat16(REGISTER_GPU64);
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU64);

View File

@ -216,9 +216,8 @@ REGISTER_CONCAT(qint32);
.HostMemory("axis"), \
ConcatV2Op<GPUDevice, type>)
TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GPU);
TF_CALL_bfloat16(REGISTER_GPU);
TF_CALL_uint8(REGISTER_GPU);
TF_CALL_int64(REGISTER_GPU);
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU);
#undef REGISTER_GPU

View File

@ -645,6 +645,17 @@ class ConcatOpTest(test.TestCase):
inp_tensors_placeholders, -2, output_shape=[2, 3],
gather_indexes=[2, 0], feed_dict=feed_dict)
def testConcatDtype(self):
for dtype in [dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64]:
with test_util.use_gpu():
t1 = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtype)
t2 = constant_op.constant([[7, 8, 9], [10, 11, 12]], dtype=dtype)
c = gen_array_ops.concat_v2([t1, t2], 1)
self.assertEqual([2, 6], c.get_shape().as_list())
output = self.evaluate(c)
self.assertAllEqual([[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], output)
def testConcatAxisType(self):
for dtype in [dtypes.int32, dtypes.int64]:
with test_util.use_gpu():