Enable FP16 support for TopKOp for XLA_CPU and XLA_GPU.
PiperOrigin-RevId: 341129873 Change-Id: I5ceb5d4de41dbfaa6466c1a6ab63fd76f70e86c6
This commit is contained in:
parent
1fed8bfa46
commit
fcb5476957
tensorflow/compiler
@ -50,10 +50,10 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def testSort(self):
|
||||
supported_types = set(
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64])
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
# TPU implementation is not supported for double precision
|
||||
if dtype == np.float64 and self.device == "TPU":
|
||||
if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
|
||||
continue
|
||||
x = np.arange(101, dtype=dtype)
|
||||
np.random.shuffle(x)
|
||||
@ -62,16 +62,18 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def testKeyValueSort(self):
|
||||
supported_key_types = set([
|
||||
dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32,
|
||||
np.uint32
|
||||
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
|
||||
np.int32, np.uint32
|
||||
])
|
||||
supported_value_types = set([
|
||||
dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32,
|
||||
np.uint32, dtypes.int64.as_numpy_dtype, dtypes.uint64.as_numpy_dtype
|
||||
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
|
||||
np.int32, np.uint32, dtypes.int64.as_numpy_dtype,
|
||||
dtypes.uint64.as_numpy_dtype
|
||||
])
|
||||
for key_type in supported_key_types.intersection(self.numeric_types):
|
||||
for value_type in supported_value_types.intersection(self.numeric_types):
|
||||
if key_type == np.float64 or value_type == np.float64:
|
||||
if key_type == np.float64 or value_type == np.float64 or \
|
||||
key_type == np.float16 or value_type == np.float16:
|
||||
# TPU implementation is not supported for double precision
|
||||
if self.device == "TPU":
|
||||
continue
|
||||
@ -87,8 +89,8 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def testTopK(self):
|
||||
supported_types = set([
|
||||
dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64, np.int32,
|
||||
np.uint32
|
||||
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
|
||||
np.int32, np.uint32
|
||||
])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
if dtype == np.float64 and self.device == "TPU":
|
||||
@ -96,7 +98,7 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
|
||||
# after conversion to bfloat16, so the possible resulting index array is
|
||||
# no longer unique.
|
||||
if dtype == dtypes.bfloat16.as_numpy_dtype:
|
||||
if dtype in (dtypes.bfloat16.as_numpy_dtype, np.float16):
|
||||
array_size = 20
|
||||
k_options = [0, 1, 2, 10, 20]
|
||||
else:
|
||||
@ -116,6 +118,7 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("HalfPrecision", dtypes.bfloat16.as_numpy_dtype),
|
||||
("HalfFloatPrecision", np.float16),
|
||||
("SinglePrecision", np.float32),
|
||||
("DoublePrecision", np.float64),
|
||||
("Int", np.int32),
|
||||
@ -124,12 +127,12 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
def testTopK2D(self, dtype):
|
||||
if dtype in self.numeric_types:
|
||||
# TPU implementation is not supported for double precision
|
||||
if dtype == np.float64 and self.device == "TPU":
|
||||
if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
|
||||
return
|
||||
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
|
||||
# after conversion to bfloat16, so the possible resulting index array is
|
||||
# no longer unique.
|
||||
if dtype == dtypes.bfloat16.as_numpy_dtype:
|
||||
if dtype in (dtypes.bfloat16.as_numpy_dtype, np.float16):
|
||||
array_size = 10
|
||||
k_options = [0, 1, 2, 10]
|
||||
else:
|
||||
@ -153,10 +156,10 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
def testTopKZeros(self):
|
||||
"""Tests that positive and negative zeros sort correctly."""
|
||||
supported_types = set(
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64])
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
# TPU implementation is not supported for double precision
|
||||
if dtype == np.float64 and self.device == "TPU":
|
||||
if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
|
||||
continue
|
||||
with self.session() as sess:
|
||||
p = array_ops.placeholder(dtype)
|
||||
@ -171,10 +174,10 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
def testTopKInfinities(self):
|
||||
"""Tests that positive and negative infinity sort correctly."""
|
||||
supported_types = set(
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.float64])
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
# TPU implementation is not supported for double precision
|
||||
if dtype == np.float64 and self.device == "TPU":
|
||||
if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
|
||||
continue
|
||||
with self.session() as sess:
|
||||
p = array_ops.placeholder(dtype)
|
||||
|
@ -58,8 +58,8 @@ class TopKOp : public XlaOpKernel {
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstantInput("k").TypeConstraint(
|
||||
"T",
|
||||
{DT_UINT32, DT_INT32, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}),
|
||||
"T", {DT_UINT32, DT_INT32, DT_FLOAT, DT_HALF, DT_DOUBLE,
|
||||
DT_BFLOAT16}),
|
||||
TopKOp);
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user