Exposes a variadic form of XlaSort. (Previously, we had XlaSort for 1 argument,

and XlaKeyValueSort for 2 arguments.)

The new op, XlaVariadicSort, fully exposes the xla::Sort functionality: can
take an arbitrary list of operands, can take dimension and
is_stable arguments, and a comparator function.

PiperOrigin-RevId: 348581487
Change-Id: I82b4ff0c37d433541167584475b004626bc8c82c
This commit is contained in:
A. Unique TensorFlower 2020-12-21 22:29:39 -08:00 committed by TensorFlower Gardener
parent 4967bdb59e
commit 682d130cf6
5 changed files with 283 additions and 26 deletions

View File

@ -2082,6 +2082,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"XlaSpmdShardToFullShape",
"XlaSvd",
"XlaVariadicReduce",
"XlaVariadicSort",
"XlaWhile",
"Zeta",
"_Arg",

View File

@ -18,14 +18,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
@ -33,6 +37,7 @@ from tensorflow.python.platform import test
class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected):
"""Tests that op(*args) == expected."""
with self.session() as session:
with self.test_scope():
placeholders = [
@ -48,37 +53,34 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
for result, v in zip(results, expected):
self.assertAllClose(v, result, rtol=1e-3)
def _shuffled_arange(self, shape, dtype):
x = np.arange(np.prod(shape), dtype=dtype)
np.random.shuffle(x)
return x.reshape(shape)
def _supported_key_types(self):
supported_key_types = set([
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
np.int32, np.uint32, np.int16, np.uint16, np.int8, np.uint8
])
res = supported_key_types.intersection(self.numeric_types)
assert res
return res
def testSort(self):
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64])
for dtype in supported_types.intersection(self.numeric_types):
# TPU implementation is not supported for double precision
if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
continue
x = np.arange(101, dtype=dtype)
np.random.shuffle(x)
for dtype in self._supported_key_types():
x = self._shuffled_arange((101,), dtype)
self._assertOpOutputMatchesExpected(
xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
def testKeyValueSort(self):
supported_key_types = set([
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
np.int32, np.uint32
])
supported_value_types = set([
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
np.int32, np.uint32, dtypes.int64.as_numpy_dtype,
dtypes.uint64.as_numpy_dtype
])
for key_type in supported_key_types.intersection(self.numeric_types):
for value_type in supported_value_types.intersection(self.numeric_types):
if key_type == np.float64 or value_type == np.float64 or \
key_type == np.float16 or value_type == np.float16:
# TPU implementation is not supported for double precision
if self.device == "TPU":
continue
x = np.arange(101, dtype=key_type)
np.random.shuffle(x)
for key_type in self._supported_key_types():
for value_type in self._supported_key_types():
if key_type == np.uint8 or value_type == np.uint8:
# I do not understand why the test fails on uint8. We plan to
# deprecate xla.key_value_sort in favor of xla.variadic_sort anyway.
continue
x = self._shuffled_arange((101,), key_type)
y = (-x).astype(value_type)
self._assertOpOutputMatchesExpected(
xla.key_value_sort, [x, y],
@ -87,6 +89,156 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
-np.arange(101, dtype=value_type)
])
@parameterized.parameters(0, 1, 2)
@test_util.disable_mlir_bridge("Not supported yet")
def testVariadicSortDimension(self, dimension):
shape = (2, 3, 4)
for key_type in self._supported_key_types():
x = self._shuffled_arange(shape, key_type)
expected = np.sort(x, axis=dimension)
@function.Defun(key_type, key_type)
def compare_lt(x1, x2):
return x1 < x2
def wrap_sort(x):
return xla.variadic_sort([x],
dimension=dimension,
is_stable=False,
comparator=compare_lt)
self._assertOpOutputMatchesExpected(wrap_sort, [x], expected=[expected])
@test_util.disable_mlir_bridge("Not supported yet")
def testVariadicSortReverse(self):
shape = (100,)
for key_type in self._supported_key_types():
x = self._shuffled_arange(shape, key_type)
expected = np.sort(x, axis=0)[::-1]
@function.Defun(key_type, key_type)
def compare_gt(x1, x2):
return x1 > x2
def wrap_sort(x):
return xla.variadic_sort([x],
dimension=0,
is_stable=False,
comparator=compare_gt)
self._assertOpOutputMatchesExpected(wrap_sort, [x], expected=[expected])
@parameterized.parameters(0, 1, 2)
@test_util.disable_mlir_bridge("Not supported yet")
def testVariadicSortSeveral(self, dimension):
if np.__version__ < "1.15":
raise unittest.SkipTest("np.take_along_axis was added in 1.15")
shape = (2, 3, 4)
for key_type in self._supported_key_types():
for value_type_1 in self._supported_key_types():
for value_type_2 in self._supported_key_types():
inputs = [
self._shuffled_arange(shape, key_type),
self._shuffled_arange(shape, value_type_1),
self._shuffled_arange(shape, value_type_2)
]
# The first array is sorted, and the others are shuffled the same way
sorted_indices = np.argsort(inputs[0], axis=dimension)
expected = [
np.take_along_axis(inp, sorted_indices, axis=dimension)
for inp in inputs
]
self.assertAllEqual(np.sort(inputs[0], axis=dimension), expected[0])
@function.Defun(key_type, key_type, value_type_1, value_type_1,
value_type_2, value_type_2)
def compare_lt(x1, x2, y1, y2, z1, z2):
del y1, y2, z1, z2
return x1 < x2
def wrap_sort(*args):
return xla.variadic_sort(
args, # Pass the arguments as a tuple
comparator=compare_lt,
dimension=dimension,
is_stable=False)
self._assertOpOutputMatchesExpected(
wrap_sort, inputs, expected=expected)
@test_util.disable_mlir_bridge("Not supported yet")
def testVariadicSortLexicographic(self):
# Three inputs: the first two are used for lexicographic sort, and the
# third is just swapped accordingly.
# The first array will contain only 0 and 1, to test lexicographic order
if np.__version__ < "1.15":
raise unittest.SkipTest("np.take_along_axis was added in 1.15")
shape = (20,)
for key_type_1 in set([np.int16, np.uint16, np.int32, np.uint32]):
for key_type_2 in self._supported_key_types():
for value_type in self._supported_key_types():
inputs = [
# Ensure that some keys in the first input are equal
np.random.uniform(0, 2, shape).astype(key_type_1),
self._shuffled_arange(shape, key_type_2),
self._shuffled_arange(shape, value_type)
]
# The first two arrays are sorted lexicographically, and the third
# is shuffled the same way
sorted_indices = np.argsort(100 * inputs[0] + inputs[1])
expected = [
np.take_along_axis(inp, sorted_indices, axis=0) for inp in inputs
]
@function.Defun(key_type_1, key_type_1, key_type_2, key_type_2,
value_type, value_type)
def compare_lexicographic(x1, x2, y1, y2, z1, z2):
del z1, z2
return math_ops.logical_or(
x1 < x2, math_ops.logical_and(math_ops.equal(x1, x2), y1 < y2))
def wrap_sort(*args):
return xla.variadic_sort(
args, # Pass the arguments as a tuple
comparator=compare_lexicographic,
dimension=0,
is_stable=False)
self._assertOpOutputMatchesExpected(
wrap_sort, inputs, expected=expected)
@parameterized.parameters(0, 1, 2)
@test_util.disable_mlir_bridge("Not supported yet")
def testVariadicSortSeveralStable(self, dimension):
shape = (2, 3, 4)
for key_type in self._supported_key_types():
for value_type_1 in self._supported_key_types():
for value_type_2 in self._supported_key_types():
# The first input is all 0s, there should be no changes for
# stable sort.
inputs = [
np.zeros(shape, key_type),
self._shuffled_arange(shape, value_type_1),
self._shuffled_arange(shape, value_type_2)
]
@function.Defun(key_type, key_type, value_type_1, value_type_1,
value_type_2, value_type_2)
def compare_lt(x1, x2, y1, y2, z1, z2):
del y1, y2, z1, z2
return x1 < x2
def wrap_sort(*args):
return xla.variadic_sort(
args, # Pass the arguments as a tuple
comparator=compare_lt,
dimension=dimension,
is_stable=False)
self._assertOpOutputMatchesExpected(
wrap_sort, inputs, expected=inputs)
def testTopK(self):
supported_types = set([
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
@ -53,5 +55,76 @@ class XlaKeyValueSortOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("XlaKeyValueSort"), XlaKeyValueSortOp);
class XlaVariadicSortOp : public XlaOpKernel {
public:
explicit XlaVariadicSortOp(OpKernelConstruction* context)
: XlaOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("T", &input_types_));
OP_REQUIRES_OK(context, context->GetAttr("comparator", &comparator_));
OP_REQUIRES_OK(context, context->GetAttr("dimension", &dimension_));
OP_REQUIRES_OK(context, context->GetAttr("is_stable", &is_stable_));
}
void Compile(XlaOpKernelContext* context) override {
std::vector<xla::XlaOp> inputs(input_types_.size());
std::vector<xla::PrimitiveType> input_xla_types(input_types_.size());
std::vector<XlaCompiler::Argument> comparator_args(2 * input_types_.size());
for (int i = 0; i < input_types_.size(); ++i) {
inputs[i] = context->Input(i);
OP_REQUIRES_OK(context, DataTypeToPrimitiveType(input_types_[i],
&input_xla_types[i]));
XlaCompiler::Argument comparator_arg;
comparator_arg.kind = XlaCompiler::Argument::kParameter;
comparator_arg.type = input_types_[i];
comparator_arg.shape = TensorShape();
comparator_args[2 * i] = comparator_arg;
comparator_args[2 * i + 1] = comparator_arg;
}
// Build the comparator function.
XlaCompiler::CompilationResult comparator;
XlaCompiler::CompileOptions compile_options;
compile_options.use_tuple_arg = false;
compile_options.always_return_tuple = false;
compile_options.is_entry_computation = false;
OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
compile_options, *comparator_, comparator_args,
&comparator));
xla::Shape expected_comparator_output_shape;
OP_REQUIRES_OK(context,
TensorShapeToXLAShape(DT_BOOL, TensorShape(),
&expected_comparator_output_shape));
OP_REQUIRES(
context,
xla::ShapeUtil::Compatible(comparator.xla_output_shape,
expected_comparator_output_shape),
errors::InvalidArgument(
"Invalid output shape of XlaReduce reducer. Expected ",
xla::ShapeUtil::HumanString(expected_comparator_output_shape),
" got ", xla::ShapeUtil::HumanString(comparator.xla_output_shape)));
xla::XlaOp outputs =
xla::Sort(inputs, *comparator.computation, dimension_, is_stable_);
for (int i = 0; i < input_types_.size(); ++i) {
xla::XlaOp output_handle =
(input_types_.size() > 1 ? xla::GetTupleElement(outputs, i)
: outputs);
context->SetOutput(i, output_handle);
}
}
private:
DataTypeVector input_types_;
const NameAttrList* comparator_;
int64 dimension_;
bool is_stable_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaVariadicSortOp);
};
REGISTER_XLA_OP(Name("XlaVariadicSort"), XlaVariadicSortOp);
} // namespace
} // namespace tensorflow

