[XLA] Create helper for using reductions as slice-like operations.
PiperOrigin-RevId: 251630341
This commit is contained in:
		
							parent
							
								
									58c796df1d
								
							
						
					
					
						commit
						8b7052c4d8
					
				| @ -110,6 +110,15 @@ XlaComputation CreateScalarOrComputation(PrimitiveType type, | ||||
|                                     const XlaOp& rhs) { return Or(lhs, rhs); }); | ||||
| } | ||||
| 
 | ||||
| XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type, | ||||
|                                                        XlaBuilder* builder) { | ||||
|   XlaComputation reducer = | ||||
|       (primitive_util::IsIntegralType(type) || type == PRED) | ||||
|           ? CreateScalarOrComputation(type, builder) | ||||
|           : CreateScalarAddComputation(type, builder); | ||||
|   return reducer; | ||||
| } | ||||
| 
 | ||||
| XlaOp Any(XlaOp predicates) { | ||||
|   XlaBuilder* builder = predicates.builder(); | ||||
|   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { | ||||
|  | ||||
| @ -52,6 +52,18 @@ XlaComputation CreateScalarAndComputation(PrimitiveType type, | ||||
| XlaComputation CreateScalarOrComputation(PrimitiveType type, | ||||
|                                          XlaBuilder* builder); | ||||
| 
 | ||||
| // This is to be used for general purpose "identity" like reductions with zero
 | ||||
| // for any type (ie. boolean operations for PRED and Add for real numbers).
 | ||||
| // As an example, this operation can be used for a situation of:
 | ||||
| // x_type = type(x)
 | ||||
| // op = CreateScalarIdentityWithZeroComputation(x_type)
 | ||||
| // ASSERT_TRUE(op(x, 0) == x)
 | ||||
| //
 | ||||
| // This functionality is used for operations that are similar to a slice,
 | ||||
| // gather, or broadcast, but are created through a reduction.
 | ||||
| XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type, | ||||
|                                                        XlaBuilder* builder); | ||||
| 
 | ||||
| // Returns whether any predicate in "predicates" is set.
 | ||||
| //
 | ||||
| // Note: if predicates is zero-sized, Any() vacuously returns false.
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user