[XLA] Create helper for using reductions as slice-like operations.
PiperOrigin-RevId: 251630341
This commit is contained in:
parent
58c796df1d
commit
8b7052c4d8
@ -110,6 +110,15 @@ XlaComputation CreateScalarOrComputation(PrimitiveType type,
|
|||||||
const XlaOp& rhs) { return Or(lhs, rhs); });
|
const XlaOp& rhs) { return Or(lhs, rhs); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type,
|
||||||
|
XlaBuilder* builder) {
|
||||||
|
XlaComputation reducer =
|
||||||
|
(primitive_util::IsIntegralType(type) || type == PRED)
|
||||||
|
? CreateScalarOrComputation(type, builder)
|
||||||
|
: CreateScalarAddComputation(type, builder);
|
||||||
|
return reducer;
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp Any(XlaOp predicates) {
|
XlaOp Any(XlaOp predicates) {
|
||||||
XlaBuilder* builder = predicates.builder();
|
XlaBuilder* builder = predicates.builder();
|
||||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
@ -52,6 +52,18 @@ XlaComputation CreateScalarAndComputation(PrimitiveType type,
|
|||||||
XlaComputation CreateScalarOrComputation(PrimitiveType type,
|
XlaComputation CreateScalarOrComputation(PrimitiveType type,
|
||||||
XlaBuilder* builder);
|
XlaBuilder* builder);
|
||||||
|
|
||||||
|
// This is to be used for general purpose "identity" like reductions with zero
|
||||||
|
// for any type (ie. boolean operations for PRED and Add for real numbers).
|
||||||
|
// As an example, this operation can be used for a situation of:
|
||||||
|
// x_type = type(x)
|
||||||
|
// op = CreateScalarIdentityWithZeroComputation(x_type)
|
||||||
|
// ASSERT_TRUE(op(x, 0) == x)
|
||||||
|
//
|
||||||
|
// This functionality is used for operations that are similar to a slice,
|
||||||
|
// gather, or broadcast, but are created through a reduction.
|
||||||
|
XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type,
|
||||||
|
XlaBuilder* builder);
|
||||||
|
|
||||||
// Returns whether any predicate in "predicates" is set.
|
// Returns whether any predicate in "predicates" is set.
|
||||||
//
|
//
|
||||||
// Note: if predicates is zero-sized, Any() vacuously returns false.
|
// Note: if predicates is zero-sized, Any() vacuously returns false.
|
||||||
|
Loading…
Reference in New Issue
Block a user