Exposes variadic reduce to TF python via xla.py
PiperOrigin-RevId: 336936774 Change-Id: Iee8bf17a18594abe2a301802eade9e7c4f4c5e34
This commit is contained in:
parent
9a04369ca8
commit
4e330bfbd1
tensorflow/compiler
jit
tests
tf2xla
@ -2066,6 +2066,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
|||||||
"XlaSpmdFullToShardShape",
|
"XlaSpmdFullToShardShape",
|
||||||
"XlaSpmdShardToFullShape",
|
"XlaSpmdShardToFullShape",
|
||||||
"XlaSvd",
|
"XlaSvd",
|
||||||
|
"XlaVariadicReduce",
|
||||||
"XlaWhile",
|
"XlaWhile",
|
||||||
"Zeta",
|
"Zeta",
|
||||||
"_Arg",
|
"_Arg",
|
||||||
|
@ -18,12 +18,15 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
from tensorflow.compiler.tf2xla.python import xla
|
from tensorflow.compiler.tf2xla.python import xla
|
||||||
from tensorflow.compiler.xla import xla_data_pb2
|
from tensorflow.compiler.xla import xla_data_pb2
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import function
|
from tensorflow.python.framework import function
|
||||||
@ -299,6 +302,78 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
|
args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
|
||||||
expected=np.array([0, 45, 120, 231], dtype=dtype))
|
expected=np.array([0, 45, 120, 231], dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
|
def testVariadicReduce(self):
|
||||||
|
for dtype in set(self.numeric_types).intersection(
|
||||||
|
set([np.float32, np.complex64])):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def kahan_sum_reducer(t0, t1):
|
||||||
|
(s0, c0), (s1, c1) = t0, t1
|
||||||
|
s0minusc = s0 - (c0 + c1)
|
||||||
|
t = s1 + s0minusc
|
||||||
|
c = (t - s1) - s0minusc
|
||||||
|
s = t
|
||||||
|
return s, c
|
||||||
|
|
||||||
|
def kahan_sum_reduction(dims, output_idx):
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
arg = array_ops.zeros([], dtype) # pylint: disable=cell-var-from-loop
|
||||||
|
reducer = kahan_sum_reducer.get_concrete_function(
|
||||||
|
(arg, arg), (arg, arg))
|
||||||
|
|
||||||
|
return xla.variadic_reduce(
|
||||||
|
(x, array_ops.zeros_like(x)),
|
||||||
|
init_value=(arg, arg),
|
||||||
|
dimensions_to_reduce=dims,
|
||||||
|
reducer=reducer)[output_idx]
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
xs = np.array([1e5, np.pi, -1e5, np.exp(1.)])
|
||||||
|
xs = np.array([xs, xs[::-1] / 3, xs / 7], dtype)
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
kahan_sum_reduction(dims=[], output_idx=0),
|
||||||
|
args=(xs,), expected=xs)
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
kahan_sum_reduction(dims=[], output_idx=1),
|
||||||
|
args=(xs,), expected=np.zeros_like(xs))
|
||||||
|
shuffle_indices = np.argsort(np.random.randn(xs.shape[0]))
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
kahan_sum_reduction(dims=[0], output_idx=0),
|
||||||
|
args=(xs[shuffle_indices],),
|
||||||
|
expected=np.array([np.exp(1) / 3 + 1e5 * 8 / 7,
|
||||||
|
np.pi * 8 / 7 - 1e5 / 3,
|
||||||
|
-1e5 * 8 / 7 + np.pi / 3,
|
||||||
|
np.exp(1) * 8 / 7 + 1e5 / 3], dtype=dtype))
|
||||||
|
error_term_equality = functools.partial(self.assertAllClose, atol=.005)
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
kahan_sum_reduction(dims=[0], output_idx=1),
|
||||||
|
args=(xs[shuffle_indices],), expected=np.zeros_like(xs[0]),
|
||||||
|
equality_fn=error_term_equality)
|
||||||
|
shuffle_indices = np.argsort(np.random.randn(xs.shape[1]))
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
kahan_sum_reduction(dims=[1], output_idx=0),
|
||||||
|
args=(xs[:, shuffle_indices],),
|
||||||
|
expected=np.array([np.pi + np.exp(1.),
|
||||||
|
(np.pi + np.exp(1.)) / 3,
|
||||||
|
(np.pi + np.exp(1.)) / 7], dtype=dtype))
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
kahan_sum_reduction(dims=[1], output_idx=1),
|
||||||
|
args=(xs[:, shuffle_indices],), expected=np.zeros_like(xs[:, 0]),
|
||||||
|
equality_fn=error_term_equality)
|
||||||
|
# Now, shuffle both dims.
|
||||||
|
xs = xs[np.argsort(np.random.randn(xs.shape[0]))]
|
||||||
|
xs = xs[:, np.argsort(np.random.randn(xs.shape[1]))]
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
kahan_sum_reduction(dims=[0, 1], output_idx=0),
|
||||||
|
args=(xs,), expected=dtype((np.pi + np.exp(1.)) * 31 / 21))
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
kahan_sum_reduction(dims=[0, 1], output_idx=1),
|
||||||
|
args=(xs,), expected=dtype(0),
|
||||||
|
equality_fn=error_term_equality)
|
||||||
|
|
||||||
@test_util.disable_mlir_bridge('Not supported yet')
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testSelectAndScatter(self):
|
def testSelectAndScatter(self):
|
||||||
for dtype in set(self.numeric_types).intersection(
|
for dtype in set(self.numeric_types).intersection(
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -38,16 +39,28 @@ class XlaReduceOp : public XlaOpKernel {
|
|||||||
context, dims_set.size() == dimensions_to_reduce_.size(),
|
context, dims_set.size() == dimensions_to_reduce_.size(),
|
||||||
errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce "
|
errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce "
|
||||||
"argument to XlaReduce"));
|
"argument to XlaReduce"));
|
||||||
|
if (context->HasAttr("N")) { // variadic reduce
|
||||||
|
use_tuples_ = true;
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("N", &n_));
|
||||||
|
} else {
|
||||||
|
use_tuples_ = false;
|
||||||
|
n_ = 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* context) override {
|
void Compile(XlaOpKernelContext* context) override {
|
||||||
const TensorShape input_shape = context->InputShape("input");
|
OP_REQUIRES(context, n_ * 2 == context->num_inputs(),
|
||||||
const TensorShape init_value_shape = context->InputShape("init_value");
|
errors::InvalidArgument("Expected ", n_ * 2, " inputs but got ",
|
||||||
|
context->num_inputs()));
|
||||||
|
|
||||||
|
const TensorShape input_shape = context->InputShape(0);
|
||||||
|
const TensorShape init_value_shape = context->InputShape(n_);
|
||||||
const DataType dtype = context->input_type(0);
|
const DataType dtype = context->input_type(0);
|
||||||
|
|
||||||
const int rank = input_shape.dims();
|
const int rank = input_shape.dims();
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape),
|
OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape),
|
||||||
errors::InvalidArgument("init_value must be a scalar"));
|
errors::InvalidArgument("init_value must be a scalar but got ",
|
||||||
|
init_value_shape.DebugString()));
|
||||||
|
|
||||||
auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; };
|
auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; };
|
||||||
OP_REQUIRES(context,
|
OP_REQUIRES(context,
|
||||||
@ -67,35 +80,58 @@ class XlaReduceOp : public XlaOpKernel {
|
|||||||
compile_options.always_return_tuple = false;
|
compile_options.always_return_tuple = false;
|
||||||
compile_options.is_entry_computation = false;
|
compile_options.is_entry_computation = false;
|
||||||
XlaCompiler::CompilationResult reducer;
|
XlaCompiler::CompilationResult reducer;
|
||||||
OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
|
OP_REQUIRES_OK(
|
||||||
compile_options, *reducer_,
|
context,
|
||||||
{reducer_arg, reducer_arg}, &reducer));
|
context->compiler()->CompileFunction(
|
||||||
|
compile_options, *reducer_,
|
||||||
|
std::vector<XlaCompiler::Argument>(n_ * 2, reducer_arg), &reducer));
|
||||||
|
|
||||||
xla::Shape scalar_shape;
|
xla::Shape expected_shape;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(
|
||||||
TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape));
|
context, TensorShapeToXLAShape(dtype, TensorShape(), &expected_shape));
|
||||||
|
if (use_tuples_) {
|
||||||
|
expected_shape = xla::ShapeUtil::MakeTupleShape(
|
||||||
|
std::vector<xla::Shape>(n_, expected_shape));
|
||||||
|
}
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context,
|
context,
|
||||||
xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape),
|
xla::ShapeUtil::Compatible(reducer.xla_output_shape, expected_shape),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Invalid output shape of XlaReduce reducer. Expected ",
|
"Invalid output shape of XlaReduce reducer. Expected ",
|
||||||
xla::ShapeUtil::HumanString(scalar_shape), " got ",
|
xla::ShapeUtil::HumanString(expected_shape), " got ",
|
||||||
xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
|
xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
|
||||||
|
|
||||||
|
std::vector<xla::XlaOp> inputs;
|
||||||
|
std::vector<xla::XlaOp> inits;
|
||||||
|
inputs.reserve(n_);
|
||||||
|
inits.reserve(n_);
|
||||||
|
for (int i = 0; i < n_; i++) {
|
||||||
|
inputs.emplace_back(context->Input(i));
|
||||||
|
inits.emplace_back(context->Input(n_ + i));
|
||||||
|
}
|
||||||
xla::XlaOp output =
|
xla::XlaOp output =
|
||||||
xla::Reduce(context->Input("input"), context->Input("init_value"),
|
xla::Reduce(context->builder(), inputs, inits, *reducer.computation,
|
||||||
*reducer.computation, dimensions_to_reduce_);
|
dimensions_to_reduce_);
|
||||||
context->SetOutput(0, output);
|
if (use_tuples_) {
|
||||||
|
for (int i = 0; i < n_; i++) {
|
||||||
|
context->SetOutput(i, xla::GetTupleElement(output, i));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
context->SetOutput(0, output);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const NameAttrList* reducer_;
|
const NameAttrList* reducer_;
|
||||||
std::vector<int64> dimensions_to_reduce_;
|
std::vector<int64> dimensions_to_reduce_;
|
||||||
|
bool use_tuples_;
|
||||||
|
int n_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp);
|
REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp);
|
||||||
|
REGISTER_XLA_OP(Name("XlaVariadicReduce"), XlaReduceOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -475,6 +475,60 @@ reducer: a reducer function to apply
|
|||||||
dimensions_to_reduce: dimension numbers over which to reduce
|
dimensions_to_reduce: dimension numbers over which to reduce
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("XlaVariadicReduce")
|
||||||
|
.Input("input: N * T")
|
||||||
|
.Input("init_value: N * T")
|
||||||
|
.Attr("N: int >= 1")
|
||||||
|
.Attr("T: numbertype")
|
||||||
|
.Attr("dimensions_to_reduce: list(int)")
|
||||||
|
.Attr("reducer: func")
|
||||||
|
.Output("output: N * T")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
int n;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
for (int j = 0; j < n; j++) {
|
||||||
|
c->MergeInput(i, c->input(j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (c->RankKnown(c->input(0))) {
|
||||||
|
int rank = c->Rank(c->input(0));
|
||||||
|
std::vector<int64> dimensions_to_reduce;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
|
||||||
|
std::set<int64> dims_set(dimensions_to_reduce.begin(),
|
||||||
|
dimensions_to_reduce.end());
|
||||||
|
auto dim_in_range = [rank](int64 dim) {
|
||||||
|
return dim >= 0 && dim < rank;
|
||||||
|
};
|
||||||
|
const int dimensions_to_reduce_size = dimensions_to_reduce.size();
|
||||||
|
if (rank < dimensions_to_reduce_size ||
|
||||||
|
dims_set.size() != dimensions_to_reduce.size() ||
|
||||||
|
!absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Invalid dimensions_to_reduce argument to XlaVariadicReduce");
|
||||||
|
}
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
c->set_output(
|
||||||
|
i, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
c->set_output(i, c->input(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
|
.Doc(R"doc(
|
||||||
|
Wraps the variadic XLA Reduce operator, documented at
|
||||||
|
https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce.
|
||||||
|
|
||||||
|
input: the input tensor(s)
|
||||||
|
init_value: scalar initial value(s) for the reduction
|
||||||
|
reducer: a reducer function to apply
|
||||||
|
dimensions_to_reduce: dimension numbers over which to reduce
|
||||||
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("XlaReduceWindow")
|
REGISTER_OP("XlaReduceWindow")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
.Input("init_value: T")
|
.Input("init_value: T")
|
||||||
|
@ -340,6 +340,7 @@ def random_uniform(minval, maxval, dims, name=None):
|
|||||||
|
|
||||||
recv = gen_xla_ops.xla_recv
|
recv = gen_xla_ops.xla_recv
|
||||||
reduce = gen_xla_ops.xla_reduce
|
reduce = gen_xla_ops.xla_reduce
|
||||||
|
variadic_reduce = gen_xla_ops.xla_variadic_reduce
|
||||||
|
|
||||||
|
|
||||||
def reduce_window(operand,
|
def reduce_window(operand,
|
||||||
|
Loading…
Reference in New Issue
Block a user