[XLA] Create helper for using reductions as slice-like operations.

PiperOrigin-RevId: 251630341
This commit is contained in:
A. Unique TensorFlower 2019-06-05 06:26:45 -07:00 committed by TensorFlower Gardener
parent 58c796df1d
commit 8b7052c4d8
2 changed files with 21 additions and 0 deletions

View File

@ -110,6 +110,15 @@ XlaComputation CreateScalarOrComputation(PrimitiveType type,
const XlaOp& rhs) { return Or(lhs, rhs); });
}
XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type,
XlaBuilder* builder) {
XlaComputation reducer =
(primitive_util::IsIntegralType(type) || type == PRED)
? CreateScalarOrComputation(type, builder)
: CreateScalarAddComputation(type, builder);
return reducer;
}
XlaOp Any(XlaOp predicates) {
XlaBuilder* builder = predicates.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {

View File

@ -52,6 +52,18 @@ XlaComputation CreateScalarAndComputation(PrimitiveType type,
XlaComputation CreateScalarOrComputation(PrimitiveType type,
XlaBuilder* builder);
// This is to be used for general purpose "identity" like reductions with zero
// for any type (ie. boolean operations for PRED and Add for real numbers).
// As an example, this operation can be used for a situation of:
// x_type = type(x)
// op = CreateScalarIdentityWithZeroComputation(x_type)
// ASSERT_TRUE(op(x, 0) == x)
//
// This functionality is used for operations that are similar to a slice,
// gather, or broadcast, but are created through a reduction.
XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type,
XlaBuilder* builder);
// Returns whether any predicate in "predicates" is set.
//
// Note: if predicates is zero-sized, Any() vacuously returns false.