Added TF to HLO lowering tests for unary operators with dynamic shapes
PiperOrigin-RevId: 271241649
This commit is contained in:
parent
372a0ae937
commit
b7ce325a0d
@ -497,6 +497,20 @@ func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @abs_dynamic
|
||||
func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @abs_rankless
|
||||
func @abs_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ceil
|
||||
func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: "xla_hlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
@ -504,6 +518,20 @@ func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @ceil_dynamic
|
||||
func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @ceil_rankless
|
||||
func @ceil_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cos
|
||||
func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: "xla_hlo.cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
@ -511,6 +539,20 @@ func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cos_dynamic
|
||||
func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cos_rankless
|
||||
func @cos_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @exp
|
||||
func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
@ -518,6 +560,20 @@ func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @exp_dynamic
|
||||
func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @exp_rankless
|
||||
func @exp_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @floor
|
||||
func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: "xla_hlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
@ -525,6 +581,20 @@ func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @floor_dynamic
|
||||
func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @floor_rankless
|
||||
func @floor_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @neg
|
||||
func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
@ -532,6 +602,20 @@ func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @neg_dynamic
|
||||
func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @neg_rankless
|
||||
func @neg_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sigmoid
|
||||
func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK-DAG: [[R0:%.+]] = "xla_hlo.constant"() {value = dense<5.000000e-01> : tensor<f32>} : () -> tensor<f32>
|
||||
@ -572,6 +656,21 @@ func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tanh_dynamic
|
||||
func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tanh_rankless
|
||||
func @tanh_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: reshape
|
||||
func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<1x1xf32> {
|
||||
// CHECK: %0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x1xf32>
|
||||
|
Loading…
x
Reference in New Issue
Block a user