Lower TensorFlow random generator ops in the fallback path
* Override RngBitGenerator op in MlirHloBuilder * Enable relevant compiler tests PiperOrigin-RevId: 327093293 Change-Id: Ib124c0b08c25255edb008cfdd350acaa0067e64c
This commit is contained in:
parent
eb377d252e
commit
589587081e
@ -312,6 +312,16 @@ StatusOr<XlaOp> MlirHloBuilder::RngOpInternal(
|
||||
return CreateOp(op_name, shape, operands);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::RngBitGeneratorInternal(
|
||||
const Shape& full_result_shape, RandomAlgorithm algorithm,
|
||||
XlaOp initial_state) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
full_result_shape, builder_));
|
||||
auto op = builder_.create<mlir::mhlo::RngBitGeneratorOp>(
|
||||
loc_, ty, builder_.getI32IntegerAttr(algorithm), GetValue(initial_state));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::ReshapeInternal(const Shape& shape,
|
||||
XlaOp operand,
|
||||
int64 inferred_dimension) {
|
||||
|
@ -183,6 +183,9 @@ class MlirHloBuilder : public XlaBuilder {
|
||||
StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution,
|
||||
absl::Span<const XlaOp> parameters,
|
||||
const Shape& shape) override;
|
||||
StatusOr<XlaOp> RngBitGeneratorInternal(const Shape& full_result_shape,
|
||||
RandomAlgorithm algorithm,
|
||||
XlaOp initial_state) override;
|
||||
|
||||
StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
|
||||
int64 inferred_dimension) override;
|
||||
|
@ -290,6 +290,14 @@ func @diag(%arg0: tensor<2xf32>) -> tensor<2x2xf32> {
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: random_uniform_int
|
||||
func @random_uniform_int(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<1000xi32> {
|
||||
%0 = "tf.Const"() {value = dense<1000> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK-NOT: tf.RandomUniformInt
|
||||
%1 = "tf.RandomUniformInt"(%0, %arg0, %arg1) {seed = 0 : i64, seed2 = 0 : i64} : (tensor<1xi32>, tensor<i32>, tensor<i32>) -> tensor<1000xi32>
|
||||
return %1 : tensor<1000xi32>
|
||||
}
|
||||
|
||||
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
||||
// available but doesn't support this instance.
|
||||
}
|
||||
|
@ -167,12 +167,14 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::NegOp>(),
|
||||
TypeID::get<TF::NonMaxSuppressionV4Op>(),
|
||||
TypeID::get<TF::NotEqualOp>(),
|
||||
TypeID::get<TF::MultinomialOp>(),
|
||||
TypeID::get<TF::PadOp>(),
|
||||
TypeID::get<TF::PlaceholderWithDefaultOp>(),
|
||||
TypeID::get<TF::PowOp>(),
|
||||
TypeID::get<TF::RFFT2DOp>(),
|
||||
TypeID::get<TF::RFFT3DOp>(),
|
||||
TypeID::get<TF::RGBToHSVOp>(),
|
||||
TypeID::get<TF::RandomUniformIntOp>(),
|
||||
TypeID::get<TF::RealDivOp>(),
|
||||
TypeID::get<TF::ReciprocalOp>(),
|
||||
TypeID::get<TF::ReciprocalGradOp>(),
|
||||
@ -199,6 +201,11 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::SparseToDenseOp>(),
|
||||
TypeID::get<TF::SqrtGradOp>(),
|
||||
TypeID::get<TF::SquareOp>(),
|
||||
TypeID::get<TF::StatelessMultinomialOp>(),
|
||||
TypeID::get<TF::StatelessRandomNormalOp>(),
|
||||
TypeID::get<TF::StatelessRandomUniformOp>(),
|
||||
TypeID::get<TF::StatelessRandomUniformIntOp>(),
|
||||
TypeID::get<TF::StatelessTruncatedNormalOp>(),
|
||||
TypeID::get<TF::SubOp>(),
|
||||
TypeID::get<TF::TanOp>(),
|
||||
TypeID::get<TF::TransposeOp>(),
|
||||
|
@ -265,6 +265,7 @@ tf_xla_py_test(
|
||||
name = "categorical_op_test",
|
||||
size = "small",
|
||||
srcs = ["categorical_op_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
@ -1285,6 +1286,7 @@ tf_xla_py_test(
|
||||
name = "stateless_random_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["stateless_random_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
|
@ -96,7 +96,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"Handle complex element types in DiagPart op lowering")
|
||||
"Handle complex element type in DiagPart lowering")
|
||||
def testAllTypeOps(self):
|
||||
for dtype in self.numeric_types - {np.int8, np.uint8}:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
|
@ -1984,7 +1984,6 @@ XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) {
|
||||
XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm,
|
||||
XlaOp initial_state, const Shape& shape) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
|
||||
TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state));
|
||||
Shape output_shape = shape;
|
||||
@ -2003,14 +2002,22 @@ XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm,
|
||||
return InvalidArgument("Unsupported shape for RngBitGenerator: %s",
|
||||
PrimitiveType_Name(output_shape.element_type()));
|
||||
}
|
||||
*instr.mutable_shape() =
|
||||
ShapeUtil::MakeTupleShape({state_shape, output_shape}).ToProto();
|
||||
instr.set_rng_algorithm(algorithm);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator,
|
||||
{initial_state});
|
||||
return RngBitGeneratorInternal(
|
||||
ShapeUtil::MakeTupleShape({state_shape, output_shape}), algorithm,
|
||||
initial_state);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::RngBitGeneratorInternal(
|
||||
const Shape& full_result_shape, RandomAlgorithm algorithm,
|
||||
XlaOp initial_state) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = full_result_shape.ToProto();
|
||||
instr.set_rng_algorithm(algorithm);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator,
|
||||
{initial_state});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::While(const XlaComputation& condition,
|
||||
const XlaComputation& body, XlaOp init) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
|
@ -712,6 +712,11 @@ class XlaBuilder {
|
||||
|
||||
XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
|
||||
const Shape& shape);
|
||||
// Internal variant for the op with the full result shape containing both data
|
||||
// and state shape as a tuple.
|
||||
virtual StatusOr<XlaOp> RngBitGeneratorInternal(
|
||||
const Shape& full_result_shape, RandomAlgorithm algorithm,
|
||||
XlaOp initial_state);
|
||||
|
||||
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
|
||||
XlaOp init);
|
||||
|
Loading…
Reference in New Issue
Block a user