Legalize tf.Multinomial with tf2xla.

PiperOrigin-RevId: 327692741
Change-Id: Icf3443597a3277c728efa9460ada59739f646835
This commit is contained in:
A. Unique TensorFlower 2020-08-20 13:59:20 -07:00 committed by TensorFlower Gardener
parent 0dc35b4d7d
commit 139a1f1860
2 changed files with 9 additions and 1 deletions

View File

@ -298,6 +298,14 @@ func @random_uniform_int(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<1000x
return %1 : tensor<1000xi32>
}
// CHECK-LABEL: multinomial
func @multinomial(%arg0: tensor<2x4xf32>, %seed: tensor<i32>, %seed2: tensor<i32>) -> tensor<2x10xi32> {
// CHECK-NOT: tf.Multinomial
%samples = "tf.Const"() { value = dense<10> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.Multinomial"(%arg0, %samples) {seed = 0, seed2 = 0}: (tensor<2x4xf32>, tensor<i32>) -> tensor<2x10xi32>
return %1 : tensor<2x10xi32>
}
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
// available but doesn't support this instance.
}

View File

@ -177,10 +177,10 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::MatrixTriangularSolveOp>(),
TypeID::get<TF::MirrorPadOp>(),
TypeID::get<TF::MulOp>(),
TypeID::get<TF::MultinomialOp>(),
TypeID::get<TF::NegOp>(),
TypeID::get<TF::NonMaxSuppressionV4Op>(),
TypeID::get<TF::NotEqualOp>(),
TypeID::get<TF::MultinomialOp>(),
TypeID::get<TF::PadOp>(),
TypeID::get<TF::PlaceholderWithDefaultOp>(),
TypeID::get<TF::PowOp>(),