diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 3fa3746598e..ac5e01a0abf 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -312,6 +312,16 @@ StatusOr MlirHloBuilder::RngOpInternal( return CreateOp(op_name, shape, operands); } +StatusOr MlirHloBuilder::RngBitGeneratorInternal( + const Shape& full_result_shape, RandomAlgorithm algorithm, + XlaOp initial_state) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + full_result_shape, builder_)); + auto op = builder_.create( + loc_, ty, builder_.getI32IntegerAttr(algorithm), GetValue(initial_state)); + return MakeXlaOp(op); +} + StatusOr MlirHloBuilder::ReshapeInternal(const Shape& shape, XlaOp operand, int64 inferred_dimension) { diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index 3884689e48d..00b7aa4d0b0 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -183,6 +183,9 @@ class MlirHloBuilder : public XlaBuilder { StatusOr RngOpInternal(RandomDistribution distribution, absl::Span parameters, const Shape& shape) override; + StatusOr RngBitGeneratorInternal(const Shape& full_result_shape, + RandomAlgorithm algorithm, + XlaOp initial_state) override; StatusOr ReshapeInternal(const Shape& shape, XlaOp operand, int64 inferred_dimension) override; diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index 221f01ece8c..de1e592157e 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -290,6 +290,14 @@ func @diag(%arg0: tensor<2xf32>) -> tensor<2x2xf32> { return %0 : tensor<2x2xf32> } +// CHECK-LABEL: random_uniform_int +func @random_uniform_int(%arg0: tensor, %arg1: tensor) -> tensor<1000xi32> { + %0 = "tf.Const"() {value = dense<1000> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NOT: tf.RandomUniformInt + %1 = "tf.RandomUniformInt"(%0, %arg0, %arg1) {seed = 0 : i64, seed2 = 0 : i64} : (tensor<1xi32>, tensor, tensor) -> tensor<1000xi32> + return %1 : tensor<1000xi32> +} + // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index af4a5cb45bf..93b1f5c3397 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -167,12 +167,14 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -199,6 +201,11 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index cc7fb3e1ab4..805f2d2da82 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -265,6 +265,7 @@ tf_xla_py_test( name = "categorical_op_test", size = "small", srcs = ["categorical_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1285,6 +1286,7 @@ tf_xla_py_test( name = "stateless_random_ops_test", size = "medium", srcs = ["stateless_random_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index f3f12d32e40..f0ac86d5444 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -96,7 +96,7 @@ class UnaryOpsTest(xla_test.XLATestCase): self.assertAllEqual(result, expected) @test_util.disable_mlir_bridge( - "Handle complex element types in DiagPart op lowering") + "Handle complex element type in DiagPart lowering") def testAllTypeOps(self): for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 8ab851fe0eb..33038ddfd04 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1984,7 +1984,6 @@ XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) { XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, const Shape& shape) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state)); Shape output_shape = shape; @@ -2003,14 +2002,22 @@ XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm, return InvalidArgument("Unsupported shape for RngBitGenerator: %s", PrimitiveType_Name(output_shape.element_type())); } - *instr.mutable_shape() = - ShapeUtil::MakeTupleShape({state_shape, output_shape}).ToProto(); - instr.set_rng_algorithm(algorithm); - return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator, - {initial_state}); + return RngBitGeneratorInternal( + ShapeUtil::MakeTupleShape({state_shape, output_shape}), algorithm, + initial_state); }); } +StatusOr XlaBuilder::RngBitGeneratorInternal( + const Shape& full_result_shape, RandomAlgorithm algorithm, + XlaOp initial_state) { + HloInstructionProto instr; + *instr.mutable_shape() = full_result_shape.ToProto(); + instr.set_rng_algorithm(algorithm); + return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator, + {initial_state}); +} + XlaOp XlaBuilder::While(const XlaComputation& condition, const XlaComputation& body, XlaOp init) { return ReportErrorOrReturn([&]() -> StatusOr { diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index d812b35f7a0..f841a1a75a0 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -712,6 +712,11 @@ class XlaBuilder { XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, const Shape& shape); + // Internal variant for the op with the full result shape containing both data + // and state shape as a tuple. + virtual StatusOr RngBitGeneratorInternal( + const Shape& full_result_shape, RandomAlgorithm algorithm, + XlaOp initial_state); XlaOp While(const XlaComputation& condition, const XlaComputation& body, XlaOp init);