From e3006b1d706fb171525cdd5cfe3a2305d6a5d879 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Jul 2018 15:21:45 -0700 Subject: [PATCH] Add support for rank >= 1 tensors for XLA top_k_v2. PiperOrigin-RevId: 205146612 --- tensorflow/compiler/tests/sort_ops_test.py | 32 +++++++++++++++ tensorflow/compiler/tf2xla/kernels/topk_op.cc | 40 ++++++++++--------- 2 files changed, 53 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 9e2ef964a1f..7ff01be3cb4 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -88,6 +88,38 @@ class XlaSortOpTest(xla_test.XLATestCase): topk, [x.astype(dtype)], expected=[x[indices].astype(dtype), indices]) + def testTopK2D(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + for dtype in supported_types.intersection(self.numeric_types): + # Use small input size for bfloat16. Otherwise, we'll get duplicate values + # after conversion to bfloat16, so the possible resulting index array is + # no longer unique. + if dtype == dtypes.bfloat16.as_numpy_dtype: + array_size = 10 + k_options = [0, 1, 2, 10] + else: + array_size = 200 * 1000 + k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] + batch = 16 + for x in [np.arange(batch * array_size)]: + np.random.shuffle(x) + x = np.reshape(x, [batch, array_size]) + for k in k_options: + indices = x.argsort(axis=1)[::, -1:-k - 1:-1] + expected = np.sort(x, axis=1)[::, -1:-k - 1:-1] + + def topk(v, k=k): + return nn_ops.top_k(v, k=k, sorted=True) + + self._assertOpOutputMatchesExpected( + topk, [x.astype(dtype)], + expected=[expected.astype(dtype), indices]) + def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 1ddcb08c8e1..82d4a69777b 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -41,33 +41,35 @@ class TopKOp : public XlaOpKernel { OP_REQUIRES(context, input_shape.dims() >= 1, errors::InvalidArgument("input must be >= 1-D, got shape ", input_shape.DebugString())); + int last_dim = input_shape.dims() - 1; + int last_dim_size = input_shape.dim_size(last_dim); OP_REQUIRES( - context, input_shape.dim_size(input_shape.dims() - 1) >= k, + context, last_dim_size >= k, errors::InvalidArgument("input must have at least k columns. Had ", - input_shape.dim_size(input_shape.dims() - 1), - ", needed ", k)); - - OP_REQUIRES( - context, input_shape.dims() == 1, - errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ", - input_shape.DebugString())); + last_dim_size, ", needed ", k)); xla::XlaBuilder* const b = context->builder(); - if (input_shape.dim_size(0) < k) { - k = input_shape.dim_size(0); + if (last_dim_size < k) { + k = last_dim_size; } const xla::XlaOp input = context->Input(0); - xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, input_shape.dim_size(0)); - xla::XlaOp sort_result = xla::Sort(xla::Neg(input), iota_s32); + + xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, last_dim_size); + auto input_dims = input_shape.dim_sizes(); + std::vector broadcast_dims(input_dims.begin(), input_dims.end() - 1); + xla::XlaOp broadcast_s32 = xla::Broadcast(iota_s32, broadcast_dims); + xla::XlaOp sort_result = xla::Sort(xla::Neg(input), broadcast_s32); + + std::vector start_indices(input_shape.dims(), 0); + std::vector limit_indices(input_dims.begin(), input_dims.end()); + limit_indices[last_dim] = k; + std::vector strides(input_shape.dims(), 1); + xla::XlaOp values = - xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), - /*start_indices=*/{0}, - /*limit_indices=*/{k}, - /*strides=*/{1})); + xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), start_indices, + limit_indices, strides)); xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1), - /*start_indices=*/{0}, - /*limit_indices=*/{k}, - /*strides=*/{1}); + start_indices, limit_indices, strides); context->SetOutput(0, values); context->SetOutput(1, indices); }