Add more tests to internal passes used by mlir_gpu and kernel generator.
PiperOrigin-RevId: 334104569 Change-Id: Icecfbb0a57b22a96ff0b3df7db0d2f7f646ff75b
This commit is contained in:
parent
fc72453228
commit
1c3145aeba
@ -0,0 +1,20 @@
|
||||
// RUN: xla-mlir-gpu-opt --mlir-gpu-fusion-op-remover %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @fusion_memref
|
||||
func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>,
|
||||
%input3: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
// CHECK-NOT: lmhlo.fusion
|
||||
"lmhlo.fusion"() ( {
|
||||
%0 = tensor_load %input1 : memref<10xf32>
|
||||
%1 = tensor_load %input2 : memref<10xf32>
|
||||
%2 = "mhlo.add"(%0, %1) {name = "add"}
|
||||
: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
%3 = tensor_load %input3 : memref<10xf32>
|
||||
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"}
|
||||
: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
tensor_store %4, %out : memref<10xf32>
|
||||
// CHECK-NOT: lmhlo.terminator
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) : () -> ()
|
||||
return
|
||||
}
|
@ -0,0 +1,138 @@
|
||||
// RUN: xla-mlir-gpu-opt --mlir-gpu-rewrite-signatures %s --split-input-file --verify-diagnostics | FileCheck %s
|
||||
|
||||
module attributes {gpu.container_module} {
|
||||
|
||||
// CHECK-LABEL: @kernel_module
|
||||
gpu.module @kernel_module {
|
||||
// CHECK-LABEL: gpu.func @kernel
|
||||
// CHECK-SAME: %{{.*}}: memref<32xf32>, %{{.*}}: memref<16xf32>,
|
||||
// CHECK-SAME: %{{.*}}: memref<8xf32>
|
||||
gpu.func @kernel(%arg0: memref<8xf32>, %arg1: memref<16xf32>,
|
||||
%arg2: memref<32xf32>) kernel {
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @caller
|
||||
func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> {
|
||||
%cst = constant 8 : index
|
||||
%res = alloc() : memref<8xf32>
|
||||
|
||||
// CHECK: gpu.launch_func
|
||||
// CHECK-SAME: index, memref<32xf32>, memref<16xf32>, memref<8xf32>)
|
||||
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %res, %arg1, %arg0)
|
||||
{ kernel = @kernel_module::@kernel }
|
||||
: (index, index, index, index, index, index,
|
||||
memref<8xf32>, memref<16xf32>, memref<32xf32>) -> ()
|
||||
|
||||
return %res : memref<8xf32>
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {gpu.container_module} {
|
||||
|
||||
gpu.module @kernel_module {
|
||||
// expected-error @+1 {{number of kernel arguments does not match numberof arguments and results of surrounding function}}
|
||||
gpu.func @kernel(%arg0: memref<16xf32>, %arg1: memref<32xf32>) kernel {
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> {
|
||||
%cst = constant 8 : index
|
||||
%res = alloc() : memref<8xf32>
|
||||
|
||||
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %arg1, %arg0)
|
||||
{ kernel = @kernel_module::@kernel }
|
||||
: (index, index, index, index, index, index,
|
||||
memref<16xf32>, memref<32xf32>) -> ()
|
||||
|
||||
return %res : memref<8xf32>
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {gpu.container_module} {
|
||||
|
||||
gpu.module @kernel_module {
|
||||
// expected-error @+1 {{result 0 of containing function is not an argument to the kernel}}
|
||||
gpu.func @kernel(%arg0: memref<16xf32>, %arg1: memref<32xf32>,
|
||||
%arg2: memref<8xf32>) kernel {
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> {
|
||||
%cst = constant 8 : index
|
||||
%res = alloc() : memref<8xf32>
|
||||
%fake = alloc() : memref<8xf32>
|
||||
|
||||
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %arg1, %arg0, %fake)
|
||||
{ kernel = @kernel_module::@kernel }
|
||||
: (index, index, index, index, index, index,
|
||||
memref<16xf32>, memref<32xf32>, memref<8xf32>) -> ()
|
||||
|
||||
return %res : memref<8xf32>
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {gpu.container_module} {
|
||||
|
||||
gpu.module @kernel_module {
|
||||
// expected-error @+1 {{argument 1 to containing function is not an argument to the kernel}}
|
||||
gpu.func @kernel(%arg0: memref<16xf32>, %arg1: memref<32xf32>,
|
||||
%arg2: memref<8xf32>) kernel {
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> {
|
||||
%cst = constant 8 : index
|
||||
%res = alloc() : memref<8xf32>
|
||||
%fake = alloc() : memref<16xf32>
|
||||
|
||||
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %fake, %arg0, %res)
|
||||
{ kernel = @kernel_module::@kernel }
|
||||
: (index, index, index, index, index, index,
|
||||
memref<16xf32>, memref<32xf32>, memref<8xf32>) -> ()
|
||||
|
||||
return %res : memref<8xf32>
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {gpu.container_module} {
|
||||
|
||||
gpu.module @kernel_module {
|
||||
gpu.func @kernel(%arg0: memref<8xf32>, %arg1: memref<16xf32>,
|
||||
%arg2: memref<32xf32>) kernel {
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
// expected-error @+1 {{surrounding function has more than one block}}
|
||||
func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> {
|
||||
%cst = constant 8 : index
|
||||
%res = alloc() : memref<8xf32>
|
||||
br ^bb1
|
||||
|
||||
^bb1:
|
||||
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %res, %arg1, %arg0)
|
||||
{ kernel = @kernel_module::@kernel }
|
||||
: (index, index, index, index, index, index,
|
||||
memref<8xf32>, memref<16xf32>, memref<32xf32>) -> ()
|
||||
|
||||
return %res : memref<8xf32>
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,59 @@
|
||||
// RUN: xla-mlir-gpu-opt --mlir-gpu-store-forwarding %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @forward
|
||||
func @forward() -> f32 {
|
||||
%0 = alloc() : memref<1024xf32>
|
||||
%c42 = constant 24 : index
|
||||
// CHECK: %[[CST:.*]] = constant 1.0
|
||||
%c1 = constant 1.0 : f32
|
||||
store %c1, %0[%c42] : memref<1024xf32>
|
||||
// CHECK-NOT: load
|
||||
%1 = load %0[%c42] : memref<1024xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
return %1 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @wrong_index
|
||||
func @wrong_index() -> f32 {
|
||||
%0 = alloc() : memref<1024xf32>
|
||||
%c42 = constant 24 : index
|
||||
%c12 = constant 12 : index
|
||||
%c1 = constant 1.0 : f32
|
||||
store %c1, %0[%c42] : memref<1024xf32>
|
||||
// CHECK: %[[RES:.*]] = load
|
||||
%1 = load %0[%c12] : memref<1024xf32>
|
||||
// CHECK: return %[[RES]]
|
||||
return %1 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @wrong_memref
|
||||
func @wrong_memref() -> f32 {
|
||||
%0 = alloc() : memref<1024xf32>
|
||||
%1 = alloc() : memref<1024xf32>
|
||||
%c42 = constant 24 : index
|
||||
%c1 = constant 1.0 : f32
|
||||
store %c1, %0[%c42] : memref<1024xf32>
|
||||
// CHECK: %[[RES:.*]] = load
|
||||
%2 = load %1[%c42] : memref<1024xf32>
|
||||
// CHECK: return %[[RES]]
|
||||
return %2 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @with_parallel_loop
|
||||
func @with_parallel_loop() {
|
||||
%0 = alloc() : memref<1024xf32>
|
||||
%c0 = constant 0 : index
|
||||
%c42 = constant 24 : index
|
||||
%c1 = constant 1 : index
|
||||
// CHECK: %[[CST:.*]] = constant 1.100000e+01 : f32
|
||||
%c11 = constant 1.100000e+01 : f32
|
||||
store %c11, %0[%c42] : memref<1024xf32>
|
||||
// CHECK: scf.parallel
|
||||
scf.parallel (%i0) = (%c0) to (%c42) step (%c1) {
|
||||
// CHECK-NOT: load
|
||||
%1 = load %0[%c42] : memref<1024xf32>
|
||||
// CHECK-NEXT: store %[[CST]]
|
||||
store %1, %0[%c0] : memref<1024xf32>
|
||||
}
|
||||
return
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user