Enable FP16 support for TopKOp for XLA_CPU and XLA_GPU.

PiperOrigin-RevId: 341129873
Change-Id: I5ceb5d4de41dbfaa6466c1a6ab63fd76f70e86c6
This commit is contained in:
A. Unique TensorFlower 2020-11-06 15:28:50 -08:00 committed by TensorFlower Gardener
parent 1fed8bfa46
commit fcb5476957
2 changed files with 21 additions and 18 deletions
tensorflow/compiler
tests
tf2xla/kernels

View File

@ -50,10 +50,10 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
def testSort(self):
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64])
[dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64])
for dtype in supported_types.intersection(self.numeric_types):
# TPU implementation is not supported for double precision
if dtype == np.float64 and self.device == "TPU":
if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
continue
x = np.arange(101, dtype=dtype)
np.random.shuffle(x)
@ -62,16 +62,18 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
def testKeyValueSort(self):
supported_key_types = set([
dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32,
np.uint32
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
np.int32, np.uint32
])
supported_value_types = set([
dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32,
np.uint32, dtypes.int64.as_numpy_dtype, dtypes.uint64.as_numpy_dtype
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
np.int32, np.uint32, dtypes.int64.as_numpy_dtype,
dtypes.uint64.as_numpy_dtype
])
for key_type in supported_key_types.intersection(self.numeric_types):
for value_type in supported_value_types.intersection(self.numeric_types):
if key_type == np.float64 or value_type == np.float64:
if key_type == np.float64 or value_type == np.float64 or \
key_type == np.float16 or value_type == np.float16:
# TPU implementation is not supported for double precision
if self.device == "TPU":
continue
@ -87,8 +89,8 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
def testTopK(self):
supported_types = set([
dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32,
np.uint32
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
np.int32, np.uint32
])
for dtype in supported_types.intersection(self.numeric_types):
if dtype == np.float64 and self.device == "TPU":
@ -96,7 +98,7 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
# after conversion to bfloat16, so the possible resulting index array is
# no longer unique.
if dtype == dtypes.bfloat16.as_numpy_dtype:
if dtype in (dtypes.bfloat16.as_numpy_dtype, np.float16):
array_size = 20
k_options = [0, 1, 2, 10, 20]
else:
@ -116,6 +118,7 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
@parameterized.named_parameters(
("HalfPrecision", dtypes.bfloat16.as_numpy_dtype),
("HalfFloatPrecision", np.float16),
("SinglePrecision", np.float32),
("DoublePrecision", np.float64),
("Int", np.int32),
@ -124,12 +127,12 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
def testTopK2D(self, dtype):
if dtype in self.numeric_types:
# TPU implementation is not supported for double precision
if dtype == np.float64 and self.device == "TPU":
if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
return
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
# after conversion to bfloat16, so the possible resulting index array is
# no longer unique.
if dtype == dtypes.bfloat16.as_numpy_dtype:
if dtype in (dtypes.bfloat16.as_numpy_dtype, np.float16):
array_size = 10
k_options = [0, 1, 2, 10]
else:
@ -153,10 +156,10 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64])
[dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64])
for dtype in supported_types.intersection(self.numeric_types):
# TPU implementation is not supported for double precision
if dtype == np.float64 and self.device == "TPU":
if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
continue
with self.session() as sess:
p = array_ops.placeholder(dtype)
@ -171,10 +174,10 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
def testTopKInfinities(self):
"""Tests that positive and negative infinity sort correctly."""
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64])
[dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64])
for dtype in supported_types.intersection(self.numeric_types):
# TPU implementation is not supported for double precision
if dtype == np.float64 and self.device == "TPU":
if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
continue
with self.session() as sess:
p = array_ops.placeholder(dtype)

View File

@ -58,8 +58,8 @@ class TopKOp : public XlaOpKernel {
};
REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstantInput("k").TypeConstraint(
"T",
{DT_UINT32, DT_INT32, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}),
"T", {DT_UINT32, DT_INT32, DT_FLOAT, DT_HALF, DT_DOUBLE,
DT_BFLOAT16}),
TopKOp);
} // namespace