TF.Rsqrt to xla_hlo.rsqrt lowering

PiperOrigin-RevId: 270795357
This commit is contained in:
A. Unique TensorFlower 2019-09-23 17:15:37 -07:00 committed by TensorFlower Gardener
parent 11ce7f6429
commit 953e80a3d1
2 changed files with 23 additions and 1 deletions

View File

@ -537,6 +537,27 @@ func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
return %0 : tensor<2xf32>
}
// CHECK-LABEL: func @rsqrt
func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "tf.Rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// CHECK-LABEL: func @rsqrt_dynamic
func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @rsqrt_rankless
func @rsqrt_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @tanh
func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: "xla_hlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>

View File

@ -191,9 +191,10 @@ foreach Mapping = [
[TF_FloorOp, HLO_FloorOp],
[TF_LogOp, HLO_LogOp],
[TF_NegOp, HLO_NegOp],
[TF_RsqrtOp, HLO_RsqrtOp],
[TF_TanhOp, HLO_TanhOp],
] in {
def : Pat<(Mapping[0] AnyTensor:$input),
def : Pat<(Mapping[0] HLO_Tensor:$input),
(Mapping[1] $input)>;
}