View File

@ -651,6 +651,36 @@ sorted_keys: A `Tensor` of type K.
sorted_values: A `Tensor` of type V.
)doc");
REGISTER_OP("XlaVariadicSort")
.Input("input: T")
.Output("output: T")
.Attr("T: list(type) >= 1")
.Attr("comparator: func")
.Attr("dimension: int")
.Attr("is_stable: bool")
.SetShapeFn([](shape_inference::InferenceContext* c) {
for (int i = 0; i < c->num_inputs(); ++i) {
c->set_output(i, c->input(i));
}
return Status::OK();
})
.Doc(R"doc(
Wraps the XLA Sort operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#sort
.
Sorts one or more tensors, with support for custom comparator, dimension, and
is_stable attributes.
input: A list of `Tensor` of identical shape by possibly different types.
comparator: A comparator function to apply to 2*N scalars and returning a
boolean. N is the number of sort inputs. If you want to sort in ascending
order then the comparator should perform a less-than comparison.
output: A list of `Tensor` of type T.
dimension: The dimension along which to sort.
is_stable: Whether to use stable sort.
)doc");
// TODO(b/37549631) setting the While Op to always be stateful is too
// conservative.
REGISTER_OP("XlaWhile")

View File

@ -473,6 +473,7 @@ def _spmd_shard_to_full_shape_grad(op, grad):
sort = gen_xla_ops.xla_sort
key_value_sort = gen_xla_ops.xla_key_value_sort
variadic_sort = gen_xla_ops.xla_variadic_sort
while_loop = gen_xla_ops.xla_while
dequantize = gen_xla_ops.xla_dequantize