diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index 4f980b6d14e..ee9764c0c35 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/primitive_util.h" namespace tensorflow { @@ -68,15 +69,21 @@ class MatrixSetDiagOp : public XlaOpKernel { /*broadcast_dimensions=*/{0}); indicator = xla::Broadcast(indicator, batch_shape.dim_sizes()); - // Broadcast diag up to the input shape. Use an implicit broadcast (Add) + // Broadcast diag up to the input shape. Use an implicit broadcast (Add/Or) // because we need to broadcast on the right. std::vector diag_broadcast_dims(rank - 1); std::iota(diag_broadcast_dims.begin(), diag_broadcast_dims.end(), 0); if (min_dim != m) { diag_broadcast_dims.back() = rank - 1; } - diag = xla::Add(diag, xla::Broadcast(zero, input_shape.dim_sizes()), - /*broadcast_dimensions=*/diag_broadcast_dims); + if (context->input_xla_type(0) == xla::PRED) { + diag = xla::Or(diag, xla::Broadcast(zero, input_shape.dim_sizes()), + /*broadcast_dimensions=*/diag_broadcast_dims); + + } else { + diag = xla::Add(diag, xla::Broadcast(zero, input_shape.dim_sizes()), + /*broadcast_dimensions=*/diag_broadcast_dims); + } auto output = xla::Select(indicator, diag, input); context->SetOutput(0, output); diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h index 47b8f1b44ff..03ebe4e0098 100644 --- a/tensorflow/compiler/xla/client/lib/constants.h +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -46,6 +46,8 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { PrimitiveType_Name(type))); } switch (type) { + case PRED: + return ConstantR0(builder, static_cast(value)); case F16: return ConstantR0(builder, static_cast(value)); case BF16: diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index 902269d9412..5b9a94ba36c 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -78,9 +78,7 @@ XlaOp GetMatrixDiagonal(XlaOp x, int k) { // TPUs don't support S64 add reduction at the moment. But fortunately // OR-reductions work just as well for integers. XlaComputation reducer = - primitive_util::IsIntegralType(shape.element_type()) - ? CreateScalarOrComputation(shape.element_type(), builder) - : CreateScalarAddComputation(shape.element_type(), builder); + CreateScalarIdentityWithZeroComputation(shape.element_type(), builder); // k == 0, we can save one slice op. if (k == 0) { return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), diff --git a/tensorflow/python/kernel_tests/diag_op_test.py b/tensorflow/python/kernel_tests/diag_op_test.py index e84d60d2c24..0bf48fd228f 100644 --- a/tensorflow/python/kernel_tests/diag_op_test.py +++ b/tensorflow/python/kernel_tests/diag_op_test.py @@ -133,7 +133,6 @@ class MatrixSetDiagTest(test.TestCase): self.assertAllEqual(mat_set_diag_batch, self.evaluate(output)) @test_util.run_deprecated_v1 - @test_util.disable_xla("Diagonal operations do not support bool in XLA") def testSquareBatch(self): self._testSquareBatch(np.float32) self._testSquareBatch(np.float64) @@ -247,7 +246,6 @@ class MatrixDiagPartTest(test.TestCase): self.assertAllEqual(mat_batch_diag.eval(), v_batch) @test_util.run_deprecated_v1 - @test_util.disable_xla("Diagonal operations do not support bool in XLA") def testSquareBatch(self): self._testSquareBatch(np.float32) self._testSquareBatch(np.float64)