Adding these ops to enable the use of uint32 for indices in sparse matrices, since int32 has issues on GPU. Adding support for uint32 for:
1. VariableV2 2. Assign 3. Identity 4. TopKV2 5. Roll 6. Sub 7. RefSwitch/Switch PiperOrigin-RevId: 298692513 Change-Id: I1209fb79ee9f8fca450b7a086a9dbc11945030ea
This commit is contained in:
parent
a7ede3b86c
commit
ddabed4285
@ -109,6 +109,8 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
|
||||
REGISTER_GPU_SWITCH(uint64);
|
||||
TF_CALL_variant(REGISTER_GPU_SWITCH);
|
||||
TF_CALL_uint32(REGISTER_GPU_SWITCH);
|
||||
TF_CALL_uint32(REGISTER_GPU_REF_SWITCH);
|
||||
|
||||
#undef REGISTER_CPU_SWITCH
|
||||
#undef REGISTER_CPU_REF_SWITCH
|
||||
|
@ -19,7 +19,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
DEFINE_BINARY6(sub, Eigen::half, float, double, int64, complex64, complex128);
|
||||
DEFINE_BINARY7(sub, Eigen::half, float, double, int64, complex64, complex128,
|
||||
uint32);
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -20,7 +20,8 @@ REGISTER8(BinaryOp, CPU, "Sub", functor::sub, float, Eigen::half, double, int32,
|
||||
int64, bfloat16, complex64, complex128);
|
||||
#if !defined(__ANDROID_TYPES_SLIM__)
|
||||
// Sub op for int8, uint8, int16, uint16
|
||||
REGISTER4(BinaryOp, CPU, "Sub", functor::sub, int8, uint8, int16, uint16);
|
||||
REGISTER5(BinaryOp, CPU, "Sub", functor::sub, int8, uint8, int16, uint16,
|
||||
uint32);
|
||||
#else
|
||||
// We only register the first type when we have multi-argument calls in the
|
||||
// case where we're trying to reduce executable size, but it turns out that the
|
||||
@ -29,8 +30,8 @@ REGISTER(BinaryOp, CPU, "Sub", functor::sub, int32);
|
||||
#endif // __ANDROID_TYPES_SLIM__
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER6(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
|
||||
complex64, complex128);
|
||||
REGISTER7(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
|
||||
complex64, complex128, uint32);
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
|
@ -68,6 +68,7 @@ TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_int32(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_int64(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_int8(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_uint32(DEFINE_GPU_KERNELS);
|
||||
#undef DEFINE_GPU_KERNELS
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -97,6 +97,8 @@ typedef Eigen::SyclDevice SYCLDevice;
|
||||
AssignOpT<CPUDevice, type>);
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
||||
// uint32 not included in ALL_TYPES
|
||||
TF_CALL_uint32(REGISTER_KERNELS);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
||||
// quint16 not included in QUANTIZIED_TYPES
|
||||
TF_CALL_quint16(REGISTER_KERNELS);
|
||||
@ -112,6 +114,7 @@ TF_CALL_quint16(REGISTER_KERNELS);
|
||||
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_int64(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_uint32(REGISTER_GPU_KERNELS);
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
@ -122,6 +122,7 @@ REGISTER_SYCL_HOST_KERNEL(bool);
|
||||
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
|
||||
REGISTER_GPU_KERNEL(Variant);
|
||||
TF_CALL_uint32(REGISTER_GPU_KERNEL);
|
||||
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
|
@ -400,6 +400,7 @@ TF_CALL_int64(REGISTER_KERNEL);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
|
||||
TF_CALL_complex64(REGISTER_KERNEL);
|
||||
TF_CALL_complex128(REGISTER_KERNEL);
|
||||
TF_CALL_uint32(REGISTER_KERNEL);
|
||||
|
||||
#undef REGISTER_KERNEL
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -96,6 +96,7 @@ TF_CALL_int64(DEFINE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPECS);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPECS);
|
||||
TF_CALL_uint32(DEFINE_GPU_SPECS)
|
||||
|
||||
#undef DEFINE_GPU_SPECS
|
||||
} // namespace functor
|
||||
|
@ -258,6 +258,7 @@ namespace functor {
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_uint32(DECLARE_GPU_SPEC);
|
||||
|
||||
#undef DECLARE_GPU_SPEC
|
||||
|
||||
@ -275,7 +276,7 @@ TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
|
||||
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
|
||||
|
||||
TF_CALL_uint32(REGISTER_KERNELS)
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
#endif // end GOOGLE_CUDA
|
||||
|
@ -23,6 +23,8 @@ namespace tensorflow {
|
||||
using Eigen::GpuDevice;
|
||||
|
||||
template struct functor::TopKFunctor<GPUDevice, uint16>;
|
||||
template struct functor::TopKFunctor<GPUDevice, uint32>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -236,6 +236,7 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_int64(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_uint32(REGISTER_GPU_KERNELS);
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
@ -435,8 +435,14 @@ Returns `x` + `y` element-wise.
|
||||
)doc");
|
||||
#endif // INTEL_MKL
|
||||
|
||||
REGISTER_OP("Sub").BINARY_MORE().SetShapeFn(
|
||||
shape_inference::BroadcastBinaryOpShapeFn);
|
||||
REGISTER_OP("Sub")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
.Output("z: T")
|
||||
.Attr(
|
||||
"T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, "
|
||||
"int64, complex64, complex128, uint32}")
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
|
||||
|
||||
REGISTER_OP("_MklSub")
|
||||
.BINARY_FEWER()
|
||||
|
Loading…
Reference in New Issue
Block a user