diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index d34ecaf99c8..0940a873fa4 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -110,6 +110,15 @@ XlaComputation CreateScalarOrComputation(PrimitiveType type, const XlaOp& rhs) { return Or(lhs, rhs); }); } +XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type, + XlaBuilder* builder) { + XlaComputation reducer = + (primitive_util::IsIntegralType(type) || type == PRED) + ? CreateScalarOrComputation(type, builder) + : CreateScalarAddComputation(type, builder); + return reducer; +} + XlaOp Any(XlaOp predicates) { XlaBuilder* builder = predicates.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 6f64d587fa8..270076a1586 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -52,6 +52,18 @@ XlaComputation CreateScalarAndComputation(PrimitiveType type, XlaComputation CreateScalarOrComputation(PrimitiveType type, XlaBuilder* builder); +// This is to be used for general purpose "identity" like reductions with zero +// for any type (ie. boolean operations for PRED and Add for real numbers). +// As an example, this operation can be used for a situation of: +// x_type = type(x) +// op = CreateScalarIdentityWithZeroComputation(x_type) +// ASSERT_TRUE(op(x, 0) == x) +// +// This functionality is used for operations that are similar to a slice, +// gather, or broadcast, but are created through a reduction. +XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type, + XlaBuilder* builder); + // Returns whether any predicate in "predicates" is set. // // Note: if predicates is zero-sized, Any() vacuously returns false.