Enable FP64 support for tf.sort for XLA_CPU and XLA_GPU.

PiperOrigin-RevId: 338838084
Change-Id: Iacab47e06f83b5b75d007d80a725063b92c0544d
This commit is contained in:
A. Unique TensorFlower 2020-10-24 09:23:51 -07:00 committed by TensorFlower Gardener
parent 9bf094291f
commit 7e71a2732f
3 changed files with 56 additions and 20 deletions

View File

@ -1549,7 +1549,7 @@ tf_xla_py_test(
srcs = ["sort_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 1,
shard_count = 2,
# Times out in fastbuild mode.
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -1560,6 +1560,7 @@ tf_xla_py_test(
"//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import xla_test
@ -30,7 +31,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
class XlaSortOpTest(xla_test.XLATestCase):
class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected):
with self.session() as session:
@ -50,21 +51,31 @@ class XlaSortOpTest(xla_test.XLATestCase):
def testSort(self):
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64])
for dtype in supported_types.intersection(self.numeric_types):
# TPU implementation is not supported for double precision
if dtype == np.float64 and self.device == "TPU":
continue
x = np.arange(101, dtype=dtype)
np.random.shuffle(x)
self._assertOpOutputMatchesExpected(
xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
def testKeyValueSort(self):
supported_key_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
supported_value_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32,
dtypes.int64.as_numpy_dtype, dtypes.uint64.as_numpy_dtype])
supported_key_types = set([
dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32,
np.uint32
])
supported_value_types = set([
dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32,
np.uint32, dtypes.int64.as_numpy_dtype, dtypes.uint64.as_numpy_dtype
])
for key_type in supported_key_types.intersection(self.numeric_types):
for value_type in supported_value_types.intersection(self.numeric_types):
if key_type == np.float64 or value_type == np.float64:
# TPU implementation is not supported for double precision
if self.device == "TPU":
continue
x = np.arange(101, dtype=key_type)
np.random.shuffle(x)
y = (-x).astype(value_type)
@ -76,9 +87,13 @@ class XlaSortOpTest(xla_test.XLATestCase):
])
def testTopK(self):
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
supported_types = set([
dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32,
np.uint32
])
for dtype in supported_types.intersection(self.numeric_types):
if dtype == np.float64 and self.device == "TPU":
continue
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
# after conversion to bfloat16, so the possible resulting index array is
# no longer unique.
@ -100,10 +115,18 @@ class XlaSortOpTest(xla_test.XLATestCase):
topk, [x.astype(dtype)],
expected=[x[indices].astype(dtype), indices])
def testTopK2D(self):
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
for dtype in supported_types.intersection(self.numeric_types):
@parameterized.named_parameters(
("HalfPrecision", dtypes.bfloat16.as_numpy_dtype),
("SinglePrecision", np.float32),
("DoublePrecision", np.float64),
("Int", np.int32),
("UnsignedInt", np.uint32),
)
def testTopK2D(self, dtype):
if dtype in self.numeric_types:
# TPU implementation is not supported for double precision
if dtype == np.float64 and self.device == "TPU":
return
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
# after conversion to bfloat16, so the possible resulting index array is
# no longer unique.
@ -131,8 +154,12 @@ class XlaSortOpTest(xla_test.XLATestCase):
@test_util.disable_mlir_bridge("Support compare type in HLO Compare Op")
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64])
for dtype in supported_types.intersection(self.numeric_types):
# TPU implementation is not supported for double precision
if dtype == np.float64 and self.device == "TPU":
continue
with self.session() as sess:
p = array_ops.placeholder(dtype)
with self.test_scope():
@ -145,8 +172,12 @@ class XlaSortOpTest(xla_test.XLATestCase):
def testTopKInfinities(self):
"""Tests that positive and negative infinity sort correctly."""
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64])
for dtype in supported_types.intersection(self.numeric_types):
# TPU implementation is not supported for double precision
if dtype == np.float64 and self.device == "TPU":
continue
with self.session() as sess:
p = array_ops.placeholder(dtype)
with self.test_scope():
@ -161,9 +192,12 @@ class XlaSortOpTest(xla_test.XLATestCase):
dtype=dtype), results[0])
self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1]))
def testInTopK(self):
supported_types = set([np.int32, np.int64])
for dtype in supported_types.intersection(self.numeric_types):
@parameterized.named_parameters(
("Int32", np.int32),
("Int64", np.uint64),
)
def testInTopK(self, dtype):
if dtype in self.numeric_types:
array_size = 200 * 1000
k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
batch = 16

View File

@ -58,7 +58,8 @@ class TopKOp : public XlaOpKernel {
};
REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstantInput("k").TypeConstraint(
"T", {DT_UINT32, DT_INT32, DT_FLOAT, DT_BFLOAT16}),
"T",
{DT_UINT32, DT_INT32, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}),
TopKOp);
} // namespace