[TF:XLA] Whitelist uint16/int16 for CPU/GPU

This is a rather neglected type in TF, so most ops don't support it (including
add). XLA mostly supports it just fine, this change adds a few missing cases
when handling constants.

PiperOrigin-RevId: 265926474
This commit is contained in:
Benjamin Kramer 2019-08-28 09:28:54 -07:00 committed by TensorFlower Gardener
parent 8a21a2236f
commit b1ec3b3eae
6 changed files with 42 additions and 31 deletions

View File

@ -98,10 +98,10 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
// Kernel registrations
constexpr std::array<DataType, 14> kAllXlaCpuTypes = {
{DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL,
DT_BFLOAT16}};
constexpr std::array<DataType, 16> kAllXlaCpuTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);

View File

@ -147,10 +147,10 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
// Kernel registrations
constexpr std::array<DataType, 14> kAllXlaGpuTypes = {
{DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL,
DT_BFLOAT16}};
constexpr std::array<DataType, 16> kAllXlaGpuTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);

View File

@ -69,6 +69,9 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
case xla::U8:
literal = xla::LiteralUtil::CreateR0<uint8>(value);
break;
case xla::U16:
literal = xla::LiteralUtil::CreateR0<uint16>(value);
break;
case xla::U32:
literal = xla::LiteralUtil::CreateR0<uint32>(value);
break;
@ -78,6 +81,9 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
case xla::S8:
literal = xla::LiteralUtil::CreateR0<int8>(value);
break;
case xla::S16:
literal = xla::LiteralUtil::CreateR0<int16>(value);
break;
case xla::S32:
literal = xla::LiteralUtil::CreateR0<int32>(value);
break;
@ -98,9 +104,6 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
break;
case xla::PRED:
LOG(FATAL) << "pred element type is not integral";
case xla::S16:
case xla::U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case xla::BF16:
literal =
xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value));

View File

@ -47,19 +47,20 @@ extern const char* const DEVICE_XLA_GPU;
constexpr std::array<DataType, 4> kFloatTypes = {
{DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
constexpr std::array<DataType, 12> kNumericTypes = {
{DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}};
constexpr std::array<DataType, 14> kNumericTypes = {
{DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_INT16, DT_INT32,
DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128,
DT_BFLOAT16}};
constexpr std::array<DataType, 16> kCpuAllTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
constexpr std::array<DataType, 18> kCpuAllTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8,
DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
constexpr std::array<DataType, 16> kGpuAllTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
constexpr std::array<DataType, 18> kGpuAllTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8,
DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
// Class that manages registrations of operators and devices for the XLA JIT.
// Not thread-safe.

View File

@ -62,12 +62,16 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) {
return ConstantR0<complex128>(builder, static_cast<complex128>(value));
case U8:
return ConstantR0<uint8>(builder, static_cast<uint8>(value));
case U16:
return ConstantR0<uint16>(builder, static_cast<uint16>(value));
case U32:
return ConstantR0<uint32>(builder, static_cast<uint32>(value));
case U64:
return ConstantR0<uint64>(builder, static_cast<uint64>(value));
case S8:
return ConstantR0<int8>(builder, static_cast<int8>(value));
case S16:
return ConstantR0<int16>(builder, static_cast<int16>(value));
case S32:
return ConstantR0<int32>(builder, static_cast<int32>(value));
case S64:

View File

@ -147,12 +147,16 @@ Literal ConvertType(LiteralSlice literal) {
switch (primitive_type) {
case U8:
return LiteralUtil::CreateR0<uint8>(1);
case U16:
return LiteralUtil::CreateR0<uint16>(1);
case U32:
return LiteralUtil::CreateR0<uint32>(1);
case U64:
return LiteralUtil::CreateR0<uint64>(1);
case S8:
return LiteralUtil::CreateR0<int8>(1);
case S16:
return LiteralUtil::CreateR0<int16>(1);
case S32:
return LiteralUtil::CreateR0<int32>(1);
case S64:
@ -171,9 +175,6 @@ Literal ConvertType(LiteralSlice literal) {
return LiteralUtil::CreateR0<complex128>(1);
case PRED:
return LiteralUtil::CreateR0<bool>(true);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case TUPLE:
LOG(FATAL) << "tuple element type cannot take on value of 1";
case OPAQUE_TYPE:
@ -187,12 +188,16 @@ Literal ConvertType(LiteralSlice literal) {
switch (primitive_type) {
case U8:
return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
case U16:
return LiteralUtil::CreateR0<uint16>(std::numeric_limits<uint16>::min());
case U32:
return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
case U64:
return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
case S8:
return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
case S16:
return LiteralUtil::CreateR0<int16>(std::numeric_limits<int16>::min());
case S32:
return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
case S64:
@ -209,9 +214,6 @@ Literal ConvertType(LiteralSlice literal) {
LOG(FATAL) << "C128 element type has no minimum value";
case PRED:
return LiteralUtil::CreateR0<bool>(false);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
return LiteralUtil::CreateR0<half>(
static_cast<half>(-std::numeric_limits<float>::infinity()));
@ -231,12 +233,16 @@ Literal ConvertType(LiteralSlice literal) {
switch (primitive_type) {
case U8:
return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
case U16:
return LiteralUtil::CreateR0<uint16>(std::numeric_limits<uint16>::max());
case U32:
return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
case U64:
return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
case S8:
return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
case S16:
return LiteralUtil::CreateR0<int16>(std::numeric_limits<int16>::max());
case S32:
return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
case S64:
@ -249,9 +255,6 @@ Literal ConvertType(LiteralSlice literal) {
std::numeric_limits<double>::infinity());
case PRED:
return LiteralUtil::CreateR0<bool>(true);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
return LiteralUtil::CreateR0<half>(
static_cast<half>(std::numeric_limits<float>::infinity()));