Add op definitions for unary kernels which already can be generated.
PiperOrigin-RevId: 329907244 Change-Id: Ibda80685eeb6bbe34818fba12a2f46500a038ee3
This commit is contained in:
parent
73a37969bf
commit
d74cd4a2a0
@ -132,3 +132,99 @@ gen_kernel_library(
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "ceil",
|
||||
same_shape = "0,1",
|
||||
tile_size = "256",
|
||||
types = [
|
||||
"f16",
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "cos",
|
||||
same_shape = "0,1",
|
||||
tile_size = "256",
|
||||
types = [
|
||||
"f16",
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "exp",
|
||||
same_shape = "0,1",
|
||||
tile_size = "256",
|
||||
types = [
|
||||
"f16",
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "floor",
|
||||
same_shape = "0,1",
|
||||
tile_size = "256",
|
||||
types = [
|
||||
"f16",
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "log",
|
||||
same_shape = "0,1",
|
||||
tile_size = "256",
|
||||
types = [
|
||||
"f16",
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "neg",
|
||||
same_shape = "0,1",
|
||||
tile_size = "256",
|
||||
types = [
|
||||
"f16",
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "rsqrt",
|
||||
same_shape = "0,1",
|
||||
tile_size = "256",
|
||||
types = [
|
||||
"f16",
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "sqrt",
|
||||
same_shape = "0,1",
|
||||
tile_size = "256",
|
||||
types = [
|
||||
"f16",
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
@ -0,0 +1,5 @@
|
||||
func @ceil(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
|
||||
%0 = "tf.Ceil"(%arg0) { }
|
||||
: (tensor<?xelem_type>) -> tensor<?xelem_type>
|
||||
return %0 : tensor<?xelem_type>
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
func @cos(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
|
||||
%0 = "tf.Cos"(%arg0) { }
|
||||
: (tensor<?xelem_type>) -> tensor<?xelem_type>
|
||||
return %0 : tensor<?xelem_type>
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
func @exp(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
|
||||
%0 = "tf.Exp"(%arg0) { }
|
||||
: (tensor<?xelem_type>) -> tensor<?xelem_type>
|
||||
return %0 : tensor<?xelem_type>
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
func @floor(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
|
||||
%0 = "tf.Floor"(%arg0) { }
|
||||
: (tensor<?xelem_type>) -> tensor<?xelem_type>
|
||||
return %0 : tensor<?xelem_type>
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
func @log(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
|
||||
%0 = "tf.Log"(%arg0) { }
|
||||
: (tensor<?xelem_type>) -> tensor<?xelem_type>
|
||||
return %0 : tensor<?xelem_type>
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
func @neg(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
|
||||
%0 = "tf.Neg"(%arg0) { }
|
||||
: (tensor<?xelem_type>) -> tensor<?xelem_type>
|
||||
return %0 : tensor<?xelem_type>
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
func @rsqrt(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
|
||||
%0 = "tf.Rsqrt"(%arg0) { }
|
||||
: (tensor<?xelem_type>) -> tensor<?xelem_type>
|
||||
return %0 : tensor<?xelem_type>
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
func @sqrt(%arg0: tensor<?xelem_type>) -> tensor<?xelem_type> {
|
||||
%0 = "tf.Sqrt"(%arg0) { }
|
||||
: (tensor<?xelem_type>) -> tensor<?xelem_type>
|
||||
return %0 : tensor<?xelem_type>
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user