[tf2xla] Add support for PopulationCount
PiperOrigin-RevId: 324706175 Change-Id: I567c1fec90022573acf7a6dbe18a06191a088c90
This commit is contained in:
parent
474d3df724
commit
7e10ef560d
@ -1952,6 +1952,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"ParallelDynamicStitch",
|
||||
"ParameterizedTruncatedNormal",
|
||||
"PartitionedCall",
|
||||
"PopulationCount",
|
||||
"Qr",
|
||||
"QuantizeAndDequantizeV2",
|
||||
"QuantizeAndDequantizeV3",
|
||||
|
||||
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
@ -90,6 +91,10 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
self.assertAllClose(result, expected, rtol, atol)
|
||||
self.assertAllEqual(np.sort(result), result)
|
||||
|
||||
def AssertAllEqual(self, result, expected, rtol, atol):
|
||||
"""Tests that result and expeted are exactly equal."""
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"MlirHloBuilder::Iota missing required for xla::Diag")
|
||||
def testAllTypeOps(self):
|
||||
@ -779,6 +784,10 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
|
||||
expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype]))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"TF_PopulationCount is missing and is required to translate to "
|
||||
"xla::PopulationCount."
|
||||
)
|
||||
def testIntOps(self):
|
||||
for dtype in self.int_types:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
@ -786,6 +795,34 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([0, -1, 1, 16, 42], dtype=dtype),
|
||||
expected=np.array([-1, 0, -2, -17, -43], dtype=dtype))
|
||||
|
||||
# Test population_count for array inputs.
|
||||
raw_inputs = [
|
||||
0, 1, -1, 3, -3, 5, -5, 14, -14, 127, 128, 255, 256, 65535, 65536,
|
||||
2**31 - 1, 2**31, 2**32 - 1, 2**32, -2**32 + 1, -2**32, -2**63 + 1,
|
||||
2**63 - 1
|
||||
]
|
||||
inputs = np.array(raw_inputs, dtype=dtype)
|
||||
|
||||
def count_bits(x):
|
||||
return sum(bin(z).count("1") for z in six.iterbytes(x.tobytes()))
|
||||
|
||||
truth = [count_bits(x) for x in inputs]
|
||||
self._assertOpOutputMatchesExpected(
|
||||
bitwise_ops.population_count,
|
||||
inputs,
|
||||
expected=np.array(truth, dtype=np.uint8),
|
||||
equality_test=self.AssertAllEqual)
|
||||
|
||||
# Test population_count for scalar inputs.
|
||||
for raw_inp in raw_inputs:
|
||||
inp = dtype(raw_inp)
|
||||
truth = count_bits(inp)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
bitwise_ops.population_count,
|
||||
inp,
|
||||
expected=np.uint8(truth),
|
||||
equality_test=self.AssertAllEqual)
|
||||
|
||||
def testNumericOps(self):
|
||||
for dtype in self.numeric_types - {np.int8, np.uint8}:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
|
||||
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -76,6 +77,8 @@ XLAJIT_MAKE_UNARY(Log1p, xla::Log1p(x));
|
||||
|
||||
XLAJIT_MAKE_UNARY(Invert, xla::Not(x));
|
||||
XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x));
|
||||
XLAJIT_MAKE_UNARY(PopulationCount,
|
||||
xla::ConvertElementType(xla::PopulationCount(x), xla::U8));
|
||||
XLAJIT_MAKE_UNARY(Neg, -x);
|
||||
|
||||
XLAJIT_MAKE_UNARY(Rint, xla::RoundToEven(x));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user