Enable fallback legalization for Erfinv and Ndtri ops

PiperOrigin-RevId: 342333062
Change-Id: I5bd57860564db162274c59983adf00bbacde288b
This commit is contained in:
Smit Hinsu 2020-11-13 13:59:23 -08:00 committed by TensorFlower Gardener
parent 19e6114760
commit 242e920d7f
2 changed files with 16 additions and 0 deletions

View File

@ -307,6 +307,20 @@ func @set_dynamic_dimension_size(%input: tensor<4xf32>, %size: tensor<i32>) -> t
return %0 : tensor<4xf32>
}
// CHECK-LABEL: @erfinv
func @erfinv(%input: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: tf.Erfinv
%0 = "tf.Erfinv"(%input) : (tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// CHECK-LABEL: @ndtri
func @ndtri(%input: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: tf.Ndtri
%0 = "tf.Ndtri"(%input) : (tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
// available but doesn't support this instance.

View File

@ -132,6 +132,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::EluOp>(),
TypeID::get<TF::EqualOp>(),
TypeID::get<TF::ErfcOp>(),
TypeID::get<TF::ErfinvOp>(),
TypeID::get<TF::ErfOp>(),
TypeID::get<TF::Expm1Op>(),
TypeID::get<TF::ExtractImagePatchesOp>(),
@ -181,6 +182,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::MirrorPadGradOp>(),
TypeID::get<TF::MulOp>(),
TypeID::get<TF::MultinomialOp>(),
TypeID::get<TF::NdtriOp>(),
TypeID::get<TF::NegOp>(),
TypeID::get<TF::NextAfterOp>(),
TypeID::get<TF::NonMaxSuppressionV4Op>(),