Add support for rank >= 1 tensors for XLA top_k_v2.

PiperOrigin-RevId: 205146612
This commit is contained in:
A. Unique TensorFlower 2018-07-18 15:21:45 -07:00 committed by TensorFlower Gardener
parent 7d79c72ba0
commit e3006b1d70
2 changed files with 53 additions and 19 deletions

View File

@ -88,6 +88,38 @@ class XlaSortOpTest(xla_test.XLATestCase):
topk, [x.astype(dtype)],
expected=[x[indices].astype(dtype), indices])
def testTopK2D(self):
# TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
if self.device in ["XLA_CPU", "XLA_GPU"]:
return
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
for dtype in supported_types.intersection(self.numeric_types):
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
# after conversion to bfloat16, so the possible resulting index array is
# no longer unique.
if dtype == dtypes.bfloat16.as_numpy_dtype:
array_size = 10
k_options = [0, 1, 2, 10]
else:
array_size = 200 * 1000
k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
batch = 16
for x in [np.arange(batch * array_size)]:
np.random.shuffle(x)
x = np.reshape(x, [batch, array_size])
for k in k_options:
indices = x.argsort(axis=1)[::, -1:-k - 1:-1]
expected = np.sort(x, axis=1)[::, -1:-k - 1:-1]
def topk(v, k=k):
return nn_ops.top_k(v, k=k, sorted=True)
self._assertOpOutputMatchesExpected(
topk, [x.astype(dtype)],
expected=[expected.astype(dtype), indices])
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
# TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.

View File

@ -41,33 +41,35 @@ class TopKOp : public XlaOpKernel {
OP_REQUIRES(context, input_shape.dims() >= 1,
errors::InvalidArgument("input must be >= 1-D, got shape ",
input_shape.DebugString()));
int last_dim = input_shape.dims() - 1;
int last_dim_size = input_shape.dim_size(last_dim);
OP_REQUIRES(
context, input_shape.dim_size(input_shape.dims() - 1) >= k,
context, last_dim_size >= k,
errors::InvalidArgument("input must have at least k columns. Had ",
input_shape.dim_size(input_shape.dims() - 1),
", needed ", k));
OP_REQUIRES(
context, input_shape.dims() == 1,
errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ",
input_shape.DebugString()));
last_dim_size, ", needed ", k));
xla::XlaBuilder* const b = context->builder();
if (input_shape.dim_size(0) < k) {
k = input_shape.dim_size(0);
if (last_dim_size < k) {
k = last_dim_size;
}
const xla::XlaOp input = context->Input(0);
xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, input_shape.dim_size(0));
xla::XlaOp sort_result = xla::Sort(xla::Neg(input), iota_s32);
xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, last_dim_size);
auto input_dims = input_shape.dim_sizes();
std::vector<int64> broadcast_dims(input_dims.begin(), input_dims.end() - 1);
xla::XlaOp broadcast_s32 = xla::Broadcast(iota_s32, broadcast_dims);
xla::XlaOp sort_result = xla::Sort(xla::Neg(input), broadcast_s32);
std::vector<int64> start_indices(input_shape.dims(), 0);
std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
limit_indices[last_dim] = k;
std::vector<int64> strides(input_shape.dims(), 1);
xla::XlaOp values =
xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0),
/*start_indices=*/{0},
/*limit_indices=*/{k},
/*strides=*/{1}));
xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), start_indices,
limit_indices, strides));
xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
/*start_indices=*/{0},
/*limit_indices=*/{k},
/*strides=*/{1});
start_indices, limit_indices, strides);
context->SetOutput(0, values);
context->SetOutput(1, indices);
}