[TF:XLA] Whitelist uint16/int16 for CPU/GPU
This is a rather neglected type in TF, so most ops don't support it (including add). XLA mostly supports it just fine, this change adds a few missing cases when handling constants. PiperOrigin-RevId: 265926474
This commit is contained in:
parent
8a21a2236f
commit
b1ec3b3eae
tensorflow/compiler
jit
tf2xla
xla
@ -98,10 +98,10 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
|
||||
|
||||
// Kernel registrations
|
||||
|
||||
constexpr std::array<DataType, 14> kAllXlaCpuTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
|
||||
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL,
|
||||
DT_BFLOAT16}};
|
||||
constexpr std::array<DataType, 16> kAllXlaCpuTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
|
||||
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
|
||||
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
|
||||
|
||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
|
||||
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);
|
||||
|
@ -147,10 +147,10 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
|
||||
|
||||
// Kernel registrations
|
||||
|
||||
constexpr std::array<DataType, 14> kAllXlaGpuTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
|
||||
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL,
|
||||
DT_BFLOAT16}};
|
||||
constexpr std::array<DataType, 16> kAllXlaGpuTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
|
||||
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
|
||||
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
|
||||
|
||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
|
||||
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
|
||||
|
@ -69,6 +69,9 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
|
||||
case xla::U8:
|
||||
literal = xla::LiteralUtil::CreateR0<uint8>(value);
|
||||
break;
|
||||
case xla::U16:
|
||||
literal = xla::LiteralUtil::CreateR0<uint16>(value);
|
||||
break;
|
||||
case xla::U32:
|
||||
literal = xla::LiteralUtil::CreateR0<uint32>(value);
|
||||
break;
|
||||
@ -78,6 +81,9 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
|
||||
case xla::S8:
|
||||
literal = xla::LiteralUtil::CreateR0<int8>(value);
|
||||
break;
|
||||
case xla::S16:
|
||||
literal = xla::LiteralUtil::CreateR0<int16>(value);
|
||||
break;
|
||||
case xla::S32:
|
||||
literal = xla::LiteralUtil::CreateR0<int32>(value);
|
||||
break;
|
||||
@ -98,9 +104,6 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
|
||||
break;
|
||||
case xla::PRED:
|
||||
LOG(FATAL) << "pred element type is not integral";
|
||||
case xla::S16:
|
||||
case xla::U16:
|
||||
LOG(FATAL) << "u16/s16 literals not yet implemented";
|
||||
case xla::BF16:
|
||||
literal =
|
||||
xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value));
|
||||
|
@ -47,19 +47,20 @@ extern const char* const DEVICE_XLA_GPU;
|
||||
|
||||
constexpr std::array<DataType, 4> kFloatTypes = {
|
||||
{DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
|
||||
constexpr std::array<DataType, 12> kNumericTypes = {
|
||||
{DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
|
||||
DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}};
|
||||
constexpr std::array<DataType, 14> kNumericTypes = {
|
||||
{DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_INT16, DT_INT32,
|
||||
DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128,
|
||||
DT_BFLOAT16}};
|
||||
|
||||
constexpr std::array<DataType, 16> kCpuAllTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
|
||||
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
|
||||
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
|
||||
constexpr std::array<DataType, 18> kCpuAllTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8,
|
||||
DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
|
||||
DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
|
||||
|
||||
constexpr std::array<DataType, 16> kGpuAllTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
|
||||
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
|
||||
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
|
||||
constexpr std::array<DataType, 18> kGpuAllTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8,
|
||||
DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
|
||||
DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
|
||||
|
||||
// Class that manages registrations of operators and devices for the XLA JIT.
|
||||
// Not thread-safe.
|
||||
|
@ -62,12 +62,16 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) {
|
||||
return ConstantR0<complex128>(builder, static_cast<complex128>(value));
|
||||
case U8:
|
||||
return ConstantR0<uint8>(builder, static_cast<uint8>(value));
|
||||
case U16:
|
||||
return ConstantR0<uint16>(builder, static_cast<uint16>(value));
|
||||
case U32:
|
||||
return ConstantR0<uint32>(builder, static_cast<uint32>(value));
|
||||
case U64:
|
||||
return ConstantR0<uint64>(builder, static_cast<uint64>(value));
|
||||
case S8:
|
||||
return ConstantR0<int8>(builder, static_cast<int8>(value));
|
||||
case S16:
|
||||
return ConstantR0<int16>(builder, static_cast<int16>(value));
|
||||
case S32:
|
||||
return ConstantR0<int32>(builder, static_cast<int32>(value));
|
||||
case S64:
|
||||
|
@ -147,12 +147,16 @@ Literal ConvertType(LiteralSlice literal) {
|
||||
switch (primitive_type) {
|
||||
case U8:
|
||||
return LiteralUtil::CreateR0<uint8>(1);
|
||||
case U16:
|
||||
return LiteralUtil::CreateR0<uint16>(1);
|
||||
case U32:
|
||||
return LiteralUtil::CreateR0<uint32>(1);
|
||||
case U64:
|
||||
return LiteralUtil::CreateR0<uint64>(1);
|
||||
case S8:
|
||||
return LiteralUtil::CreateR0<int8>(1);
|
||||
case S16:
|
||||
return LiteralUtil::CreateR0<int16>(1);
|
||||
case S32:
|
||||
return LiteralUtil::CreateR0<int32>(1);
|
||||
case S64:
|
||||
@ -171,9 +175,6 @@ Literal ConvertType(LiteralSlice literal) {
|
||||
return LiteralUtil::CreateR0<complex128>(1);
|
||||
case PRED:
|
||||
return LiteralUtil::CreateR0<bool>(true);
|
||||
case S16:
|
||||
case U16:
|
||||
LOG(FATAL) << "u16/s16 literals not yet implemented";
|
||||
case TUPLE:
|
||||
LOG(FATAL) << "tuple element type cannot take on value of 1";
|
||||
case OPAQUE_TYPE:
|
||||
@ -187,12 +188,16 @@ Literal ConvertType(LiteralSlice literal) {
|
||||
switch (primitive_type) {
|
||||
case U8:
|
||||
return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
|
||||
case U16:
|
||||
return LiteralUtil::CreateR0<uint16>(std::numeric_limits<uint16>::min());
|
||||
case U32:
|
||||
return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
|
||||
case U64:
|
||||
return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
|
||||
case S8:
|
||||
return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
|
||||
case S16:
|
||||
return LiteralUtil::CreateR0<int16>(std::numeric_limits<int16>::min());
|
||||
case S32:
|
||||
return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
|
||||
case S64:
|
||||
@ -209,9 +214,6 @@ Literal ConvertType(LiteralSlice literal) {
|
||||
LOG(FATAL) << "C128 element type has no minimum value";
|
||||
case PRED:
|
||||
return LiteralUtil::CreateR0<bool>(false);
|
||||
case S16:
|
||||
case U16:
|
||||
LOG(FATAL) << "u16/s16 literals not yet implemented";
|
||||
case F16:
|
||||
return LiteralUtil::CreateR0<half>(
|
||||
static_cast<half>(-std::numeric_limits<float>::infinity()));
|
||||
@ -231,12 +233,16 @@ Literal ConvertType(LiteralSlice literal) {
|
||||
switch (primitive_type) {
|
||||
case U8:
|
||||
return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
|
||||
case U16:
|
||||
return LiteralUtil::CreateR0<uint16>(std::numeric_limits<uint16>::max());
|
||||
case U32:
|
||||
return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
|
||||
case U64:
|
||||
return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
|
||||
case S8:
|
||||
return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
|
||||
case S16:
|
||||
return LiteralUtil::CreateR0<int16>(std::numeric_limits<int16>::max());
|
||||
case S32:
|
||||
return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
|
||||
case S64:
|
||||
@ -249,9 +255,6 @@ Literal ConvertType(LiteralSlice literal) {
|
||||
std::numeric_limits<double>::infinity());
|
||||
case PRED:
|
||||
return LiteralUtil::CreateR0<bool>(true);
|
||||
case S16:
|
||||
case U16:
|
||||
LOG(FATAL) << "u16/s16 literals not yet implemented";
|
||||
case F16:
|
||||
return LiteralUtil::CreateR0<half>(
|
||||
static_cast<half>(std::numeric_limits<float>::infinity()));
|
||||
|
Loading…
Reference in New Issue
Block a user