Exposes variadic reduce to TF python via xla.py
PiperOrigin-RevId: 336936774 Change-Id: Iee8bf17a18594abe2a301802eade9e7c4f4c5e34
This commit is contained in:
parent
9a04369ca8
commit
4e330bfbd1
tensorflow/compiler
jit
tests
tf2xla
@ -2066,6 +2066,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"XlaSpmdFullToShardShape",
|
||||
"XlaSpmdShardToFullShape",
|
||||
"XlaSvd",
|
||||
"XlaVariadicReduce",
|
||||
"XlaWhile",
|
||||
"Zeta",
|
||||
"_Arg",
|
||||
|
@ -18,12 +18,15 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.compiler.tf2xla.python import xla
|
||||
from tensorflow.compiler.xla import xla_data_pb2
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import function
|
||||
@ -299,6 +302,78 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
|
||||
expected=np.array([0, 45, 120, 231], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testVariadicReduce(self):
|
||||
for dtype in set(self.numeric_types).intersection(
|
||||
set([np.float32, np.complex64])):
|
||||
|
||||
@def_function.function
|
||||
def kahan_sum_reducer(t0, t1):
|
||||
(s0, c0), (s1, c1) = t0, t1
|
||||
s0minusc = s0 - (c0 + c1)
|
||||
t = s1 + s0minusc
|
||||
c = (t - s1) - s0minusc
|
||||
s = t
|
||||
return s, c
|
||||
|
||||
def kahan_sum_reduction(dims, output_idx):
|
||||
|
||||
def fn(x):
|
||||
arg = array_ops.zeros([], dtype) # pylint: disable=cell-var-from-loop
|
||||
reducer = kahan_sum_reducer.get_concrete_function(
|
||||
(arg, arg), (arg, arg))
|
||||
|
||||
return xla.variadic_reduce(
|
||||
(x, array_ops.zeros_like(x)),
|
||||
init_value=(arg, arg),
|
||||
dimensions_to_reduce=dims,
|
||||
reducer=reducer)[output_idx]
|
||||
|
||||
return fn
|
||||
|
||||
xs = np.array([1e5, np.pi, -1e5, np.exp(1.)])
|
||||
xs = np.array([xs, xs[::-1] / 3, xs / 7], dtype)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
kahan_sum_reduction(dims=[], output_idx=0),
|
||||
args=(xs,), expected=xs)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
kahan_sum_reduction(dims=[], output_idx=1),
|
||||
args=(xs,), expected=np.zeros_like(xs))
|
||||
shuffle_indices = np.argsort(np.random.randn(xs.shape[0]))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
kahan_sum_reduction(dims=[0], output_idx=0),
|
||||
args=(xs[shuffle_indices],),
|
||||
expected=np.array([np.exp(1) / 3 + 1e5 * 8 / 7,
|
||||
np.pi * 8 / 7 - 1e5 / 3,
|
||||
-1e5 * 8 / 7 + np.pi / 3,
|
||||
np.exp(1) * 8 / 7 + 1e5 / 3], dtype=dtype))
|
||||
error_term_equality = functools.partial(self.assertAllClose, atol=.005)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
kahan_sum_reduction(dims=[0], output_idx=1),
|
||||
args=(xs[shuffle_indices],), expected=np.zeros_like(xs[0]),
|
||||
equality_fn=error_term_equality)
|
||||
shuffle_indices = np.argsort(np.random.randn(xs.shape[1]))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
kahan_sum_reduction(dims=[1], output_idx=0),
|
||||
args=(xs[:, shuffle_indices],),
|
||||
expected=np.array([np.pi + np.exp(1.),
|
||||
(np.pi + np.exp(1.)) / 3,
|
||||
(np.pi + np.exp(1.)) / 7], dtype=dtype))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
kahan_sum_reduction(dims=[1], output_idx=1),
|
||||
args=(xs[:, shuffle_indices],), expected=np.zeros_like(xs[:, 0]),
|
||||
equality_fn=error_term_equality)
|
||||
# Now, shuffle both dims.
|
||||
xs = xs[np.argsort(np.random.randn(xs.shape[0]))]
|
||||
xs = xs[:, np.argsort(np.random.randn(xs.shape[1]))]
|
||||
self._assertOpOutputMatchesExpected(
|
||||
kahan_sum_reduction(dims=[0, 1], output_idx=0),
|
||||
args=(xs,), expected=dtype((np.pi + np.exp(1.)) * 31 / 21))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
kahan_sum_reduction(dims=[0, 1], output_idx=1),
|
||||
args=(xs,), expected=dtype(0),
|
||||
equality_fn=error_term_equality)
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testSelectAndScatter(self):
|
||||
for dtype in set(self.numeric_types).intersection(
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -38,16 +39,28 @@ class XlaReduceOp : public XlaOpKernel {
|
||||
context, dims_set.size() == dimensions_to_reduce_.size(),
|
||||
errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce "
|
||||
"argument to XlaReduce"));
|
||||
if (context->HasAttr("N")) { // variadic reduce
|
||||
use_tuples_ = true;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("N", &n_));
|
||||
} else {
|
||||
use_tuples_ = false;
|
||||
n_ = 1;
|
||||
}
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
const TensorShape input_shape = context->InputShape("input");
|
||||
const TensorShape init_value_shape = context->InputShape("init_value");
|
||||
OP_REQUIRES(context, n_ * 2 == context->num_inputs(),
|
||||
errors::InvalidArgument("Expected ", n_ * 2, " inputs but got ",
|
||||
context->num_inputs()));
|
||||
|
||||
const TensorShape input_shape = context->InputShape(0);
|
||||
const TensorShape init_value_shape = context->InputShape(n_);
|
||||
const DataType dtype = context->input_type(0);
|
||||
|
||||
const int rank = input_shape.dims();
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape),
|
||||
errors::InvalidArgument("init_value must be a scalar"));
|
||||
errors::InvalidArgument("init_value must be a scalar but got ",
|
||||
init_value_shape.DebugString()));
|
||||
|
||||
auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; };
|
||||
OP_REQUIRES(context,
|
||||
@ -67,35 +80,58 @@ class XlaReduceOp : public XlaOpKernel {
|
||||
compile_options.always_return_tuple = false;
|
||||
compile_options.is_entry_computation = false;
|
||||
XlaCompiler::CompilationResult reducer;
|
||||
OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
|
||||
compile_options, *reducer_,
|
||||
{reducer_arg, reducer_arg}, &reducer));
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
context->compiler()->CompileFunction(
|
||||
compile_options, *reducer_,
|
||||
std::vector<XlaCompiler::Argument>(n_ * 2, reducer_arg), &reducer));
|
||||
|
||||
xla::Shape scalar_shape;
|
||||
OP_REQUIRES_OK(context,
|
||||
TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape));
|
||||
xla::Shape expected_shape;
|
||||
OP_REQUIRES_OK(
|
||||
context, TensorShapeToXLAShape(dtype, TensorShape(), &expected_shape));
|
||||
if (use_tuples_) {
|
||||
expected_shape = xla::ShapeUtil::MakeTupleShape(
|
||||
std::vector<xla::Shape>(n_, expected_shape));
|
||||
}
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape),
|
||||
xla::ShapeUtil::Compatible(reducer.xla_output_shape, expected_shape),
|
||||
errors::InvalidArgument(
|
||||
"Invalid output shape of XlaReduce reducer. Expected ",
|
||||
xla::ShapeUtil::HumanString(scalar_shape), " got ",
|
||||
xla::ShapeUtil::HumanString(expected_shape), " got ",
|
||||
xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
|
||||
|
||||
std::vector<xla::XlaOp> inputs;
|
||||
std::vector<xla::XlaOp> inits;
|
||||
inputs.reserve(n_);
|
||||
inits.reserve(n_);
|
||||
for (int i = 0; i < n_; i++) {
|
||||
inputs.emplace_back(context->Input(i));
|
||||
inits.emplace_back(context->Input(n_ + i));
|
||||
}
|
||||
xla::XlaOp output =
|
||||
xla::Reduce(context->Input("input"), context->Input("init_value"),
|
||||
*reducer.computation, dimensions_to_reduce_);
|
||||
context->SetOutput(0, output);
|
||||
xla::Reduce(context->builder(), inputs, inits, *reducer.computation,
|
||||
dimensions_to_reduce_);
|
||||
if (use_tuples_) {
|
||||
for (int i = 0; i < n_; i++) {
|
||||
context->SetOutput(i, xla::GetTupleElement(output, i));
|
||||
}
|
||||
} else {
|
||||
context->SetOutput(0, output);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const NameAttrList* reducer_;
|
||||
std::vector<int64> dimensions_to_reduce_;
|
||||
bool use_tuples_;
|
||||
int n_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp);
|
||||
REGISTER_XLA_OP(Name("XlaVariadicReduce"), XlaReduceOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -475,6 +475,60 @@ reducer: a reducer function to apply
|
||||
dimensions_to_reduce: dimension numbers over which to reduce
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaVariadicReduce")
|
||||
.Input("input: N * T")
|
||||
.Input("init_value: N * T")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("dimensions_to_reduce: list(int)")
|
||||
.Attr("reducer: func")
|
||||
.Output("output: N * T")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
int n;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
|
||||
for (int i = 0; i < n; i++) {
|
||||
for (int j = 0; j < n; j++) {
|
||||
c->MergeInput(i, c->input(j));
|
||||
}
|
||||
}
|
||||
if (c->RankKnown(c->input(0))) {
|
||||
int rank = c->Rank(c->input(0));
|
||||
std::vector<int64> dimensions_to_reduce;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
|
||||
std::set<int64> dims_set(dimensions_to_reduce.begin(),
|
||||
dimensions_to_reduce.end());
|
||||
auto dim_in_range = [rank](int64 dim) {
|
||||
return dim >= 0 && dim < rank;
|
||||
};
|
||||
const int dimensions_to_reduce_size = dimensions_to_reduce.size();
|
||||
if (rank < dimensions_to_reduce_size ||
|
||||
dims_set.size() != dimensions_to_reduce.size() ||
|
||||
!absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid dimensions_to_reduce argument to XlaVariadicReduce");
|
||||
}
|
||||
for (int i = 0; i < n; i++) {
|
||||
c->set_output(
|
||||
i, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < n; i++) {
|
||||
c->set_output(i, c->input(i));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Wraps the variadic XLA Reduce operator, documented at
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce.
|
||||
|
||||
input: the input tensor(s)
|
||||
init_value: scalar initial value(s) for the reduction
|
||||
reducer: a reducer function to apply
|
||||
dimensions_to_reduce: dimension numbers over which to reduce
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaReduceWindow")
|
||||
.Input("input: T")
|
||||
.Input("init_value: T")
|
||||
|
@ -340,6 +340,7 @@ def random_uniform(minval, maxval, dims, name=None):
|
||||
|
||||
recv = gen_xla_ops.xla_recv
|
||||
reduce = gen_xla_ops.xla_reduce
|
||||
variadic_reduce = gen_xla_ops.xla_variadic_reduce
|
||||
|
||||
|
||||
def reduce_window(operand,
|
||||
|
Loading…
Reference in New Issue
Block a user