Adding these ops to enable the use of uint32 for indices in sparse matrices, since int32 has issues on GPU. Adding support for uint32 for:

1. VariableV2
2. Assign
3. Identity
4. TopKV2
5. Roll
6. Sub
7. RefSwitch/Switch

PiperOrigin-RevId: 298692513
Change-Id: I1209fb79ee9f8fca450b7a086a9dbc11945030ea
This commit is contained in:
A. Unique TensorFlower 2020-03-03 14:17:52 -08:00 committed by TensorFlower Gardener
parent a7ede3b86c
commit ddabed4285
12 changed files with 28 additions and 7 deletions

View File

@ -109,6 +109,8 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
REGISTER_GPU_SWITCH(uint64);
TF_CALL_variant(REGISTER_GPU_SWITCH);
TF_CALL_uint32(REGISTER_GPU_SWITCH);
TF_CALL_uint32(REGISTER_GPU_REF_SWITCH);
#undef REGISTER_CPU_SWITCH
#undef REGISTER_CPU_REF_SWITCH

View File

@ -19,7 +19,8 @@ limitations under the License.
namespace tensorflow {
namespace functor {
DEFINE_BINARY6(sub, Eigen::half, float, double, int64, complex64, complex128);
DEFINE_BINARY7(sub, Eigen::half, float, double, int64, complex64, complex128,
uint32);
} // namespace functor
} // namespace tensorflow

View File

@ -20,7 +20,8 @@ REGISTER8(BinaryOp, CPU, "Sub", functor::sub, float, Eigen::half, double, int32,
int64, bfloat16, complex64, complex128);
#if !defined(__ANDROID_TYPES_SLIM__)
// Sub op for int8, uint8, int16, uint16
REGISTER4(BinaryOp, CPU, "Sub", functor::sub, int8, uint8, int16, uint16);
REGISTER5(BinaryOp, CPU, "Sub", functor::sub, int8, uint8, int16, uint16,
uint32);
#else
// We only register the first type when we have multi-argument calls in the
// case where we're trying to reduce executable size, but it turns out that the
@ -29,8 +30,8 @@ REGISTER(BinaryOp, CPU, "Sub", functor::sub, int32);
#endif // __ANDROID_TYPES_SLIM__
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER6(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
complex64, complex128);
REGISTER7(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
complex64, complex128, uint32);
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel

View File

@ -68,6 +68,7 @@ TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_int32(DEFINE_GPU_KERNELS);
TF_CALL_int64(DEFINE_GPU_KERNELS);
TF_CALL_int8(DEFINE_GPU_KERNELS);
TF_CALL_uint32(DEFINE_GPU_KERNELS);
#undef DEFINE_GPU_KERNELS
} // end namespace tensorflow

View File

@ -97,6 +97,8 @@ typedef Eigen::SyclDevice SYCLDevice;
AssignOpT<CPUDevice, type>);
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
// uint32 not included in ALL_TYPES
TF_CALL_uint32(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
// quint16 not included in QUANTIZIED_TYPES
TF_CALL_quint16(REGISTER_KERNELS);
@ -112,6 +114,7 @@ TF_CALL_quint16(REGISTER_KERNELS);
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_int64(REGISTER_GPU_KERNELS);
TF_CALL_uint32(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -122,6 +122,7 @@ REGISTER_SYCL_HOST_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
REGISTER_GPU_KERNEL(Variant);
TF_CALL_uint32(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL

View File

@ -400,6 +400,7 @@ TF_CALL_int64(REGISTER_KERNEL);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
TF_CALL_complex64(REGISTER_KERNEL);
TF_CALL_complex128(REGISTER_KERNEL);
TF_CALL_uint32(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -96,6 +96,7 @@ TF_CALL_int64(DEFINE_GPU_SPECS);
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
TF_CALL_complex64(DEFINE_GPU_SPECS);
TF_CALL_complex128(DEFINE_GPU_SPECS);
TF_CALL_uint32(DEFINE_GPU_SPECS)
#undef DEFINE_GPU_SPECS
} // namespace functor

View File

@ -258,6 +258,7 @@ namespace functor {
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
TF_CALL_uint32(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC
@ -275,7 +276,7 @@ TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
TF_CALL_uint32(REGISTER_KERNELS)
#undef REGISTER_KERNELS
#endif // end GOOGLE_CUDA

View File

@ -23,6 +23,8 @@ namespace tensorflow {
using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, uint16>;
template struct functor::TopKFunctor<GPUDevice, uint32>;
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -236,6 +236,7 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_int64(REGISTER_GPU_KERNELS);
TF_CALL_uint32(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -435,8 +435,14 @@ Returns `x` + `y` element-wise.
)doc");
#endif // INTEL_MKL
REGISTER_OP("Sub").BINARY_MORE().SetShapeFn(
shape_inference::BroadcastBinaryOpShapeFn);
REGISTER_OP("Sub")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr(
"T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, "
"int64, complex64, complex128, uint32}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
REGISTER_OP("_MklSub")
.BINARY_FEWER()