Add more tests to internal passes used by mlir_gpu and kernel generator.

PiperOrigin-RevId: 334104569
Change-Id: Icecfbb0a57b22a96ff0b3df7db0d2f7f646ff75b
This commit is contained in:
Stephan Herhut 2020-09-28 01:14:44 -07:00 committed by TensorFlower Gardener
parent fc72453228
commit 1c3145aeba
3 changed files with 217 additions and 0 deletions

View File

@ -0,0 +1,20 @@
// RUN: xla-mlir-gpu-opt --mlir-gpu-fusion-op-remover %s | FileCheck %s
// CHECK-LABEL: func @fusion_memref
func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>,
%input3: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-NOT: lmhlo.fusion
"lmhlo.fusion"() ( {
%0 = tensor_load %input1 : memref<10xf32>
%1 = tensor_load %input2 : memref<10xf32>
%2 = "mhlo.add"(%0, %1) {name = "add"}
: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%3 = tensor_load %input3 : memref<10xf32>
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"}
: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
tensor_store %4, %out : memref<10xf32>
// CHECK-NOT: lmhlo.terminator
"lmhlo.terminator"() : () -> ()
} ) : () -> ()
return
}

View File

@ -0,0 +1,138 @@
// RUN: xla-mlir-gpu-opt --mlir-gpu-rewrite-signatures %s --split-input-file --verify-diagnostics | FileCheck %s
module attributes {gpu.container_module} {
// CHECK-LABEL: @kernel_module
gpu.module @kernel_module {
// CHECK-LABEL: gpu.func @kernel
// CHECK-SAME: %{{.*}}: memref<32xf32>, %{{.*}}: memref<16xf32>,
// CHECK-SAME: %{{.*}}: memref<8xf32>
gpu.func @kernel(%arg0: memref<8xf32>, %arg1: memref<16xf32>,
%arg2: memref<32xf32>) kernel {
gpu.return
}
}
// CHECK-LABEL: @caller
func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> {
%cst = constant 8 : index
%res = alloc() : memref<8xf32>
// CHECK: gpu.launch_func
// CHECK-SAME: index, memref<32xf32>, memref<16xf32>, memref<8xf32>)
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %res, %arg1, %arg0)
{ kernel = @kernel_module::@kernel }
: (index, index, index, index, index, index,
memref<8xf32>, memref<16xf32>, memref<32xf32>) -> ()
return %res : memref<8xf32>
}
}
// -----
module attributes {gpu.container_module} {
gpu.module @kernel_module {
// expected-error @+1 {{number of kernel arguments does not match numberof arguments and results of surrounding function}}
gpu.func @kernel(%arg0: memref<16xf32>, %arg1: memref<32xf32>) kernel {
gpu.return
}
}
func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> {
%cst = constant 8 : index
%res = alloc() : memref<8xf32>
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %arg1, %arg0)
{ kernel = @kernel_module::@kernel }
: (index, index, index, index, index, index,
memref<16xf32>, memref<32xf32>) -> ()
return %res : memref<8xf32>
}
}
// -----
module attributes {gpu.container_module} {
gpu.module @kernel_module {
// expected-error @+1 {{result 0 of containing function is not an argument to the kernel}}
gpu.func @kernel(%arg0: memref<16xf32>, %arg1: memref<32xf32>,
%arg2: memref<8xf32>) kernel {
gpu.return
}
}
func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> {
%cst = constant 8 : index
%res = alloc() : memref<8xf32>
%fake = alloc() : memref<8xf32>
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %arg1, %arg0, %fake)
{ kernel = @kernel_module::@kernel }
: (index, index, index, index, index, index,
memref<16xf32>, memref<32xf32>, memref<8xf32>) -> ()
return %res : memref<8xf32>
}
}
// -----
module attributes {gpu.container_module} {
gpu.module @kernel_module {
// expected-error @+1 {{argument 1 to containing function is not an argument to the kernel}}
gpu.func @kernel(%arg0: memref<16xf32>, %arg1: memref<32xf32>,
%arg2: memref<8xf32>) kernel {
gpu.return
}
}
func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> {
%cst = constant 8 : index
%res = alloc() : memref<8xf32>
%fake = alloc() : memref<16xf32>
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %fake, %arg0, %res)
{ kernel = @kernel_module::@kernel }
: (index, index, index, index, index, index,
memref<16xf32>, memref<32xf32>, memref<8xf32>) -> ()
return %res : memref<8xf32>
}
}
// -----
module attributes {gpu.container_module} {
gpu.module @kernel_module {
gpu.func @kernel(%arg0: memref<8xf32>, %arg1: memref<16xf32>,
%arg2: memref<32xf32>) kernel {
gpu.return
}
}
// expected-error @+1 {{surrounding function has more than one block}}
func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> {
%cst = constant 8 : index
%res = alloc() : memref<8xf32>
br ^bb1
^bb1:
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %res, %arg1, %arg0)
{ kernel = @kernel_module::@kernel }
: (index, index, index, index, index, index,
memref<8xf32>, memref<16xf32>, memref<32xf32>) -> ()
return %res : memref<8xf32>
}
}

View File

@ -0,0 +1,59 @@
// RUN: xla-mlir-gpu-opt --mlir-gpu-store-forwarding %s | FileCheck %s
// CHECK-LABEL: @forward
func @forward() -> f32 {
%0 = alloc() : memref<1024xf32>
%c42 = constant 24 : index
// CHECK: %[[CST:.*]] = constant 1.0
%c1 = constant 1.0 : f32
store %c1, %0[%c42] : memref<1024xf32>
// CHECK-NOT: load
%1 = load %0[%c42] : memref<1024xf32>
// CHECK: return %[[CST]]
return %1 : f32
}
// CHECK-LABEL: @wrong_index
func @wrong_index() -> f32 {
%0 = alloc() : memref<1024xf32>
%c42 = constant 24 : index
%c12 = constant 12 : index
%c1 = constant 1.0 : f32
store %c1, %0[%c42] : memref<1024xf32>
// CHECK: %[[RES:.*]] = load
%1 = load %0[%c12] : memref<1024xf32>
// CHECK: return %[[RES]]
return %1 : f32
}
// CHECK-LABEL: @wrong_memref
func @wrong_memref() -> f32 {
%0 = alloc() : memref<1024xf32>
%1 = alloc() : memref<1024xf32>
%c42 = constant 24 : index
%c1 = constant 1.0 : f32
store %c1, %0[%c42] : memref<1024xf32>
// CHECK: %[[RES:.*]] = load
%2 = load %1[%c42] : memref<1024xf32>
// CHECK: return %[[RES]]
return %2 : f32
}
// CHECK-LABEL: @with_parallel_loop
func @with_parallel_loop() {
%0 = alloc() : memref<1024xf32>
%c0 = constant 0 : index
%c42 = constant 24 : index
%c1 = constant 1 : index
// CHECK: %[[CST:.*]] = constant 1.100000e+01 : f32
%c11 = constant 1.100000e+01 : f32
store %c11, %0[%c42] : memref<1024xf32>
// CHECK: scf.parallel
scf.parallel (%i0) = (%c0) to (%c42) step (%c1) {
// CHECK-NOT: load
%1 = load %0[%c42] : memref<1024xf32>
// CHECK-NEXT: store %[[CST]]
store %1, %0[%c0] : memref<1024xf32>
}
return
}