Exposes variadic reduce to TF python via xla.py

PiperOrigin-RevId: 336936774
Change-Id: Iee8bf17a18594abe2a301802eade9e7c4f4c5e34
This commit is contained in:
Brian Patton 2020-10-13 12:56:14 -07:00 committed by TensorFlower Gardener
parent 9a04369ca8
commit 4e330bfbd1
5 changed files with 181 additions and 14 deletions

View File

@ -2066,6 +2066,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"XlaSpmdFullToShardShape",
"XlaSpmdShardToFullShape",
"XlaSvd",
"XlaVariadicReduce",
"XlaWhile",
"Zeta",
"_Arg",

View File

@ -18,12 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.compiler.xla import xla_data_pb2
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
@ -299,6 +302,78 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
expected=np.array([0, 45, 120, 231], dtype=dtype))
@test_util.disable_mlir_bridge('Not supported yet')
def testVariadicReduce(self):
for dtype in set(self.numeric_types).intersection(
set([np.float32, np.complex64])):
@def_function.function
def kahan_sum_reducer(t0, t1):
(s0, c0), (s1, c1) = t0, t1
s0minusc = s0 - (c0 + c1)
t = s1 + s0minusc
c = (t - s1) - s0minusc
s = t
return s, c
def kahan_sum_reduction(dims, output_idx):
def fn(x):
arg = array_ops.zeros([], dtype) # pylint: disable=cell-var-from-loop
reducer = kahan_sum_reducer.get_concrete_function(
(arg, arg), (arg, arg))
return xla.variadic_reduce(
(x, array_ops.zeros_like(x)),
init_value=(arg, arg),
dimensions_to_reduce=dims,
reducer=reducer)[output_idx]
return fn
xs = np.array([1e5, np.pi, -1e5, np.exp(1.)])
xs = np.array([xs, xs[::-1] / 3, xs / 7], dtype)
self._assertOpOutputMatchesExpected(
kahan_sum_reduction(dims=[], output_idx=0),
args=(xs,), expected=xs)
self._assertOpOutputMatchesExpected(
kahan_sum_reduction(dims=[], output_idx=1),
args=(xs,), expected=np.zeros_like(xs))
shuffle_indices = np.argsort(np.random.randn(xs.shape[0]))
self._assertOpOutputMatchesExpected(
kahan_sum_reduction(dims=[0], output_idx=0),
args=(xs[shuffle_indices],),
expected=np.array([np.exp(1) / 3 + 1e5 * 8 / 7,
np.pi * 8 / 7 - 1e5 / 3,
-1e5 * 8 / 7 + np.pi / 3,
np.exp(1) * 8 / 7 + 1e5 / 3], dtype=dtype))
error_term_equality = functools.partial(self.assertAllClose, atol=.005)
self._assertOpOutputMatchesExpected(
kahan_sum_reduction(dims=[0], output_idx=1),
args=(xs[shuffle_indices],), expected=np.zeros_like(xs[0]),
equality_fn=error_term_equality)
shuffle_indices = np.argsort(np.random.randn(xs.shape[1]))
self._assertOpOutputMatchesExpected(
kahan_sum_reduction(dims=[1], output_idx=0),
args=(xs[:, shuffle_indices],),
expected=np.array([np.pi + np.exp(1.),
(np.pi + np.exp(1.)) / 3,
(np.pi + np.exp(1.)) / 7], dtype=dtype))
self._assertOpOutputMatchesExpected(
kahan_sum_reduction(dims=[1], output_idx=1),
args=(xs[:, shuffle_indices],), expected=np.zeros_like(xs[:, 0]),
equality_fn=error_term_equality)
# Now, shuffle both dims.
xs = xs[np.argsort(np.random.randn(xs.shape[0]))]
xs = xs[:, np.argsort(np.random.randn(xs.shape[1]))]
self._assertOpOutputMatchesExpected(
kahan_sum_reduction(dims=[0, 1], output_idx=0),
args=(xs,), expected=dtype((np.pi + np.exp(1.)) * 31 / 21))
self._assertOpOutputMatchesExpected(
kahan_sum_reduction(dims=[0, 1], output_idx=1),
args=(xs,), expected=dtype(0),
equality_fn=error_term_equality)
@test_util.disable_mlir_bridge('Not supported yet')
def testSelectAndScatter(self):
for dtype in set(self.numeric_types).intersection(

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace {
@ -38,16 +39,28 @@ class XlaReduceOp : public XlaOpKernel {
context, dims_set.size() == dimensions_to_reduce_.size(),
errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce "
"argument to XlaReduce"));
if (context->HasAttr("N")) { // variadic reduce
use_tuples_ = true;
OP_REQUIRES_OK(context, context->GetAttr("N", &n_));
} else {
use_tuples_ = false;
n_ = 1;
}
}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape("input");
const TensorShape init_value_shape = context->InputShape("init_value");
OP_REQUIRES(context, n_ * 2 == context->num_inputs(),
errors::InvalidArgument("Expected ", n_ * 2, " inputs but got ",
context->num_inputs()));
const TensorShape input_shape = context->InputShape(0);
const TensorShape init_value_shape = context->InputShape(n_);
const DataType dtype = context->input_type(0);
const int rank = input_shape.dims();
OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape),
errors::InvalidArgument("init_value must be a scalar"));
errors::InvalidArgument("init_value must be a scalar but got ",
init_value_shape.DebugString()));
auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; };
OP_REQUIRES(context,
@ -67,35 +80,58 @@ class XlaReduceOp : public XlaOpKernel {
compile_options.always_return_tuple = false;
compile_options.is_entry_computation = false;
XlaCompiler::CompilationResult reducer;
OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
compile_options, *reducer_,
{reducer_arg, reducer_arg}, &reducer));
OP_REQUIRES_OK(
context,
context->compiler()->CompileFunction(
compile_options, *reducer_,
std::vector<XlaCompiler::Argument>(n_ * 2, reducer_arg), &reducer));
xla::Shape scalar_shape;
OP_REQUIRES_OK(context,
TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape));
xla::Shape expected_shape;
OP_REQUIRES_OK(
context, TensorShapeToXLAShape(dtype, TensorShape(), &expected_shape));
if (use_tuples_) {
expected_shape = xla::ShapeUtil::MakeTupleShape(
std::vector<xla::Shape>(n_, expected_shape));
}
OP_REQUIRES(
context,
xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape),
xla::ShapeUtil::Compatible(reducer.xla_output_shape, expected_shape),
errors::InvalidArgument(
"Invalid output shape of XlaReduce reducer. Expected ",
xla::ShapeUtil::HumanString(scalar_shape), " got ",
xla::ShapeUtil::HumanString(expected_shape), " got ",
xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
std::vector<xla::XlaOp> inputs;
std::vector<xla::XlaOp> inits;
inputs.reserve(n_);
inits.reserve(n_);
for (int i = 0; i < n_; i++) {
inputs.emplace_back(context->Input(i));
inits.emplace_back(context->Input(n_ + i));
}
xla::XlaOp output =
xla::Reduce(context->Input("input"), context->Input("init_value"),
*reducer.computation, dimensions_to_reduce_);
context->SetOutput(0, output);
xla::Reduce(context->builder(), inputs, inits, *reducer.computation,
dimensions_to_reduce_);
if (use_tuples_) {
for (int i = 0; i < n_; i++) {
context->SetOutput(i, xla::GetTupleElement(output, i));
}
} else {
context->SetOutput(0, output);
}
}
private:
const NameAttrList* reducer_;
std::vector<int64> dimensions_to_reduce_;
bool use_tuples_;
int n_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp);
};
REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp);
REGISTER_XLA_OP(Name("XlaVariadicReduce"), XlaReduceOp);
} // namespace
} // namespace tensorflow

