Adding these ops to enable the use of uint32 for indices in sparse matrices, since int32 has issues on GPU. Adding support for uint32 for:
1. VariableV2 2. Assign 3. Identity 4. TopKV2 5. Roll 6. Sub 7. RefSwitch/Switch PiperOrigin-RevId: 298692513 Change-Id: I1209fb79ee9f8fca450b7a086a9dbc11945030ea
This commit is contained in:
parent
a7ede3b86c
commit
ddabed4285
@ -109,6 +109,8 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
|
|||||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
|
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
|
||||||
REGISTER_GPU_SWITCH(uint64);
|
REGISTER_GPU_SWITCH(uint64);
|
||||||
TF_CALL_variant(REGISTER_GPU_SWITCH);
|
TF_CALL_variant(REGISTER_GPU_SWITCH);
|
||||||
|
TF_CALL_uint32(REGISTER_GPU_SWITCH);
|
||||||
|
TF_CALL_uint32(REGISTER_GPU_REF_SWITCH);
|
||||||
|
|
||||||
#undef REGISTER_CPU_SWITCH
|
#undef REGISTER_CPU_SWITCH
|
||||||
#undef REGISTER_CPU_REF_SWITCH
|
#undef REGISTER_CPU_REF_SWITCH
|
||||||
|
@ -19,7 +19,8 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace functor {
|
namespace functor {
|
||||||
DEFINE_BINARY6(sub, Eigen::half, float, double, int64, complex64, complex128);
|
DEFINE_BINARY7(sub, Eigen::half, float, double, int64, complex64, complex128,
|
||||||
|
uint32);
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -20,7 +20,8 @@ REGISTER8(BinaryOp, CPU, "Sub", functor::sub, float, Eigen::half, double, int32,
|
|||||||
int64, bfloat16, complex64, complex128);
|
int64, bfloat16, complex64, complex128);
|
||||||
#if !defined(__ANDROID_TYPES_SLIM__)
|
#if !defined(__ANDROID_TYPES_SLIM__)
|
||||||
// Sub op for int8, uint8, int16, uint16
|
// Sub op for int8, uint8, int16, uint16
|
||||||
REGISTER4(BinaryOp, CPU, "Sub", functor::sub, int8, uint8, int16, uint16);
|
REGISTER5(BinaryOp, CPU, "Sub", functor::sub, int8, uint8, int16, uint16,
|
||||||
|
uint32);
|
||||||
#else
|
#else
|
||||||
// We only register the first type when we have multi-argument calls in the
|
// We only register the first type when we have multi-argument calls in the
|
||||||
// case where we're trying to reduce executable size, but it turns out that the
|
// case where we're trying to reduce executable size, but it turns out that the
|
||||||
@ -29,8 +30,8 @@ REGISTER(BinaryOp, CPU, "Sub", functor::sub, int32);
|
|||||||
#endif // __ANDROID_TYPES_SLIM__
|
#endif // __ANDROID_TYPES_SLIM__
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
REGISTER6(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
|
REGISTER7(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
|
||||||
complex64, complex128);
|
complex64, complex128, uint32);
|
||||||
|
|
||||||
// A special GPU kernel for int32.
|
// A special GPU kernel for int32.
|
||||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||||
|
@ -68,6 +68,7 @@ TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_KERNELS);
|
|||||||
TF_CALL_int32(DEFINE_GPU_KERNELS);
|
TF_CALL_int32(DEFINE_GPU_KERNELS);
|
||||||
TF_CALL_int64(DEFINE_GPU_KERNELS);
|
TF_CALL_int64(DEFINE_GPU_KERNELS);
|
||||||
TF_CALL_int8(DEFINE_GPU_KERNELS);
|
TF_CALL_int8(DEFINE_GPU_KERNELS);
|
||||||
|
TF_CALL_uint32(DEFINE_GPU_KERNELS);
|
||||||
#undef DEFINE_GPU_KERNELS
|
#undef DEFINE_GPU_KERNELS
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -97,6 +97,8 @@ typedef Eigen::SyclDevice SYCLDevice;
|
|||||||
AssignOpT<CPUDevice, type>);
|
AssignOpT<CPUDevice, type>);
|
||||||
|
|
||||||
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
||||||
|
// uint32 not included in ALL_TYPES
|
||||||
|
TF_CALL_uint32(REGISTER_KERNELS);
|
||||||
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
||||||
// quint16 not included in QUANTIZIED_TYPES
|
// quint16 not included in QUANTIZIED_TYPES
|
||||||
TF_CALL_quint16(REGISTER_KERNELS);
|
TF_CALL_quint16(REGISTER_KERNELS);
|
||||||
@ -112,6 +114,7 @@ TF_CALL_quint16(REGISTER_KERNELS);
|
|||||||
|
|
||||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
|
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
|
||||||
TF_CALL_int64(REGISTER_GPU_KERNELS);
|
TF_CALL_int64(REGISTER_GPU_KERNELS);
|
||||||
|
TF_CALL_uint32(REGISTER_GPU_KERNELS);
|
||||||
#undef REGISTER_GPU_KERNELS
|
#undef REGISTER_GPU_KERNELS
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
@ -122,6 +122,7 @@ REGISTER_SYCL_HOST_KERNEL(bool);
|
|||||||
|
|
||||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
|
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
|
||||||
REGISTER_GPU_KERNEL(Variant);
|
REGISTER_GPU_KERNEL(Variant);
|
||||||
|
TF_CALL_uint32(REGISTER_GPU_KERNEL);
|
||||||
|
|
||||||
#undef REGISTER_GPU_KERNEL
|
#undef REGISTER_GPU_KERNEL
|
||||||
|
|
||||||
|
@ -400,6 +400,7 @@ TF_CALL_int64(REGISTER_KERNEL);
|
|||||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
|
||||||
TF_CALL_complex64(REGISTER_KERNEL);
|
TF_CALL_complex64(REGISTER_KERNEL);
|
||||||
TF_CALL_complex128(REGISTER_KERNEL);
|
TF_CALL_complex128(REGISTER_KERNEL);
|
||||||
|
TF_CALL_uint32(REGISTER_KERNEL);
|
||||||
|
|
||||||
#undef REGISTER_KERNEL
|
#undef REGISTER_KERNEL
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -96,6 +96,7 @@ TF_CALL_int64(DEFINE_GPU_SPECS);
|
|||||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
||||||
TF_CALL_complex64(DEFINE_GPU_SPECS);
|
TF_CALL_complex64(DEFINE_GPU_SPECS);
|
||||||
TF_CALL_complex128(DEFINE_GPU_SPECS);
|
TF_CALL_complex128(DEFINE_GPU_SPECS);
|
||||||
|
TF_CALL_uint32(DEFINE_GPU_SPECS)
|
||||||
|
|
||||||
#undef DEFINE_GPU_SPECS
|
#undef DEFINE_GPU_SPECS
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
@ -258,6 +258,7 @@ namespace functor {
|
|||||||
|
|
||||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||||
TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
|
TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
|
||||||
|
TF_CALL_uint32(DECLARE_GPU_SPEC);
|
||||||
|
|
||||||
#undef DECLARE_GPU_SPEC
|
#undef DECLARE_GPU_SPEC
|
||||||
|
|
||||||
@ -275,7 +276,7 @@ TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
|
|||||||
|
|
||||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
|
||||||
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
|
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
|
||||||
|
TF_CALL_uint32(REGISTER_KERNELS)
|
||||||
#undef REGISTER_KERNELS
|
#undef REGISTER_KERNELS
|
||||||
|
|
||||||
#endif // end GOOGLE_CUDA
|
#endif // end GOOGLE_CUDA
|
||||||
|
@ -23,6 +23,8 @@ namespace tensorflow {
|
|||||||
using Eigen::GpuDevice;
|
using Eigen::GpuDevice;
|
||||||
|
|
||||||
template struct functor::TopKFunctor<GPUDevice, uint16>;
|
template struct functor::TopKFunctor<GPUDevice, uint16>;
|
||||||
|
template struct functor::TopKFunctor<GPUDevice, uint32>;
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
@ -236,6 +236,7 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
|
|||||||
|
|
||||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||||
TF_CALL_int64(REGISTER_GPU_KERNELS);
|
TF_CALL_int64(REGISTER_GPU_KERNELS);
|
||||||
|
TF_CALL_uint32(REGISTER_GPU_KERNELS);
|
||||||
#undef REGISTER_GPU_KERNELS
|
#undef REGISTER_GPU_KERNELS
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
@ -435,8 +435,14 @@ Returns `x` + `y` element-wise.
|
|||||||
)doc");
|
)doc");
|
||||||
#endif // INTEL_MKL
|
#endif // INTEL_MKL
|
||||||
|
|
||||||
REGISTER_OP("Sub").BINARY_MORE().SetShapeFn(
|
REGISTER_OP("Sub")
|
||||||
shape_inference::BroadcastBinaryOpShapeFn);
|
.Input("x: T")
|
||||||
|
.Input("y: T")
|
||||||
|
.Output("z: T")
|
||||||
|
.Attr(
|
||||||
|
"T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, "
|
||||||
|
"int64, complex64, complex128, uint32}")
|
||||||
|
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
|
||||||
|
|
||||||
REGISTER_OP("_MklSub")
|
REGISTER_OP("_MklSub")
|
||||||
.BINARY_FEWER()
|
.BINARY_FEWER()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user