[tf2xla] Add support for PopulationCount

PiperOrigin-RevId: 324706175
Change-Id: I567c1fec90022573acf7a6dbe18a06191a088c90
This commit is contained in:
David Majnemer 2020-08-03 16:20:05 -07:00 committed by TensorFlower Gardener
parent 474d3df724
commit 7e10ef560d
3 changed files with 41 additions and 0 deletions

View File

@ -1952,6 +1952,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"ParallelDynamicStitch",
"ParameterizedTruncatedNormal",
"PartitionedCall",
"PopulationCount",
"Qr",
"QuantizeAndDequantizeV2",
"QuantizeAndDequantizeV3",

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import unittest
import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.tests import xla_test
@ -90,6 +91,10 @@ class UnaryOpsTest(xla_test.XLATestCase):
self.assertAllClose(result, expected, rtol, atol)
self.assertAllEqual(np.sort(result), result)
def AssertAllEqual(self, result, expected, rtol, atol):
"""Tests that result and expeted are exactly equal."""
self.assertAllEqual(result, expected)
@test_util.disable_mlir_bridge(
"MlirHloBuilder::Iota missing required for xla::Diag")
def testAllTypeOps(self):
@ -779,6 +784,10 @@ class UnaryOpsTest(xla_test.XLATestCase):
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype]))
@test_util.disable_mlir_bridge(
"TF_PopulationCount is missing and is required to translate to "
"xla::PopulationCount."
)
def testIntOps(self):
for dtype in self.int_types:
self._assertOpOutputMatchesExpected(
@ -786,6 +795,34 @@ class UnaryOpsTest(xla_test.XLATestCase):
np.array([0, -1, 1, 16, 42], dtype=dtype),
expected=np.array([-1, 0, -2, -17, -43], dtype=dtype))
# Test population_count for array inputs.
raw_inputs = [
0, 1, -1, 3, -3, 5, -5, 14, -14, 127, 128, 255, 256, 65535, 65536,
2**31 - 1, 2**31, 2**32 - 1, 2**32, -2**32 + 1, -2**32, -2**63 + 1,
2**63 - 1
]
inputs = np.array(raw_inputs, dtype=dtype)
def count_bits(x):
return sum(bin(z).count("1") for z in six.iterbytes(x.tobytes()))
truth = [count_bits(x) for x in inputs]
self._assertOpOutputMatchesExpected(
bitwise_ops.population_count,
inputs,
expected=np.array(truth, dtype=np.uint8),
equality_test=self.AssertAllEqual)
# Test population_count for scalar inputs.
for raw_inp in raw_inputs:
inp = dtype(raw_inp)
truth = count_bits(inp)
self._assertOpOutputMatchesExpected(
bitwise_ops.population_count,
inp,
expected=np.uint8(truth),
equality_test=self.AssertAllEqual)
def testNumericOps(self):
for dtype in self.numeric_types - {np.int8, np.uint8}:
self._assertOpOutputMatchesExpected(

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
@ -76,6 +77,8 @@ XLAJIT_MAKE_UNARY(Log1p, xla::Log1p(x));
XLAJIT_MAKE_UNARY(Invert, xla::Not(x));
XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x));
XLAJIT_MAKE_UNARY(PopulationCount,
xla::ConvertElementType(xla::PopulationCount(x), xla::U8));
XLAJIT_MAKE_UNARY(Neg, -x);
XLAJIT_MAKE_UNARY(Rint, xla::RoundToEven(x));