View File

@ -475,6 +475,60 @@ reducer: a reducer function to apply
dimensions_to_reduce: dimension numbers over which to reduce
)doc");
REGISTER_OP("XlaVariadicReduce")
.Input("input: N * T")
.Input("init_value: N * T")
.Attr("N: int >= 1")
.Attr("T: numbertype")
.Attr("dimensions_to_reduce: list(int)")
.Attr("reducer: func")
.Output("output: N * T")
.SetShapeFn([](shape_inference::InferenceContext* c) {
int n;
TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
c->MergeInput(i, c->input(j));
}
}
if (c->RankKnown(c->input(0))) {
int rank = c->Rank(c->input(0));
std::vector<int64> dimensions_to_reduce;
TF_RETURN_IF_ERROR(
c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
std::set<int64> dims_set(dimensions_to_reduce.begin(),
dimensions_to_reduce.end());
auto dim_in_range = [rank](int64 dim) {
return dim >= 0 && dim < rank;
};
const int dimensions_to_reduce_size = dimensions_to_reduce.size();
if (rank < dimensions_to_reduce_size ||
dims_set.size() != dimensions_to_reduce.size() ||
!absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
return errors::InvalidArgument(
"Invalid dimensions_to_reduce argument to XlaVariadicReduce");
}
for (int i = 0; i < n; i++) {
c->set_output(
i, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
}
} else {
for (int i = 0; i < n; i++) {
c->set_output(i, c->input(i));
}
}
return Status::OK();
})
.Doc(R"doc(
Wraps the variadic XLA Reduce operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce.
input: the input tensor(s)
init_value: scalar initial value(s) for the reduction
reducer: a reducer function to apply
dimensions_to_reduce: dimension numbers over which to reduce
)doc");
REGISTER_OP("XlaReduceWindow")
.Input("input: T")
.Input("init_value: T")

View File

@ -340,6 +340,7 @@ def random_uniform(minval, maxval, dims, name=None):
recv = gen_xla_ops.xla_recv
reduce = gen_xla_ops.xla_reduce
variadic_reduce = gen_xla_ops.xla_variadic_reduce
def reduce_window(operand,