Enable FP64 support for tf.sort for XLA_CPU and XLA_GPU.

PiperOrigin-RevId: 338838084
Change-Id: Iacab47e06f83b5b75d007d80a725063b92c0544d
This commit is contained in:
A. Unique TensorFlower 2020-10-24 09:23:51 -07:00 committed by TensorFlower Gardener
parent 9bf094291f
commit 7e71a2732f
3 changed files with 56 additions and 20 deletions

View File

@ -1549,7 +1549,7 @@ tf_xla_py_test(
srcs = ["sort_ops_test.py"], srcs = ["sort_ops_test.py"],
enable_mlir_bridge = True, enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
shard_count = 1, shard_count = 2,
# Times out in fastbuild mode. # Times out in fastbuild mode.
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -1560,6 +1560,7 @@ tf_xla_py_test(
"//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
"@absl_py//absl/testing:parameterized",
], ],
) )

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tests import xla_test
@ -30,7 +31,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
class XlaSortOpTest(xla_test.XLATestCase): class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected): def _assertOpOutputMatchesExpected(self, op, args, expected):
with self.session() as session: with self.session() as session:
@ -50,21 +51,31 @@ class XlaSortOpTest(xla_test.XLATestCase):
def testSort(self): def testSort(self):
supported_types = set( supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) [dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64])
for dtype in supported_types.intersection(self.numeric_types): for dtype in supported_types.intersection(self.numeric_types):
# TPU implementation is not supported for double precision
if dtype == np.float64 and self.device == "TPU":
continue
x = np.arange(101, dtype=dtype) x = np.arange(101, dtype=dtype)
np.random.shuffle(x) np.random.shuffle(x)
self._assertOpOutputMatchesExpected( self._assertOpOutputMatchesExpected(
xla.sort, [x], expected=[np.arange(101, dtype=dtype)]) xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
def testKeyValueSort(self): def testKeyValueSort(self):
supported_key_types = set( supported_key_types = set([
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32,
supported_value_types = set( np.uint32
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32, ])
dtypes.int64.as_numpy_dtype, dtypes.uint64.as_numpy_dtype]) supported_value_types = set([
dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32,
np.uint32, dtypes.int64.as_numpy_dtype, dtypes.uint64.as_numpy_dtype
])
for key_type in supported_key_types.intersection(self.numeric_types): for key_type in supported_key_types.intersection(self.numeric_types):
for value_type in supported_value_types.intersection(self.numeric_types): for value_type in supported_value_types.intersection(self.numeric_types):
if key_type == np.float64 or value_type == np.float64:
# TPU implementation is not supported for double precision
if self.device == "TPU":
continue
x = np.arange(101, dtype=key_type) x = np.arange(101, dtype=key_type)
np.random.shuffle(x) np.random.shuffle(x)
y = (-x).astype(value_type) y = (-x).astype(value_type)
@ -76,9 +87,13 @@ class XlaSortOpTest(xla_test.XLATestCase):
]) ])
def testTopK(self): def testTopK(self):
supported_types = set( supported_types = set([
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32,
np.uint32
])
for dtype in supported_types.intersection(self.numeric_types): for dtype in supported_types.intersection(self.numeric_types):
if dtype == np.float64 and self.device == "TPU":
continue
# Use small input size for bfloat16. Otherwise, we'll get duplicate values # Use small input size for bfloat16. Otherwise, we'll get duplicate values
# after conversion to bfloat16, so the possible resulting index array is # after conversion to bfloat16, so the possible resulting index array is
# no longer unique. # no longer unique.
@ -100,10 +115,18 @@ class XlaSortOpTest(xla_test.XLATestCase):
topk, [x.astype(dtype)], topk, [x.astype(dtype)],
expected=[x[indices].astype(dtype), indices]) expected=[x[indices].astype(dtype), indices])
def testTopK2D(self): @parameterized.named_parameters(
supported_types = set( ("HalfPrecision", dtypes.bfloat16.as_numpy_dtype),
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) ("SinglePrecision", np.float32),
for dtype in supported_types.intersection(self.numeric_types): ("DoublePrecision", np.float64),
("Int", np.int32),
("UnsignedInt", np.uint32),
)
def testTopK2D(self, dtype):
if dtype in self.numeric_types:
# TPU implementation is not supported for double precision
if dtype == np.float64 and self.device == "TPU":
return
# Use small input size for bfloat16. Otherwise, we'll get duplicate values # Use small input size for bfloat16. Otherwise, we'll get duplicate values
# after conversion to bfloat16, so the possible resulting index array is # after conversion to bfloat16, so the possible resulting index array is
# no longer unique. # no longer unique.
@ -131,8 +154,12 @@ class XlaSortOpTest(xla_test.XLATestCase):
@test_util.disable_mlir_bridge("Support compare type in HLO Compare Op") @test_util.disable_mlir_bridge("Support compare type in HLO Compare Op")
def testTopKZeros(self): def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly.""" """Tests that positive and negative zeros sort correctly."""
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64])
for dtype in supported_types.intersection(self.numeric_types): for dtype in supported_types.intersection(self.numeric_types):
# TPU implementation is not supported for double precision
if dtype == np.float64 and self.device == "TPU":
continue
with self.session() as sess: with self.session() as sess:
p = array_ops.placeholder(dtype) p = array_ops.placeholder(dtype)
with self.test_scope(): with self.test_scope():
@ -145,8 +172,12 @@ class XlaSortOpTest(xla_test.XLATestCase):
def testTopKInfinities(self): def testTopKInfinities(self):
"""Tests that positive and negative infinity sort correctly.""" """Tests that positive and negative infinity sort correctly."""
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64])
for dtype in supported_types.intersection(self.numeric_types): for dtype in supported_types.intersection(self.numeric_types):
# TPU implementation is not supported for double precision
if dtype == np.float64 and self.device == "TPU":
continue
with self.session() as sess: with self.session() as sess:
p = array_ops.placeholder(dtype) p = array_ops.placeholder(dtype)
with self.test_scope(): with self.test_scope():
@ -161,9 +192,12 @@ class XlaSortOpTest(xla_test.XLATestCase):
dtype=dtype), results[0]) dtype=dtype), results[0])
self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1])) self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1]))
def testInTopK(self): @parameterized.named_parameters(
supported_types = set([np.int32, np.int64]) ("Int32", np.int32),
for dtype in supported_types.intersection(self.numeric_types): ("Int64", np.uint64),
)
def testInTopK(self, dtype):
if dtype in self.numeric_types:
array_size = 200 * 1000 array_size = 200 * 1000
k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
batch = 16 batch = 16

View File

@ -58,7 +58,8 @@ class TopKOp : public XlaOpKernel {
}; };
REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstantInput("k").TypeConstraint( REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstantInput("k").TypeConstraint(
"T", {DT_UINT32, DT_INT32, DT_FLOAT, DT_BFLOAT16}), "T",
{DT_UINT32, DT_INT32, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}),
TopKOp); TopKOp);
} // namespace } // namespace