TF.Rsqrt to xla_hlo.rsqrt lowering
PiperOrigin-RevId: 270795357
This commit is contained in:
parent
11ce7f6429
commit
953e80a3d1
@ -537,6 +537,27 @@ func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @rsqrt
|
||||
func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "tf.Rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @rsqrt_dynamic
|
||||
func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @rsqrt_rankless
|
||||
func @rsqrt_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tanh
|
||||
func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: "xla_hlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
|
@ -191,9 +191,10 @@ foreach Mapping = [
|
||||
[TF_FloorOp, HLO_FloorOp],
|
||||
[TF_LogOp, HLO_LogOp],
|
||||
[TF_NegOp, HLO_NegOp],
|
||||
[TF_RsqrtOp, HLO_RsqrtOp],
|
||||
[TF_TanhOp, HLO_TanhOp],
|
||||
] in {
|
||||
def : Pat<(Mapping[0] AnyTensor:$input),
|
||||
def : Pat<(Mapping[0] HLO_Tensor:$input),
|
||||
(Mapping[1] $input)>;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user