diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 2a301ad3b6e..ea4e72b6328 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -50,10 +50,10 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase): def testSort(self): supported_types = set( - [dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64]) + [dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64]) for dtype in supported_types.intersection(self.numeric_types): # TPU implementation is not supported for double precision - if dtype == np.float64 and self.device == "TPU": + if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU": continue x = np.arange(101, dtype=dtype) np.random.shuffle(x) @@ -62,16 +62,18 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase): def testKeyValueSort(self): supported_key_types = set([ - dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32, - np.uint32 + dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64, + np.int32, np.uint32 ]) supported_value_types = set([ - dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32, - np.uint32, dtypes.int64.as_numpy_dtype, dtypes.uint64.as_numpy_dtype + dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64, + np.int32, np.uint32, dtypes.int64.as_numpy_dtype, + dtypes.uint64.as_numpy_dtype ]) for key_type in supported_key_types.intersection(self.numeric_types): for value_type in supported_value_types.intersection(self.numeric_types): - if key_type == np.float64 or value_type == np.float64: + if key_type == np.float64 or value_type == np.float64 or \ + key_type == np.float16 or value_type == np.float16: # TPU implementation is not supported for double precision if self.device == "TPU": continue @@ -87,8 +89,8 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase): def testTopK(self): supported_types = set([ - dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32, - np.uint32 + dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64, + np.int32, np.uint32 ]) for dtype in supported_types.intersection(self.numeric_types): if dtype == np.float64 and self.device == "TPU": @@ -96,7 +98,7 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase): # Use small input size for bfloat16. Otherwise, we'll get duplicate values # after conversion to bfloat16, so the possible resulting index array is # no longer unique. - if dtype == dtypes.bfloat16.as_numpy_dtype: + if dtype in (dtypes.bfloat16.as_numpy_dtype, np.float16): array_size = 20 k_options = [0, 1, 2, 10, 20] else: @@ -116,6 +118,7 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase): @parameterized.named_parameters( ("HalfPrecision", dtypes.bfloat16.as_numpy_dtype), + ("HalfFloatPrecision", np.float16), ("SinglePrecision", np.float32), ("DoublePrecision", np.float64), ("Int", np.int32), @@ -124,12 +127,12 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase): def testTopK2D(self, dtype): if dtype in self.numeric_types: # TPU implementation is not supported for double precision - if dtype == np.float64 and self.device == "TPU": + if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU": return # Use small input size for bfloat16. Otherwise, we'll get duplicate values # after conversion to bfloat16, so the possible resulting index array is # no longer unique. - if dtype == dtypes.bfloat16.as_numpy_dtype: + if dtype in (dtypes.bfloat16.as_numpy_dtype, np.float16): array_size = 10 k_options = [0, 1, 2, 10] else: @@ -153,10 +156,10 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase): def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" supported_types = set( - [dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64]) + [dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64]) for dtype in supported_types.intersection(self.numeric_types): # TPU implementation is not supported for double precision - if dtype == np.float64 and self.device == "TPU": + if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU": continue with self.session() as sess: p = array_ops.placeholder(dtype) @@ -171,10 +174,10 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase): def testTopKInfinities(self): """Tests that positive and negative infinity sort correctly.""" supported_types = set( - [dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64]) + [dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64]) for dtype in supported_types.intersection(self.numeric_types): # TPU implementation is not supported for double precision - if dtype == np.float64 and self.device == "TPU": + if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU": continue with self.session() as sess: p = array_ops.placeholder(dtype) diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 488407eaa12..06c9038a145 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -58,8 +58,8 @@ class TopKOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstantInput("k").TypeConstraint( - "T", - {DT_UINT32, DT_INT32, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}), + "T", {DT_UINT32, DT_INT32, DT_FLOAT, DT_HALF, DT_DOUBLE, + DT_BFLOAT16}), TopKOp); } // namespace