From 139a1f18603d8fdc050b576efcc2f376b94a3040 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 20 Aug 2020 13:59:20 -0700 Subject: [PATCH] Legalize tf.Multinomial with tf2xla. PiperOrigin-RevId: 327692741 Change-Id: Icf3443597a3277c728efa9460ada59739f646835 --- .../compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir | 8 ++++++++ .../mlir/xla/transforms/legalize_tf_with_tf2xla.cc | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index de1e592157e..df4f0303a84 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -298,6 +298,14 @@ func @random_uniform_int(%arg0: tensor, %arg1: tensor) -> tensor<1000x return %1 : tensor<1000xi32> } +// CHECK-LABEL: multinomial +func @multinomial(%arg0: tensor<2x4xf32>, %seed: tensor, %seed2: tensor) -> tensor<2x10xi32> { + // CHECK-NOT: tf.Multinomial + %samples = "tf.Const"() { value = dense<10> : tensor } : () -> tensor + %1 = "tf.Multinomial"(%arg0, %samples) {seed = 0, seed2 = 0}: (tensor<2x4xf32>, tensor) -> tensor<2x10xi32> + return %1 : tensor<2x10xi32> +} + // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 3ab89e49cb2..1eb2292ba20 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -177,10 +177,10 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(),