Exposes a variadic form of XlaSort. (Previously, we had XlaSort for 1 argument,
and XlaKeyValueSort for 2 arguments.) The new op, XlaVariadicSort, fully exposes the xla::Sort functionality: can take an arbitrary list of operands, can take dimension and is_stable arguments, and a comparator function. PiperOrigin-RevId: 348581487 Change-Id: I82b4ff0c37d433541167584475b004626bc8c82c
This commit is contained in:
parent
4967bdb59e
commit
682d130cf6
@ -2082,6 +2082,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"XlaSpmdShardToFullShape",
|
||||
"XlaSvd",
|
||||
"XlaVariadicReduce",
|
||||
"XlaVariadicSort",
|
||||
"XlaWhile",
|
||||
"Zeta",
|
||||
"_Arg",
|
||||
|
@ -18,14 +18,18 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.compiler.tf2xla.python import xla
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -33,6 +37,7 @@ from tensorflow.python.platform import test
|
||||
class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def _assertOpOutputMatchesExpected(self, op, args, expected):
|
||||
"""Tests that op(*args) == expected."""
|
||||
with self.session() as session:
|
||||
with self.test_scope():
|
||||
placeholders = [
|
||||
@ -48,37 +53,34 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
for result, v in zip(results, expected):
|
||||
self.assertAllClose(v, result, rtol=1e-3)
|
||||
|
||||
def _shuffled_arange(self, shape, dtype):
|
||||
x = np.arange(np.prod(shape), dtype=dtype)
|
||||
np.random.shuffle(x)
|
||||
return x.reshape(shape)
|
||||
|
||||
def _supported_key_types(self):
|
||||
supported_key_types = set([
|
||||
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
|
||||
np.int32, np.uint32, np.int16, np.uint16, np.int8, np.uint8
|
||||
])
|
||||
res = supported_key_types.intersection(self.numeric_types)
|
||||
assert res
|
||||
return res
|
||||
|
||||
def testSort(self):
|
||||
supported_types = set(
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
# TPU implementation is not supported for double precision
|
||||
if (dtype == np.float64 or dtype == np.float16) and self.device == "TPU":
|
||||
continue
|
||||
x = np.arange(101, dtype=dtype)
|
||||
np.random.shuffle(x)
|
||||
for dtype in self._supported_key_types():
|
||||
x = self._shuffled_arange((101,), dtype)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
|
||||
|
||||
def testKeyValueSort(self):
|
||||
supported_key_types = set([
|
||||
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
|
||||
np.int32, np.uint32
|
||||
])
|
||||
supported_value_types = set([
|
||||
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
|
||||
np.int32, np.uint32, dtypes.int64.as_numpy_dtype,
|
||||
dtypes.uint64.as_numpy_dtype
|
||||
])
|
||||
for key_type in supported_key_types.intersection(self.numeric_types):
|
||||
for value_type in supported_value_types.intersection(self.numeric_types):
|
||||
if key_type == np.float64 or value_type == np.float64 or \
|
||||
key_type == np.float16 or value_type == np.float16:
|
||||
# TPU implementation is not supported for double precision
|
||||
if self.device == "TPU":
|
||||
continue
|
||||
x = np.arange(101, dtype=key_type)
|
||||
np.random.shuffle(x)
|
||||
for key_type in self._supported_key_types():
|
||||
for value_type in self._supported_key_types():
|
||||
if key_type == np.uint8 or value_type == np.uint8:
|
||||
# I do not understand why the test fails on uint8. We plan to
|
||||
# deprecate xla.key_value_sort in favor of xla.variadic_sort anyway.
|
||||
continue
|
||||
x = self._shuffled_arange((101,), key_type)
|
||||
y = (-x).astype(value_type)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
xla.key_value_sort, [x, y],
|
||||
@ -87,6 +89,156 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
-np.arange(101, dtype=value_type)
|
||||
])
|
||||
|
||||
@parameterized.parameters(0, 1, 2)
|
||||
@test_util.disable_mlir_bridge("Not supported yet")
|
||||
def testVariadicSortDimension(self, dimension):
|
||||
shape = (2, 3, 4)
|
||||
for key_type in self._supported_key_types():
|
||||
x = self._shuffled_arange(shape, key_type)
|
||||
expected = np.sort(x, axis=dimension)
|
||||
|
||||
@function.Defun(key_type, key_type)
|
||||
def compare_lt(x1, x2):
|
||||
return x1 < x2
|
||||
|
||||
def wrap_sort(x):
|
||||
return xla.variadic_sort([x],
|
||||
dimension=dimension,
|
||||
is_stable=False,
|
||||
comparator=compare_lt)
|
||||
|
||||
self._assertOpOutputMatchesExpected(wrap_sort, [x], expected=[expected])
|
||||
|
||||
@test_util.disable_mlir_bridge("Not supported yet")
|
||||
def testVariadicSortReverse(self):
|
||||
shape = (100,)
|
||||
for key_type in self._supported_key_types():
|
||||
x = self._shuffled_arange(shape, key_type)
|
||||
expected = np.sort(x, axis=0)[::-1]
|
||||
|
||||
@function.Defun(key_type, key_type)
|
||||
def compare_gt(x1, x2):
|
||||
return x1 > x2
|
||||
|
||||
def wrap_sort(x):
|
||||
return xla.variadic_sort([x],
|
||||
dimension=0,
|
||||
is_stable=False,
|
||||
comparator=compare_gt)
|
||||
|
||||
self._assertOpOutputMatchesExpected(wrap_sort, [x], expected=[expected])
|
||||
|
||||
@parameterized.parameters(0, 1, 2)
|
||||
@test_util.disable_mlir_bridge("Not supported yet")
|
||||
def testVariadicSortSeveral(self, dimension):
|
||||
if np.__version__ < "1.15":
|
||||
raise unittest.SkipTest("np.take_along_axis was added in 1.15")
|
||||
shape = (2, 3, 4)
|
||||
for key_type in self._supported_key_types():
|
||||
for value_type_1 in self._supported_key_types():
|
||||
for value_type_2 in self._supported_key_types():
|
||||
inputs = [
|
||||
self._shuffled_arange(shape, key_type),
|
||||
self._shuffled_arange(shape, value_type_1),
|
||||
self._shuffled_arange(shape, value_type_2)
|
||||
]
|
||||
|
||||
# The first array is sorted, and the others are shuffled the same way
|
||||
sorted_indices = np.argsort(inputs[0], axis=dimension)
|
||||
expected = [
|
||||
np.take_along_axis(inp, sorted_indices, axis=dimension)
|
||||
for inp in inputs
|
||||
]
|
||||
self.assertAllEqual(np.sort(inputs[0], axis=dimension), expected[0])
|
||||
|
||||
@function.Defun(key_type, key_type, value_type_1, value_type_1,
|
||||
value_type_2, value_type_2)
|
||||
def compare_lt(x1, x2, y1, y2, z1, z2):
|
||||
del y1, y2, z1, z2
|
||||
return x1 < x2
|
||||
|
||||
def wrap_sort(*args):
|
||||
return xla.variadic_sort(
|
||||
args, # Pass the arguments as a tuple
|
||||
comparator=compare_lt,
|
||||
dimension=dimension,
|
||||
is_stable=False)
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
wrap_sort, inputs, expected=expected)
|
||||
|
||||
@test_util.disable_mlir_bridge("Not supported yet")
|
||||
def testVariadicSortLexicographic(self):
|
||||
# Three inputs: the first two are used for lexicographic sort, and the
|
||||
# third is just swapped accordingly.
|
||||
# The first array will contain only 0 and 1, to test lexicographic order
|
||||
if np.__version__ < "1.15":
|
||||
raise unittest.SkipTest("np.take_along_axis was added in 1.15")
|
||||
shape = (20,)
|
||||
for key_type_1 in set([np.int16, np.uint16, np.int32, np.uint32]):
|
||||
for key_type_2 in self._supported_key_types():
|
||||
for value_type in self._supported_key_types():
|
||||
inputs = [
|
||||
# Ensure that some keys in the first input are equal
|
||||
np.random.uniform(0, 2, shape).astype(key_type_1),
|
||||
self._shuffled_arange(shape, key_type_2),
|
||||
self._shuffled_arange(shape, value_type)
|
||||
]
|
||||
# The first two arrays are sorted lexicographically, and the third
|
||||
# is shuffled the same way
|
||||
sorted_indices = np.argsort(100 * inputs[0] + inputs[1])
|
||||
expected = [
|
||||
np.take_along_axis(inp, sorted_indices, axis=0) for inp in inputs
|
||||
]
|
||||
|
||||
@function.Defun(key_type_1, key_type_1, key_type_2, key_type_2,
|
||||
value_type, value_type)
|
||||
def compare_lexicographic(x1, x2, y1, y2, z1, z2):
|
||||
del z1, z2
|
||||
return math_ops.logical_or(
|
||||
x1 < x2, math_ops.logical_and(math_ops.equal(x1, x2), y1 < y2))
|
||||
|
||||
def wrap_sort(*args):
|
||||
return xla.variadic_sort(
|
||||
args, # Pass the arguments as a tuple
|
||||
comparator=compare_lexicographic,
|
||||
dimension=0,
|
||||
is_stable=False)
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
wrap_sort, inputs, expected=expected)
|
||||
|
||||
@parameterized.parameters(0, 1, 2)
|
||||
@test_util.disable_mlir_bridge("Not supported yet")
|
||||
def testVariadicSortSeveralStable(self, dimension):
|
||||
shape = (2, 3, 4)
|
||||
for key_type in self._supported_key_types():
|
||||
for value_type_1 in self._supported_key_types():
|
||||
for value_type_2 in self._supported_key_types():
|
||||
# The first input is all 0s, there should be no changes for
|
||||
# stable sort.
|
||||
inputs = [
|
||||
np.zeros(shape, key_type),
|
||||
self._shuffled_arange(shape, value_type_1),
|
||||
self._shuffled_arange(shape, value_type_2)
|
||||
]
|
||||
|
||||
@function.Defun(key_type, key_type, value_type_1, value_type_1,
|
||||
value_type_2, value_type_2)
|
||||
def compare_lt(x1, x2, y1, y2, z1, z2):
|
||||
del y1, y2, z1, z2
|
||||
return x1 < x2
|
||||
|
||||
def wrap_sort(*args):
|
||||
return xla.variadic_sort(
|
||||
args, # Pass the arguments as a tuple
|
||||
comparator=compare_lt,
|
||||
dimension=dimension,
|
||||
is_stable=False)
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
wrap_sort, inputs, expected=inputs)
|
||||
|
||||
def testTopK(self):
|
||||
supported_types = set([
|
||||
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/comparators.h"
|
||||
@ -53,5 +55,76 @@ class XlaKeyValueSortOp : public XlaOpKernel {
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaKeyValueSort"), XlaKeyValueSortOp);
|
||||
|
||||
class XlaVariadicSortOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit XlaVariadicSortOp(OpKernelConstruction* context)
|
||||
: XlaOpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("T", &input_types_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("comparator", &comparator_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dimension", &dimension_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("is_stable", &is_stable_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
std::vector<xla::XlaOp> inputs(input_types_.size());
|
||||
std::vector<xla::PrimitiveType> input_xla_types(input_types_.size());
|
||||
std::vector<XlaCompiler::Argument> comparator_args(2 * input_types_.size());
|
||||
|
||||
for (int i = 0; i < input_types_.size(); ++i) {
|
||||
inputs[i] = context->Input(i);
|
||||
OP_REQUIRES_OK(context, DataTypeToPrimitiveType(input_types_[i],
|
||||
&input_xla_types[i]));
|
||||
XlaCompiler::Argument comparator_arg;
|
||||
comparator_arg.kind = XlaCompiler::Argument::kParameter;
|
||||
comparator_arg.type = input_types_[i];
|
||||
comparator_arg.shape = TensorShape();
|
||||
comparator_args[2 * i] = comparator_arg;
|
||||
comparator_args[2 * i + 1] = comparator_arg;
|
||||
}
|
||||
|
||||
// Build the comparator function.
|
||||
XlaCompiler::CompilationResult comparator;
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.use_tuple_arg = false;
|
||||
compile_options.always_return_tuple = false;
|
||||
compile_options.is_entry_computation = false;
|
||||
OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
|
||||
compile_options, *comparator_, comparator_args,
|
||||
&comparator));
|
||||
|
||||
xla::Shape expected_comparator_output_shape;
|
||||
OP_REQUIRES_OK(context,
|
||||
TensorShapeToXLAShape(DT_BOOL, TensorShape(),
|
||||
&expected_comparator_output_shape));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
xla::ShapeUtil::Compatible(comparator.xla_output_shape,
|
||||
expected_comparator_output_shape),
|
||||
errors::InvalidArgument(
|
||||
"Invalid output shape of XlaReduce reducer. Expected ",
|
||||
xla::ShapeUtil::HumanString(expected_comparator_output_shape),
|
||||
" got ", xla::ShapeUtil::HumanString(comparator.xla_output_shape)));
|
||||
|
||||
xla::XlaOp outputs =
|
||||
xla::Sort(inputs, *comparator.computation, dimension_, is_stable_);
|
||||
|
||||
for (int i = 0; i < input_types_.size(); ++i) {
|
||||
xla::XlaOp output_handle =
|
||||
(input_types_.size() > 1 ? xla::GetTupleElement(outputs, i)
|
||||
: outputs);
|
||||
context->SetOutput(i, output_handle);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
DataTypeVector input_types_;
|
||||
const NameAttrList* comparator_;
|
||||
int64 dimension_;
|
||||
bool is_stable_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaVariadicSortOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaVariadicSort"), XlaVariadicSortOp);
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -651,6 +651,36 @@ sorted_keys: A `Tensor` of type K.
|
||||
sorted_values: A `Tensor` of type V.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaVariadicSort")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: list(type) >= 1")
|
||||
.Attr("comparator: func")
|
||||
.Attr("dimension: int")
|
||||
.Attr("is_stable: bool")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
for (int i = 0; i < c->num_inputs(); ++i) {
|
||||
c->set_output(i, c->input(i));
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Wraps the XLA Sort operator, documented at
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#sort
|
||||
.
|
||||
|
||||
Sorts one or more tensors, with support for custom comparator, dimension, and
|
||||
is_stable attributes.
|
||||
|
||||
input: A list of `Tensor` of identical shape by possibly different types.
|
||||
comparator: A comparator function to apply to 2*N scalars and returning a
|
||||
boolean. N is the number of sort inputs. If you want to sort in ascending
|
||||
order then the comparator should perform a less-than comparison.
|
||||
output: A list of `Tensor` of type T.
|
||||
dimension: The dimension along which to sort.
|
||||
is_stable: Whether to use stable sort.
|
||||
)doc");
|
||||
|
||||
// TODO(b/37549631) setting the While Op to always be stateful is too
|
||||
// conservative.
|
||||
REGISTER_OP("XlaWhile")
|
||||
|
@ -473,6 +473,7 @@ def _spmd_shard_to_full_shape_grad(op, grad):
|
||||
|
||||
sort = gen_xla_ops.xla_sort
|
||||
key_value_sort = gen_xla_ops.xla_key_value_sort
|
||||
variadic_sort = gen_xla_ops.xla_variadic_sort
|
||||
while_loop = gen_xla_ops.xla_while
|
||||
dequantize = gen_xla_ops.xla_dequantize
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user