Legalize tf.Multinomial with tf2xla.
PiperOrigin-RevId: 327692741 Change-Id: Icf3443597a3277c728efa9460ada59739f646835
This commit is contained in:
parent
0dc35b4d7d
commit
139a1f1860
@ -298,6 +298,14 @@ func @random_uniform_int(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<1000x
|
||||
return %1 : tensor<1000xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: multinomial
|
||||
func @multinomial(%arg0: tensor<2x4xf32>, %seed: tensor<i32>, %seed2: tensor<i32>) -> tensor<2x10xi32> {
|
||||
// CHECK-NOT: tf.Multinomial
|
||||
%samples = "tf.Const"() { value = dense<10> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.Multinomial"(%arg0, %samples) {seed = 0, seed2 = 0}: (tensor<2x4xf32>, tensor<i32>) -> tensor<2x10xi32>
|
||||
return %1 : tensor<2x10xi32>
|
||||
}
|
||||
|
||||
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
||||
// available but doesn't support this instance.
|
||||
}
|
||||
|
@ -177,10 +177,10 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::MatrixTriangularSolveOp>(),
|
||||
TypeID::get<TF::MirrorPadOp>(),
|
||||
TypeID::get<TF::MulOp>(),
|
||||
TypeID::get<TF::MultinomialOp>(),
|
||||
TypeID::get<TF::NegOp>(),
|
||||
TypeID::get<TF::NonMaxSuppressionV4Op>(),
|
||||
TypeID::get<TF::NotEqualOp>(),
|
||||
TypeID::get<TF::MultinomialOp>(),
|
||||
TypeID::get<TF::PadOp>(),
|
||||
TypeID::get<TF::PlaceholderWithDefaultOp>(),
|
||||
TypeID::get<TF::PowOp>(),
|
||||
|
Loading…
Reference in New Issue
Block a user