Enable FP64 support for tf.sort for XLA_CPU and XLA_GPU.
PiperOrigin-RevId: 338838084 Change-Id: Iacab47e06f83b5b75d007d80a725063b92c0544d
This commit is contained in:
parent
9bf094291f
commit
7e71a2732f
@ -1549,7 +1549,7 @@ tf_xla_py_test(
|
||||
srcs = ["sort_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 1,
|
||||
shard_count = 2,
|
||||
# Times out in fastbuild mode.
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
@ -1560,6 +1560,7 @@ tf_xla_py_test(
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
@ -30,7 +31,7 @@ from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class XlaSortOpTest(xla_test.XLATestCase):
|
||||
class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def _assertOpOutputMatchesExpected(self, op, args, expected):
|
||||
with self.session() as session:
|
||||
@ -50,21 +51,31 @@ class XlaSortOpTest(xla_test.XLATestCase):
|
||||
|
||||
def testSort(self):
|
||||
supported_types = set(
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
# TPU implementation is not supported for double precision
|
||||
if dtype == np.float64 and self.device == "TPU":
|
||||
continue
|
||||
x = np.arange(101, dtype=dtype)
|
||||
np.random.shuffle(x)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
|
||||
|
||||
def testKeyValueSort(self):
|
||||
supported_key_types = set(
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
|
||||
supported_value_types = set(
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32,
|
||||
dtypes.int64.as_numpy_dtype, dtypes.uint64.as_numpy_dtype])
|
||||
supported_key_types = set([
|
||||
dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32,
|
||||
np.uint32
|
||||
])
|
||||
supported_value_types = set([
|
||||
dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32,
|
||||
np.uint32, dtypes.int64.as_numpy_dtype, dtypes.uint64.as_numpy_dtype
|
||||
])
|
||||
for key_type in supported_key_types.intersection(self.numeric_types):
|
||||
for value_type in supported_value_types.intersection(self.numeric_types):
|
||||
if key_type == np.float64 or value_type == np.float64:
|
||||
# TPU implementation is not supported for double precision
|
||||
if self.device == "TPU":
|
||||
continue
|
||||
x = np.arange(101, dtype=key_type)
|
||||
np.random.shuffle(x)
|
||||
y = (-x).astype(value_type)
|
||||
@ -76,9 +87,13 @@ class XlaSortOpTest(xla_test.XLATestCase):
|
||||
])
|
||||
|
||||
def testTopK(self):
|
||||
supported_types = set(
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
|
||||
supported_types = set([
|
||||
dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32,
|
||||
np.uint32
|
||||
])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
if dtype == np.float64 and self.device == "TPU":
|
||||
continue
|
||||
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
|
||||
# after conversion to bfloat16, so the possible resulting index array is
|
||||
# no longer unique.
|
||||
@ -100,10 +115,18 @@ class XlaSortOpTest(xla_test.XLATestCase):
|
||||
topk, [x.astype(dtype)],
|
||||
expected=[x[indices].astype(dtype), indices])
|
||||
|
||||
def testTopK2D(self):
|
||||
supported_types = set(
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
@parameterized.named_parameters(
|
||||
("HalfPrecision", dtypes.bfloat16.as_numpy_dtype),
|
||||
("SinglePrecision", np.float32),
|
||||
("DoublePrecision", np.float64),
|
||||
("Int", np.int32),
|
||||
("UnsignedInt", np.uint32),
|
||||
)
|
||||
def testTopK2D(self, dtype):
|
||||
if dtype in self.numeric_types:
|
||||
# TPU implementation is not supported for double precision
|
||||
if dtype == np.float64 and self.device == "TPU":
|
||||
return
|
||||
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
|
||||
# after conversion to bfloat16, so the possible resulting index array is
|
||||
# no longer unique.
|
||||
@ -131,8 +154,12 @@ class XlaSortOpTest(xla_test.XLATestCase):
|
||||
@test_util.disable_mlir_bridge("Support compare type in HLO Compare Op")
|
||||
def testTopKZeros(self):
|
||||
"""Tests that positive and negative zeros sort correctly."""
|
||||
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
|
||||
supported_types = set(
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
# TPU implementation is not supported for double precision
|
||||
if dtype == np.float64 and self.device == "TPU":
|
||||
continue
|
||||
with self.session() as sess:
|
||||
p = array_ops.placeholder(dtype)
|
||||
with self.test_scope():
|
||||
@ -145,8 +172,12 @@ class XlaSortOpTest(xla_test.XLATestCase):
|
||||
|
||||
def testTopKInfinities(self):
|
||||
"""Tests that positive and negative infinity sort correctly."""
|
||||
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
|
||||
supported_types = set(
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
# TPU implementation is not supported for double precision
|
||||
if dtype == np.float64 and self.device == "TPU":
|
||||
continue
|
||||
with self.session() as sess:
|
||||
p = array_ops.placeholder(dtype)
|
||||
with self.test_scope():
|
||||
@ -161,9 +192,12 @@ class XlaSortOpTest(xla_test.XLATestCase):
|
||||
dtype=dtype), results[0])
|
||||
self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1]))
|
||||
|
||||
def testInTopK(self):
|
||||
supported_types = set([np.int32, np.int64])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
@parameterized.named_parameters(
|
||||
("Int32", np.int32),
|
||||
("Int64", np.uint64),
|
||||
)
|
||||
def testInTopK(self, dtype):
|
||||
if dtype in self.numeric_types:
|
||||
array_size = 200 * 1000
|
||||
k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
|
||||
batch = 16
|
||||
|
@ -58,7 +58,8 @@ class TopKOp : public XlaOpKernel {
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstantInput("k").TypeConstraint(
|
||||
"T", {DT_UINT32, DT_INT32, DT_FLOAT, DT_BFLOAT16}),
|
||||
"T",
|
||||
{DT_UINT32, DT_INT32, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}),
|
||||
TopKOp);
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user