[XLA] Create helper for using reductions as slice-like operations.
PiperOrigin-RevId: 251630341
This commit is contained in:
parent
58c796df1d
commit
8b7052c4d8
@ -110,6 +110,15 @@ XlaComputation CreateScalarOrComputation(PrimitiveType type,
|
||||
const XlaOp& rhs) { return Or(lhs, rhs); });
|
||||
}
|
||||
|
||||
XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type,
|
||||
XlaBuilder* builder) {
|
||||
XlaComputation reducer =
|
||||
(primitive_util::IsIntegralType(type) || type == PRED)
|
||||
? CreateScalarOrComputation(type, builder)
|
||||
: CreateScalarAddComputation(type, builder);
|
||||
return reducer;
|
||||
}
|
||||
|
||||
XlaOp Any(XlaOp predicates) {
|
||||
XlaBuilder* builder = predicates.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
|
@ -52,6 +52,18 @@ XlaComputation CreateScalarAndComputation(PrimitiveType type,
|
||||
XlaComputation CreateScalarOrComputation(PrimitiveType type,
|
||||
XlaBuilder* builder);
|
||||
|
||||
// This is to be used for general purpose "identity" like reductions with zero
|
||||
// for any type (ie. boolean operations for PRED and Add for real numbers).
|
||||
// As an example, this operation can be used for a situation of:
|
||||
// x_type = type(x)
|
||||
// op = CreateScalarIdentityWithZeroComputation(x_type)
|
||||
// ASSERT_TRUE(op(x, 0) == x)
|
||||
//
|
||||
// This functionality is used for operations that are similar to a slice,
|
||||
// gather, or broadcast, but are created through a reduction.
|
||||
XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type,
|
||||
XlaBuilder* builder);
|
||||
|
||||
// Returns whether any predicate in "predicates" is set.
|
||||
//
|
||||
// Note: if predicates is zero-sized, Any() vacuously returns false.
|
||||
|
Loading…
Reference in New Issue
Block a user