Add support for rank >= 1 tensors for XLA top_k_v2.
PiperOrigin-RevId: 205146612
This commit is contained in:
parent
7d79c72ba0
commit
e3006b1d70
@ -88,6 +88,38 @@ class XlaSortOpTest(xla_test.XLATestCase):
|
|||||||
topk, [x.astype(dtype)],
|
topk, [x.astype(dtype)],
|
||||||
expected=[x[indices].astype(dtype), indices])
|
expected=[x[indices].astype(dtype), indices])
|
||||||
|
|
||||||
|
def testTopK2D(self):
|
||||||
|
# TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
|
||||||
|
if self.device in ["XLA_CPU", "XLA_GPU"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
supported_types = set(
|
||||||
|
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
|
||||||
|
for dtype in supported_types.intersection(self.numeric_types):
|
||||||
|
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
|
||||||
|
# after conversion to bfloat16, so the possible resulting index array is
|
||||||
|
# no longer unique.
|
||||||
|
if dtype == dtypes.bfloat16.as_numpy_dtype:
|
||||||
|
array_size = 10
|
||||||
|
k_options = [0, 1, 2, 10]
|
||||||
|
else:
|
||||||
|
array_size = 200 * 1000
|
||||||
|
k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
|
||||||
|
batch = 16
|
||||||
|
for x in [np.arange(batch * array_size)]:
|
||||||
|
np.random.shuffle(x)
|
||||||
|
x = np.reshape(x, [batch, array_size])
|
||||||
|
for k in k_options:
|
||||||
|
indices = x.argsort(axis=1)[::, -1:-k - 1:-1]
|
||||||
|
expected = np.sort(x, axis=1)[::, -1:-k - 1:-1]
|
||||||
|
|
||||||
|
def topk(v, k=k):
|
||||||
|
return nn_ops.top_k(v, k=k, sorted=True)
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
topk, [x.astype(dtype)],
|
||||||
|
expected=[expected.astype(dtype), indices])
|
||||||
|
|
||||||
def testTopKZeros(self):
|
def testTopKZeros(self):
|
||||||
"""Tests that positive and negative zeros sort correctly."""
|
"""Tests that positive and negative zeros sort correctly."""
|
||||||
# TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
|
# TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
|
||||||
|
@ -41,33 +41,35 @@ class TopKOp : public XlaOpKernel {
|
|||||||
OP_REQUIRES(context, input_shape.dims() >= 1,
|
OP_REQUIRES(context, input_shape.dims() >= 1,
|
||||||
errors::InvalidArgument("input must be >= 1-D, got shape ",
|
errors::InvalidArgument("input must be >= 1-D, got shape ",
|
||||||
input_shape.DebugString()));
|
input_shape.DebugString()));
|
||||||
|
int last_dim = input_shape.dims() - 1;
|
||||||
|
int last_dim_size = input_shape.dim_size(last_dim);
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, input_shape.dim_size(input_shape.dims() - 1) >= k,
|
context, last_dim_size >= k,
|
||||||
errors::InvalidArgument("input must have at least k columns. Had ",
|
errors::InvalidArgument("input must have at least k columns. Had ",
|
||||||
input_shape.dim_size(input_shape.dims() - 1),
|
last_dim_size, ", needed ", k));
|
||||||
", needed ", k));
|
|
||||||
|
|
||||||
OP_REQUIRES(
|
|
||||||
context, input_shape.dims() == 1,
|
|
||||||
errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ",
|
|
||||||
input_shape.DebugString()));
|
|
||||||
|
|
||||||
xla::XlaBuilder* const b = context->builder();
|
xla::XlaBuilder* const b = context->builder();
|
||||||
if (input_shape.dim_size(0) < k) {
|
if (last_dim_size < k) {
|
||||||
k = input_shape.dim_size(0);
|
k = last_dim_size;
|
||||||
}
|
}
|
||||||
const xla::XlaOp input = context->Input(0);
|
const xla::XlaOp input = context->Input(0);
|
||||||
xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, input_shape.dim_size(0));
|
|
||||||
xla::XlaOp sort_result = xla::Sort(xla::Neg(input), iota_s32);
|
xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, last_dim_size);
|
||||||
|
auto input_dims = input_shape.dim_sizes();
|
||||||
|
std::vector<int64> broadcast_dims(input_dims.begin(), input_dims.end() - 1);
|
||||||
|
xla::XlaOp broadcast_s32 = xla::Broadcast(iota_s32, broadcast_dims);
|
||||||
|
xla::XlaOp sort_result = xla::Sort(xla::Neg(input), broadcast_s32);
|
||||||
|
|
||||||
|
std::vector<int64> start_indices(input_shape.dims(), 0);
|
||||||
|
std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
|
||||||
|
limit_indices[last_dim] = k;
|
||||||
|
std::vector<int64> strides(input_shape.dims(), 1);
|
||||||
|
|
||||||
xla::XlaOp values =
|
xla::XlaOp values =
|
||||||
xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0),
|
xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), start_indices,
|
||||||
/*start_indices=*/{0},
|
limit_indices, strides));
|
||||||
/*limit_indices=*/{k},
|
|
||||||
/*strides=*/{1}));
|
|
||||||
xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
|
xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
|
||||||
/*start_indices=*/{0},
|
start_indices, limit_indices, strides);
|
||||||
/*limit_indices=*/{k},
|
|
||||||
/*strides=*/{1});
|
|
||||||
context->SetOutput(0, values);
|
context->SetOutput(0, values);
|
||||||
context->SetOutput(1, indices);
|
context->SetOutput(1, indices);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user