Add missing lowering step for IsFiniteOp.
Also add a BUILD target for generating the GPU kernel. PiperOrigin-RevId: 334993362 Change-Id: I712f9d138c7050503d98a53412cd9c923a1ff60d
This commit is contained in:
parent
81d9fcbf83
commit
fed53e76a7
@ -506,6 +506,7 @@ void populateHLOToLHLOConversionPattern(
|
|||||||
HloToLhloOpConverter<mhlo::GatherOp>,
|
HloToLhloOpConverter<mhlo::GatherOp>,
|
||||||
HloToLhloOpConverter<mhlo::ImagOp>,
|
HloToLhloOpConverter<mhlo::ImagOp>,
|
||||||
HloToLhloOpConverter<mhlo::IotaOp>,
|
HloToLhloOpConverter<mhlo::IotaOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::IsFiniteOp>,
|
||||||
HloToLhloOpConverter<mhlo::LogOp>,
|
HloToLhloOpConverter<mhlo::LogOp>,
|
||||||
HloToLhloOpConverter<mhlo::MaxOp>,
|
HloToLhloOpConverter<mhlo::MaxOp>,
|
||||||
HloToLhloOpConverter<mhlo::MinOp>,
|
HloToLhloOpConverter<mhlo::MinOp>,
|
||||||
|
@ -601,3 +601,14 @@ func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memre
|
|||||||
tensor_store %result_tensor, %result: memref<4x4xf16>
|
tensor_store %result_tensor, %result: memref<4x4xf16>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----
|
||||||
|
|
||||||
|
// BOTH-LABEL: func @isfinite
|
||||||
|
func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||||
|
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
|
||||||
|
// BOTH: "lmhlo.is_finite"(%{{.*}}, %{{.*}})
|
||||||
|
%result_tensor = "mhlo.is_finite"(%arg0_tensor) : (tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||||
|
tensor_store %result_tensor, %result: memref<2x2xi1>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -243,6 +243,18 @@ gen_kernel_library(
|
|||||||
unroll_factors = "4",
|
unroll_factors = "4",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gen_kernel_library(
|
||||||
|
name = "isfinite",
|
||||||
|
same_shape = "0,1",
|
||||||
|
tile_size = "256",
|
||||||
|
types = [
|
||||||
|
"f16",
|
||||||
|
"f32",
|
||||||
|
"f64",
|
||||||
|
],
|
||||||
|
unroll_factors = "4",
|
||||||
|
)
|
||||||
|
|
||||||
gen_kernel_library(
|
gen_kernel_library(
|
||||||
name = "log",
|
name = "log",
|
||||||
same_shape = "0,1",
|
same_shape = "0,1",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user