Add support for rank >= 1 tensors for XLA top_k_v2.
PiperOrigin-RevId: 205146612
This commit is contained in:
parent
7d79c72ba0
commit
e3006b1d70
@ -88,6 +88,38 @@ class XlaSortOpTest(xla_test.XLATestCase):
|
||||
topk, [x.astype(dtype)],
|
||||
expected=[x[indices].astype(dtype), indices])
|
||||
|
||||
def testTopK2D(self):
|
||||
# TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
|
||||
if self.device in ["XLA_CPU", "XLA_GPU"]:
|
||||
return
|
||||
|
||||
supported_types = set(
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
|
||||
# after conversion to bfloat16, so the possible resulting index array is
|
||||
# no longer unique.
|
||||
if dtype == dtypes.bfloat16.as_numpy_dtype:
|
||||
array_size = 10
|
||||
k_options = [0, 1, 2, 10]
|
||||
else:
|
||||
array_size = 200 * 1000
|
||||
k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
|
||||
batch = 16
|
||||
for x in [np.arange(batch * array_size)]:
|
||||
np.random.shuffle(x)
|
||||
x = np.reshape(x, [batch, array_size])
|
||||
for k in k_options:
|
||||
indices = x.argsort(axis=1)[::, -1:-k - 1:-1]
|
||||
expected = np.sort(x, axis=1)[::, -1:-k - 1:-1]
|
||||
|
||||
def topk(v, k=k):
|
||||
return nn_ops.top_k(v, k=k, sorted=True)
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
topk, [x.astype(dtype)],
|
||||
expected=[expected.astype(dtype), indices])
|
||||
|
||||
def testTopKZeros(self):
|
||||
"""Tests that positive and negative zeros sort correctly."""
|
||||
# TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
|
||||
|
@ -41,33 +41,35 @@ class TopKOp : public XlaOpKernel {
|
||||
OP_REQUIRES(context, input_shape.dims() >= 1,
|
||||
errors::InvalidArgument("input must be >= 1-D, got shape ",
|
||||
input_shape.DebugString()));
|
||||
int last_dim = input_shape.dims() - 1;
|
||||
int last_dim_size = input_shape.dim_size(last_dim);
|
||||
OP_REQUIRES(
|
||||
context, input_shape.dim_size(input_shape.dims() - 1) >= k,
|
||||
context, last_dim_size >= k,
|
||||
errors::InvalidArgument("input must have at least k columns. Had ",
|
||||
input_shape.dim_size(input_shape.dims() - 1),
|
||||
", needed ", k));
|
||||
|
||||
OP_REQUIRES(
|
||||
context, input_shape.dims() == 1,
|
||||
errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ",
|
||||
input_shape.DebugString()));
|
||||
last_dim_size, ", needed ", k));
|
||||
|
||||
xla::XlaBuilder* const b = context->builder();
|
||||
if (input_shape.dim_size(0) < k) {
|
||||
k = input_shape.dim_size(0);
|
||||
if (last_dim_size < k) {
|
||||
k = last_dim_size;
|
||||
}
|
||||
const xla::XlaOp input = context->Input(0);
|
||||
xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, input_shape.dim_size(0));
|
||||
xla::XlaOp sort_result = xla::Sort(xla::Neg(input), iota_s32);
|
||||
|
||||
xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, last_dim_size);
|
||||
auto input_dims = input_shape.dim_sizes();
|
||||
std::vector<int64> broadcast_dims(input_dims.begin(), input_dims.end() - 1);
|
||||
xla::XlaOp broadcast_s32 = xla::Broadcast(iota_s32, broadcast_dims);
|
||||
xla::XlaOp sort_result = xla::Sort(xla::Neg(input), broadcast_s32);
|
||||
|
||||
std::vector<int64> start_indices(input_shape.dims(), 0);
|
||||
std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
|
||||
limit_indices[last_dim] = k;
|
||||
std::vector<int64> strides(input_shape.dims(), 1);
|
||||
|
||||
xla::XlaOp values =
|
||||
xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0),
|
||||
/*start_indices=*/{0},
|
||||
/*limit_indices=*/{k},
|
||||
/*strides=*/{1}));
|
||||
xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), start_indices,
|
||||
limit_indices, strides));
|
||||
xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
|
||||
/*start_indices=*/{0},
|
||||
/*limit_indices=*/{k},
|
||||
/*strides=*/{1});
|
||||
start_indices, limit_indices, strides);
|
||||
context->SetOutput(0, values);
|
||||
context->SetOutput(1, indices);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user