diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 317e29d4a84..ada7766fcbb 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -2066,6 +2066,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "XlaSpmdFullToShardShape", "XlaSpmdShardToFullShape", "XlaSvd", + "XlaVariadicReduce", "XlaWhile", "Zeta", "_Arg", diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 3e9f5e8c5dd..b80b6263992 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -18,12 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function @@ -299,6 +302,78 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), expected=np.array([0, 45, 120, 231], dtype=dtype)) + @test_util.disable_mlir_bridge('Not supported yet') + def testVariadicReduce(self): + for dtype in set(self.numeric_types).intersection( + set([np.float32, np.complex64])): + + @def_function.function + def kahan_sum_reducer(t0, t1): + (s0, c0), (s1, c1) = t0, t1 + s0minusc = s0 - (c0 + c1) + t = s1 + s0minusc + c = (t - s1) - s0minusc + s = t + return s, c + + def kahan_sum_reduction(dims, output_idx): + + def fn(x): + arg = array_ops.zeros([], dtype) # pylint: disable=cell-var-from-loop + reducer = kahan_sum_reducer.get_concrete_function( + (arg, arg), (arg, arg)) + + return xla.variadic_reduce( + (x, array_ops.zeros_like(x)), + init_value=(arg, arg), + dimensions_to_reduce=dims, + reducer=reducer)[output_idx] + + return fn + + xs = np.array([1e5, np.pi, -1e5, np.exp(1.)]) + xs = np.array([xs, xs[::-1] / 3, xs / 7], dtype) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[], output_idx=0), + args=(xs,), expected=xs) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[], output_idx=1), + args=(xs,), expected=np.zeros_like(xs)) + shuffle_indices = np.argsort(np.random.randn(xs.shape[0])) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[0], output_idx=0), + args=(xs[shuffle_indices],), + expected=np.array([np.exp(1) / 3 + 1e5 * 8 / 7, + np.pi * 8 / 7 - 1e5 / 3, + -1e5 * 8 / 7 + np.pi / 3, + np.exp(1) * 8 / 7 + 1e5 / 3], dtype=dtype)) + error_term_equality = functools.partial(self.assertAllClose, atol=.005) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[0], output_idx=1), + args=(xs[shuffle_indices],), expected=np.zeros_like(xs[0]), + equality_fn=error_term_equality) + shuffle_indices = np.argsort(np.random.randn(xs.shape[1])) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[1], output_idx=0), + args=(xs[:, shuffle_indices],), + expected=np.array([np.pi + np.exp(1.), + (np.pi + np.exp(1.)) / 3, + (np.pi + np.exp(1.)) / 7], dtype=dtype)) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[1], output_idx=1), + args=(xs[:, shuffle_indices],), expected=np.zeros_like(xs[:, 0]), + equality_fn=error_term_equality) + # Now, shuffle both dims. + xs = xs[np.argsort(np.random.randn(xs.shape[0]))] + xs = xs[:, np.argsort(np.random.randn(xs.shape[1]))] + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[0, 1], output_idx=0), + args=(xs,), expected=dtype((np.pi + np.exp(1.)) * 31 / 21)) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[0, 1], output_idx=1), + args=(xs,), expected=dtype(0), + equality_fn=error_term_equality) + @test_util.disable_mlir_bridge('Not supported yet') def testSelectAndScatter(self): for dtype in set(self.numeric_types).intersection( diff --git a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc index 8b481d55a80..555905ebe6b 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace { @@ -38,16 +39,28 @@ class XlaReduceOp : public XlaOpKernel { context, dims_set.size() == dimensions_to_reduce_.size(), errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce " "argument to XlaReduce")); + if (context->HasAttr("N")) { // variadic reduce + use_tuples_ = true; + OP_REQUIRES_OK(context, context->GetAttr("N", &n_)); + } else { + use_tuples_ = false; + n_ = 1; + } } void Compile(XlaOpKernelContext* context) override { - const TensorShape input_shape = context->InputShape("input"); - const TensorShape init_value_shape = context->InputShape("init_value"); + OP_REQUIRES(context, n_ * 2 == context->num_inputs(), + errors::InvalidArgument("Expected ", n_ * 2, " inputs but got ", + context->num_inputs())); + + const TensorShape input_shape = context->InputShape(0); + const TensorShape init_value_shape = context->InputShape(n_); const DataType dtype = context->input_type(0); const int rank = input_shape.dims(); OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape), - errors::InvalidArgument("init_value must be a scalar")); + errors::InvalidArgument("init_value must be a scalar but got ", + init_value_shape.DebugString())); auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; }; OP_REQUIRES(context, @@ -67,35 +80,58 @@ class XlaReduceOp : public XlaOpKernel { compile_options.always_return_tuple = false; compile_options.is_entry_computation = false; XlaCompiler::CompilationResult reducer; - OP_REQUIRES_OK(context, context->compiler()->CompileFunction( - compile_options, *reducer_, - {reducer_arg, reducer_arg}, &reducer)); + OP_REQUIRES_OK( + context, + context->compiler()->CompileFunction( + compile_options, *reducer_, + std::vector(n_ * 2, reducer_arg), &reducer)); - xla::Shape scalar_shape; - OP_REQUIRES_OK(context, - TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + xla::Shape expected_shape; + OP_REQUIRES_OK( + context, TensorShapeToXLAShape(dtype, TensorShape(), &expected_shape)); + if (use_tuples_) { + expected_shape = xla::ShapeUtil::MakeTupleShape( + std::vector(n_, expected_shape)); + } OP_REQUIRES( context, - xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape), + xla::ShapeUtil::Compatible(reducer.xla_output_shape, expected_shape), errors::InvalidArgument( "Invalid output shape of XlaReduce reducer. Expected ", - xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(expected_shape), " got ", xla::ShapeUtil::HumanString(reducer.xla_output_shape))); + std::vector inputs; + std::vector inits; + inputs.reserve(n_); + inits.reserve(n_); + for (int i = 0; i < n_; i++) { + inputs.emplace_back(context->Input(i)); + inits.emplace_back(context->Input(n_ + i)); + } xla::XlaOp output = - xla::Reduce(context->Input("input"), context->Input("init_value"), - *reducer.computation, dimensions_to_reduce_); - context->SetOutput(0, output); + xla::Reduce(context->builder(), inputs, inits, *reducer.computation, + dimensions_to_reduce_); + if (use_tuples_) { + for (int i = 0; i < n_; i++) { + context->SetOutput(i, xla::GetTupleElement(output, i)); + } + } else { + context->SetOutput(0, output); + } } private: const NameAttrList* reducer_; std::vector dimensions_to_reduce_; + bool use_tuples_; + int n_; TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp); }; REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp); +REGISTER_XLA_OP(Name("XlaVariadicReduce"), XlaReduceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index f73d2b109a1..00cbe7bc9bc 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -475,6 +475,60 @@ reducer: a reducer function to apply dimensions_to_reduce: dimension numbers over which to reduce )doc"); +REGISTER_OP("XlaVariadicReduce") + .Input("input: N * T") + .Input("init_value: N * T") + .Attr("N: int >= 1") + .Attr("T: numbertype") + .Attr("dimensions_to_reduce: list(int)") + .Attr("reducer: func") + .Output("output: N * T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + int n; + TF_RETURN_IF_ERROR(c->GetAttr("N", &n)); + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + c->MergeInput(i, c->input(j)); + } + } + if (c->RankKnown(c->input(0))) { + int rank = c->Rank(c->input(0)); + std::vector dimensions_to_reduce; + TF_RETURN_IF_ERROR( + c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce)); + std::set dims_set(dimensions_to_reduce.begin(), + dimensions_to_reduce.end()); + auto dim_in_range = [rank](int64 dim) { + return dim >= 0 && dim < rank; + }; + const int dimensions_to_reduce_size = dimensions_to_reduce.size(); + if (rank < dimensions_to_reduce_size || + dims_set.size() != dimensions_to_reduce.size() || + !absl::c_all_of(dimensions_to_reduce, dim_in_range)) { + return errors::InvalidArgument( + "Invalid dimensions_to_reduce argument to XlaVariadicReduce"); + } + for (int i = 0; i < n; i++) { + c->set_output( + i, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size())); + } + } else { + for (int i = 0; i < n; i++) { + c->set_output(i, c->input(i)); + } + } + return Status::OK(); + }) + .Doc(R"doc( +Wraps the variadic XLA Reduce operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce. + +input: the input tensor(s) +init_value: scalar initial value(s) for the reduction +reducer: a reducer function to apply +dimensions_to_reduce: dimension numbers over which to reduce +)doc"); + REGISTER_OP("XlaReduceWindow") .Input("input: T") .Input("init_value: T") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 6278ea1a3a9..2e5667bc02f 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -340,6 +340,7 @@ def random_uniform(minval, maxval, dims, name=None): recv = gen_xla_ops.xla_recv reduce = gen_xla_ops.xla_reduce +variadic_reduce = gen_xla_ops.xla_variadic_reduce def reduce_window(operand,