[TF:XLA] Handle predicates for matrix diagonal operations.
PiperOrigin-RevId: 251821349
This commit is contained in:
parent
326111b961
commit
28b65d088c
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -68,15 +69,21 @@ class MatrixSetDiagOp : public XlaOpKernel {
|
||||
/*broadcast_dimensions=*/{0});
|
||||
indicator = xla::Broadcast(indicator, batch_shape.dim_sizes());
|
||||
|
||||
// Broadcast diag up to the input shape. Use an implicit broadcast (Add)
|
||||
// Broadcast diag up to the input shape. Use an implicit broadcast (Add/Or)
|
||||
// because we need to broadcast on the right.
|
||||
std::vector<int64> diag_broadcast_dims(rank - 1);
|
||||
std::iota(diag_broadcast_dims.begin(), diag_broadcast_dims.end(), 0);
|
||||
if (min_dim != m) {
|
||||
diag_broadcast_dims.back() = rank - 1;
|
||||
}
|
||||
diag = xla::Add(diag, xla::Broadcast(zero, input_shape.dim_sizes()),
|
||||
/*broadcast_dimensions=*/diag_broadcast_dims);
|
||||
if (context->input_xla_type(0) == xla::PRED) {
|
||||
diag = xla::Or(diag, xla::Broadcast(zero, input_shape.dim_sizes()),
|
||||
/*broadcast_dimensions=*/diag_broadcast_dims);
|
||||
|
||||
} else {
|
||||
diag = xla::Add(diag, xla::Broadcast(zero, input_shape.dim_sizes()),
|
||||
/*broadcast_dimensions=*/diag_broadcast_dims);
|
||||
}
|
||||
|
||||
auto output = xla::Select(indicator, diag, input);
|
||||
context->SetOutput(0, output);
|
||||
|
@ -46,6 +46,8 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) {
|
||||
PrimitiveType_Name(type)));
|
||||
}
|
||||
switch (type) {
|
||||
case PRED:
|
||||
return ConstantR0<bool>(builder, static_cast<bool>(value));
|
||||
case F16:
|
||||
return ConstantR0<half>(builder, static_cast<half>(value));
|
||||
case BF16:
|
||||
|
@ -78,9 +78,7 @@ XlaOp GetMatrixDiagonal(XlaOp x, int k) {
|
||||
// TPUs don't support S64 add reduction at the moment. But fortunately
|
||||
// OR-reductions work just as well for integers.
|
||||
XlaComputation reducer =
|
||||
primitive_util::IsIntegralType(shape.element_type())
|
||||
? CreateScalarOrComputation(shape.element_type(), builder)
|
||||
: CreateScalarAddComputation(shape.element_type(), builder);
|
||||
CreateScalarIdentityWithZeroComputation(shape.element_type(), builder);
|
||||
// k == 0, we can save one slice op.
|
||||
if (k == 0) {
|
||||
return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0),
|
||||
|
@ -133,7 +133,6 @@ class MatrixSetDiagTest(test.TestCase):
|
||||
self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("Diagonal operations do not support bool in XLA")
|
||||
def testSquareBatch(self):
|
||||
self._testSquareBatch(np.float32)
|
||||
self._testSquareBatch(np.float64)
|
||||
@ -247,7 +246,6 @@ class MatrixDiagPartTest(test.TestCase):
|
||||
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("Diagonal operations do not support bool in XLA")
|
||||
def testSquareBatch(self):
|
||||
self._testSquareBatch(np.float32)
|
||||
self._testSquareBatch(np.float64)
|
||||
|
Loading…
Reference in New Issue
Block a user