Run mlir_gpu tests from files.

This makes it easier to iterate over tests, because the test doesn't have to be
recompiled all the time.

PiperOrigin-RevId: 295913906
Change-Id: I975a8db086aceb862498f2d63138cb0bf4859c00
This commit is contained in:
Adrian Kuegel 2020-02-19 00:55:21 -08:00 committed by TensorFlower Gardener
parent 8a97955b84
commit 6d44543263
32 changed files with 572 additions and 478 deletions

View File

@ -193,7 +193,9 @@ cc_library(
"//tensorflow/compiler/xla/tests:codegen_test_base",
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:verified_hlo_module",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core/platform:resource_loader",
"//tensorflow/core/platform:test",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",

View File

@ -32,6 +32,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/filecheck.h"
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@ -46,8 +49,10 @@ void MlirIrGenTestBase::CompileIr(std::unique_ptr<HloModule> hlo_module,
TF_ASSERT_OK(status);
}
void MlirIrGenTestBase::PatternMatch(const string& str, const string& pattern) {
StatusOr<bool> filecheck_result = RunFileCheck(str, pattern);
void MlirIrGenTestBase::PatternMatch(const std::string& str,
const std::string& pattern_file) {
StatusOr<bool> filecheck_result =
RunFileCheckWithPatternFile(str, pattern_file);
TF_ASSERT_OK(filecheck_result.status());
EXPECT_TRUE(filecheck_result.ValueOrDie());
}
@ -55,7 +60,7 @@ void MlirIrGenTestBase::PatternMatch(const string& str, const string& pattern) {
string MlirIrGenTestBase::CompileIr(
std::unique_ptr<HloModule> hlo_module,
MlirCompiler::IRHook::LoweringStage printing_stage) {
string ir;
std::string ir;
CompileIr(std::move(hlo_module),
{[&ir](mlir::ModuleOp module) -> Status {
std::string buffer_string;
@ -70,23 +75,21 @@ string MlirIrGenTestBase::CompileIr(
}
void MlirIrGenTestBase::CompileAndVerifyIr(
std::unique_ptr<HloModule> hlo_module, const string& pattern,
std::unique_ptr<HloModule> hlo_module, const std::string& pattern_file,
LoweringStage printing_stage) {
string ir = CompileIr(std::move(hlo_module), printing_stage);
PatternMatch(ir, pattern);
std::string ir = CompileIr(std::move(hlo_module), printing_stage);
PatternMatch(ir, pattern_file);
}
void MlirIrGenTestBase::CompileAndVerifyIr(const string& hlo_text,
const string& expected_llvm_ir,
void MlirIrGenTestBase::CompileAndVerifyIr(const std::string& hlo_text_filename,
LoweringStage printing_stage) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
auto module = absl::make_unique<VerifiedHloModule>(
"Module", config, /*verifier_layout_sensitive=*/true,
/*allow_mixed_precision_in_hlo_verifier=*/false,
/*shape_size_function=*/ShapeUtil::ByteSizeOfElements);
TF_ASSERT_OK(module->ParseHloStringAndVerifyModule(hlo_text));
CompileAndVerifyIr(std::move(module), expected_llvm_ir, printing_stage);
std::string hlo_text_absolute_filename =
tensorflow::GetDataDependencyFilepath(hlo_text_filename);
TF_ASSERT_OK_AND_ASSIGN(auto module,
GetVerifiedHloModule(hlo_text_absolute_filename));
CompileAndVerifyIr(std::move(module),
/*pattern_file=*/hlo_text_absolute_filename,
printing_stage);
}
MlirCompiler::IRHook MlirIrGenTestBase::getIRHookBreakingLoweringStage(
@ -104,7 +107,7 @@ MlirCompiler::IRHook MlirIrGenTestBase::getIRHookBreakingLoweringStage(
StatusOr<string> MlirIrGenTestBase::CompileAndInjectErrors(
std::unique_ptr<HloModule> hlo_module, LoweringStage breaking_stage) {
string errors;
std::string errors;
auto error_handler = [&errors](const EmissionContext::ErrorMap& error_map,
HloModule* hlo_module) {
errors = "ERRORS FOUND: ";
@ -127,19 +130,32 @@ StatusOr<string> MlirIrGenTestBase::CompileAndInjectErrors(
return status;
}
void MlirIrGenTestBase::CompileAndVerifyErrors(const string& hlo_text,
const string& expected_errors,
LoweringStage breaking_stage) {
void MlirIrGenTestBase::CompileAndVerifyErrors(
const std::string& hlo_text_filename, LoweringStage breaking_stage) {
std::string test_srcdir = tensorflow::testing::TensorFlowSrcRoot();
std::string hlo_text_absolute_filename =
tensorflow::GetDataDependencyFilepath(hlo_text_filename);
TF_ASSERT_OK_AND_ASSIGN(auto module,
GetVerifiedHloModule(hlo_text_absolute_filename));
TF_ASSERT_OK_AND_ASSIGN(
std::string errors,
CompileAndInjectErrors(std::move(module), breaking_stage));
PatternMatch(errors, /*pattern_file=*/hlo_text_absolute_filename);
}
StatusOr<std::unique_ptr<VerifiedHloModule>>
MlirIrGenTestBase::GetVerifiedHloModule(const std::string& hlo_text_filename) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
auto module = absl::make_unique<VerifiedHloModule>(
"Module", config, /*verifier_layout_sensitive=*/true,
/*allow_mixed_precision_in_hlo_verifier=*/false,
/*shape_size_function=*/ShapeUtil::ByteSizeOfElements);
TF_ASSERT_OK(module->ParseHloStringAndVerifyModule(hlo_text));
TF_ASSERT_OK_AND_ASSIGN(
string errors, CompileAndInjectErrors(std::move(module), breaking_stage));
PatternMatch(errors, expected_errors);
std::string hlo_text;
TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(
tensorflow::Env::Default(), hlo_text_filename, &hlo_text));
TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
return std::move(module);
}
MlirCompiler* MlirIrGenTestBase::GetMLIRCompiler() {

View File

@ -39,38 +39,36 @@ class MlirIrGenTestBase : public CodegenTestBase {
// steps to LLVM IR are applied; otherwise, the IR before lowering is
// matched.
void CompileAndVerifyIr(std::unique_ptr<HloModule> hlo_module,
const string& pattern, LoweringStage printing_stage);
const std::string& pattern_file,
LoweringStage printing_stage);
// A thin wrapper around CompileAndVerifyIr that parses `hlo_text` to create
// an HLO module.
void CompileAndVerifyIr(const string& hlo_text,
const string& expected_llvm_ir,
// A thin wrapper around CompileAndVerifyIr that parses the hlo text in
// `hlo_text_filename` to create an HLO module.
void CompileAndVerifyIr(const std::string& hlo_text_filename,
LoweringStage printing_stage = LoweringStage::LHLO);
// Compiles and returns module with optimizations from a given HLO.
StatusOr<std::unique_ptr<HloModule>> GetOptimizedModule(
absl::string_view hlo);
// Adds the InjectErrorsForTestingPass to MLIRCompiler on the provided
// lowering stage, compiles the given HLO module, and returns a string
// lowering stage, compiles the given HLO module, and returns a std::string
// representation of all the errors occurred during compiling.
StatusOr<string> CompileAndInjectErrors(std::unique_ptr<HloModule> hlo_module,
LoweringStage breaking_stage);
// Adds the InjectErrorsForTestingPass to MLIRCompiler on the provided
// lowering stage, parses and compiles `hlo_text`, and verifies that the
// string representation of all the errors occurred during compiling matches
// the given pattern.
void CompileAndVerifyErrors(const string& hlo_text,
const string& expected_errors,
// std::string representation of all the errors occurred during compiling
// matches the given pattern.
void CompileAndVerifyErrors(const std::string& hlo_text_filename,
LoweringStage breaking_stage);
private:
StatusOr<std::unique_ptr<VerifiedHloModule>> GetVerifiedHloModule(
const std::string& hlo_text_filename);
void CompileIr(std::unique_ptr<HloModule> hlo_module,
const MlirCompiler::IRHook& ir_hook);
void PatternMatch(const string& str, const string& pattern);
string CompileIr(std::unique_ptr<HloModule> hlo_module,
LoweringStage printing_stage);
void PatternMatch(const std::string& str, const std::string& pattern_file);
std::string CompileIr(std::unique_ptr<HloModule> hlo_module,
LoweringStage printing_stage);
MlirCompiler::IRHook getIRHookBreakingLoweringStage(
LoweringStage breaking_stage);
MlirCompiler* GetMLIRCompiler();

View File

@ -25,11 +25,39 @@ package_group(
tf_cc_test(
name = "mlir_gpu_lhlo_gen_test",
srcs = if_cuda_is_configured(["mlir_gpu_lhlo_gen_test.cc"]),
data = [
"abs.hlo",
"add.hlo",
"add_as_kernel.hlo",
"add_in_gpu_dialect.hlo",
"add_multiply.hlo",
"add_multiply_gpu.hlo",
"add_reduce.hlo",
"broadcast.hlo",
"broken_add.hlo",
"ceil.hlo",
"compare.hlo",
"const.hlo",
"copy.hlo",
"cos.hlo",
"exp.hlo",
"fused_reduce.hlo",
"iota.hlo",
"iota_add_multiply.hlo",
"log.hlo",
"neg.hlo",
"rem.hlo",
"rsqrt.hlo",
"select.hlo",
"sign.hlo",
"tanh.hlo",
],
tags = tf_cuda_tests_tags() + ["no_rocm"],
deps = [
"//tensorflow/core:test_main",
"//tensorflow/core:test",
] + if_cuda_is_configured([
"//tensorflow/core:lib",
"//tensorflow/compiler/xla/service:gpu_plugin_mlir",
"//tensorflow/compiler/xla/service/mlir_gpu:mlir_irgen_test_base",
"//tensorflow/stream_executor/lib",

View File

@ -0,0 +1,9 @@
HloModule Abs
ENTRY %Abs (val: f32[2,2]) -> f32[2,2] {
%val = f32[2,2]{1,0} parameter(0)
ROOT %abs = f32[2,2]{1,0} abs(f32[2,2]{1,0} %val)
}
// CHECK: func @abs(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
// CHECK: "xla_lhlo.abs"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
// CHECK: }

View File

@ -0,0 +1,11 @@
HloModule Add
ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
%y = f32[2,2]{1,0} parameter(1)
ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
}
// CHECK: func @add(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) {
// CHECK: "xla_lhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
// CHECK: }

View File

@ -0,0 +1,62 @@
HloModule Add
ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
%y = f32[2,2]{1,0} parameter(1)
ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
}
// CHECK: func @add_kernel(%[[ARG0:.*]]: [[TYPE:!llvm<.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]
//
// Check that relevant sizes and strides are emitted.
//
// CHECK: %[[CAST0:.*]] = llvm.bitcast %[[ARG0:.*]] : !llvm<"i8*"> to !llvm<"float*">
// CHECK: %[[SIZE00:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
// CHECK: %[[SIZE01:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
// CHECK: %[[STRIDE01:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
// CHECK: %[[STRIDE00:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG1:.*]] : !llvm<"i8*"> to !llvm<"float*">
// CHECK: %[[SIZE10:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
// CHECK: %[[SIZE11:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
// CHECK: %[[STRIDE11:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
// CHECK: %[[STRIDE10:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[ARG2:.*]] : !llvm<"i8*"> to !llvm<"float*">
// CHECK: %[[SIZE20:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
// CHECK: %[[SIZE21:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
// CHECK: %[[STRIDE21:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
// CHECK: %[[STRIDE20:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
//
// Check that the emitted sizes and strides, as well the pointers to HLO buffers,
// are inserted into the memref descriptors.
//
// CHECK: %[[DESC0:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC01:.*]] = llvm.insertvalue %[[CAST0]], %[[DESC0]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC02:.*]] = llvm.insertvalue %[[CAST0]], %[[DESC01]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC03:.*]] = llvm.insertvalue %{{.*}}, %[[DESC02]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC04:.*]] = llvm.insertvalue %[[SIZE00]], %[[DESC03]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC05:.*]] = llvm.insertvalue %[[STRIDE00]], %[[DESC04]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC06:.*]] = llvm.insertvalue %[[SIZE01]], %[[DESC05]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE01]], %[[DESC06]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC1:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC11:.*]] = llvm.insertvalue %[[CAST1]], %[[DESC1]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC12:.*]] = llvm.insertvalue %[[CAST1]], %[[DESC11]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC13:.*]] = llvm.insertvalue %{{.*}}, %[[DESC12]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC14:.*]] = llvm.insertvalue %[[SIZE10]], %[[DESC13]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC15:.*]] = llvm.insertvalue %[[STRIDE10]], %[[DESC14]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC16:.*]] = llvm.insertvalue %[[SIZE11]], %[[DESC15]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE11]], %[[DESC16]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC2:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC21:.*]] = llvm.insertvalue %[[CAST2]], %[[DESC2]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC22:.*]] = llvm.insertvalue %[[CAST2]], %[[DESC21]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC23:.*]] = llvm.insertvalue %{{.*}}, %[[DESC22]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC24:.*]] = llvm.insertvalue %[[SIZE20]], %[[DESC23]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC25:.*]] = llvm.insertvalue %[[STRIDE20]], %[[DESC24]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC26:.*]] = llvm.insertvalue %[[SIZE21]], %[[DESC25]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE21]], %[[DESC26]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">

View File

@ -0,0 +1,19 @@
HloModule Add
ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
%y = f32[2,2]{1,0} parameter(1)
ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
}
// CHECK: func @add(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) {
// CHECK: "gpu.launch_func"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ARG0]], %[[ARG1]], %[[ARG2]]
// CHECK: }
// CHECK: func @add_kernel(%[[ARG0]]: [[TYPE]], %[[ARG1]]: [[TYPE]], %[[ARG2]]: [[TYPE]]
// CHECK-DAG: std.subview %[[ARG0]]{{\[}}[[INDEX:.*]]]
// CHECK-DAG: std.subview %[[ARG1]]{{\[}}[[INDEX]]]
// CHECK-DAG: std.subview %[[ARG2]]{{\[}}[[INDEX]]]
// CHECK: %[[VAL1:.*]] = load %{{.*\[}}[[INDEX:.*]]]
// CHECK: %[[VAL2:.*]] = load %{{.*\[}}[[INDEX]]]
// CHECK: %[[RES:.*]] = addf %[[VAL1]], %[[VAL2]]
// CHECK: store %[[RES]], %{{.*\[}}[[INDEX]]]

View File

@ -0,0 +1,21 @@
HloModule AddMultiply
ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
%y = f32[2,2]{1,0} parameter(1)
%z = f32[2,2]{1,0} parameter(2)
%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
ROOT %mul = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %add, f32[2,2]{1,0} %z)
}
// CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]], %[[RESULT:.*]]: [[TYPE]])
// CHECK: "xla_lhlo.fusion"() ( {
// CHECK: %[[REF0:.*]] = tensor_load %[[ARG0]] : [[TYPE]]
// CHECK: %[[REF1:.*]] = tensor_load %[[ARG1]] : [[TYPE]]
// CHECK: %[[REF2:.*]] = tensor_load %[[ARG2]] : [[TYPE]]
// CHECK: %[[ADD:.*]] = xla_hlo.add %[[REF1]], %[[REF2]]
// CHECK: %[[MUL:.*]] = xla_hlo.mul %[[ADD]], %[[REF0]]
// CHECK: tensor_store %[[MUL]], %[[RESULT]]
// CHECK: "xla_lhlo.terminator"()
// CHECK-NEXT: }

View File

@ -0,0 +1,22 @@
HloModule AddMultiply
ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
%y = f32[2,2]{1,0} parameter(1)
%z = f32[2,2]{1,0} parameter(2)
%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
ROOT %mul = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %add, f32[2,2]{1,0} %z)
}
// CHECK: func @fusion_kernel(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]], %[[RESULT:.*]]: [[TYPE]])
// CHECK-DAG: std.subview %[[ARG0]]{{\[}}[[INDEX:.*]]]
// CHECK-DAG: std.subview %[[ARG1]]{{\[}}[[INDEX]]]
// CHECK-DAG: std.subview %[[ARG2]]{{\[}}[[INDEX]]]
// CHECK-DAG: std.subview %[[RESULT]]{{\[}}[[INDEX]]]
// CHECK: %[[V0:.*]] = load %{{.*\[}}[[CSTIDX:.*]]]
// CHECK: %[[V1:.*]] = load %{{.*\[}}[[CSTIDX:.*]]]
// CHECK: %[[ADD:.*]] = addf %[[V0]], %[[V1]]
// CHECK: %[[V2:.*]] = load %{{.*\[}}[[CSTIDX:.*]]]
// CHECK: %[[MUL:.*]] = mulf %[[ADD]], %[[V2]]
// CHECK: store %[[MUL]], %{{.*\[}}[[CSTIDX:.*]]]
// CHECK-NEXT: return

View File

@ -0,0 +1,23 @@
HloModule AddReduce
%add (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
ENTRY %AddReduce (x: f32[100,10], c: f32[]) -> f32[100] {
%x = f32[100,10]{1,0} parameter(0)
%c = f32[] parameter(1)
ROOT %reduce = f32[100]{0} reduce(f32[100,10]{1,0} %x, f32[] %c), dimensions={1}, to_apply=%add
}
// CHECK: func @reduce(%[[ARG:.*]]: [[ARGT:.*]], %[[CST:.*]]: memref<f32>, %[[RES:.*]]: [[REST:.*]]) {
// CHECK: "xla_lhlo.reduce"(%[[ARG]], %[[CST]], %[[RES]]) ( {
// CHECK: ^bb0(%[[FARG0:.*]]: memref<f32>, %[[FARG1:.*]]: memref<f32>, %[[FRES:.*]]: memref<f32>):
// CHECK: %[[LHS:.*]] = tensor_load %[[FARG0]] : memref<f32>
// CHECK: %[[RHS:.*]] = tensor_load %[[FARG1]] : memref<f32>
// CHECK: %[[RES:.*]] = xla_hlo.add %[[LHS]], %[[RHS]] : tensor<f32>
// CHECK: tensor_store %[[RES]], %[[FRES]] : memref<f32>
// CHECK: "xla_lhlo.terminator"() : () -> ()
// CHECK-NEXT: }) {dimensions = dense<1> : tensor<1xi64>} : ([[ARGT]], memref<f32>, [[REST]]) -> ()

View File

@ -0,0 +1,13 @@
HloModule Broadcast
ENTRY %Broadcast (x: f32[10]) -> f32[10, 5] {
%x = f32[10]{0} parameter(0)
ROOT %broadcast = f32[10, 5]{1,0} broadcast(f32[10]{0} %x), dimensions={0}
}
// CHECK: func @broadcast(%[[IN:.*]]: [[IN_T:.*]], %[[OUT:.*]]: [[OUT_T:.*]]) {
// CHECK: "xla_lhlo.broadcast_in_dim"(%[[IN]], %[[OUT]])
// CHECK: {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK: : ([[IN_T]], [[OUT_T]]) -> ()
// CHECK: }

View File

@ -0,0 +1,9 @@
HloModule Add
ENTRY %Add (x: f32[2,2,2], y: f32[2,2,2]) -> f32[2,2,2] {
%x = f32[2,2,2]{2,1,0} parameter(0)
%y = f32[2,2,2]{2,1,0} parameter(1)
ROOT %add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y)
}
// CHECK: ERRORS FOUND: [%add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y): failed for testing: xla_lhlo.add; failed for testing: std.return]

View File

@ -0,0 +1,9 @@
HloModule Ceil
ENTRY %Ceil (val: f32[2,2]) -> f32[2,2] {
%val = f32[2,2]{1,0} parameter(0)
ROOT %ceil = f32[2,2]{1,0} ceil(f32[2,2]{1,0} %val)
}
// CHECK: func @ceil(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
// CHECK: "xla_lhlo.ceil"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
// CHECK: }

View File

@ -0,0 +1,12 @@
HloModule Compare
ENTRY %Compare (x: f32[2,2], y: f32[2,2]) -> pred[2,2] {
%x = f32[2,2]{1,0} parameter(0)
%y = f32[2,2]{1,0} parameter(1)
ROOT %compare = pred[2,2]{1,0} compare(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y), direction=EQ
}
// CHECK: func @compare(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[PRED:.*]]: [[PRED_TYPE:.*]]) {
// CHECK: "xla_lhlo.compare"(%[[ARG0]], %[[ARG1]], %[[PRED]])
// CHECK: {comparison_direction = "EQ"} : ([[TYPE]], [[TYPE]], [[PRED_TYPE]]) -> ()
// CHECK: }

View File

@ -0,0 +1,11 @@
HloModule Const
ENTRY %Const () -> s32[100] {
%const.0 = s32[] constant(10)
ROOT %broadcast.0 = s32[100]{0} broadcast(s32[] %const.0), dimensions={}
}
// CHECK: func @constant(%[[ARG0:.*]]: memref<i32>)
// CHECK: "xla_lhlo.constant"(%[[ARG0]]) {value = dense<10> : tensor<i32>}
// CHECK: func @broadcast(%[[ARG1:.*]]: memref<i32>, %[[ARG2:.*]]: memref<100xi32>)
// CHECK: "xla_lhlo.broadcast_in_dim"(%[[ARG1]], %[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>}

View File

@ -0,0 +1,9 @@
HloModule Copy
ENTRY %Copy (x: f32[2,4]) -> f32[2,4] {
%x = f32[2,4] parameter(0)
ROOT %copy = f32[2,4] copy(f32[2,4] %x)
}
// CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4xf32>, %[[RESULT:.*]]: memref<2x4xf32>) {
// CHECK: "xla_lhlo.copy"(%[[OPERAND]], %[[RESULT]]) : (memref<2x4xf32>, memref<2x4xf32>) -> ()

View File

@ -0,0 +1,9 @@
HloModule Cos
ENTRY %Cos (val: f32[2,2]) -> f32[2,2] {
%val = f32[2,2]{1,0} parameter(0)
ROOT %cos = f32[2,2]{1,0} cosine(f32[2,2]{1,0} %val)
}
// CHECK: func @cosine(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
// CHECK: "xla_lhlo.cos"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
// CHECK: }

View File

@ -0,0 +1,11 @@
HloModule Exp
ENTRY %Exp (x: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
ROOT %exp = f32[2,2]{1,0} exponential(f32[2,2]{1,0} %x)
}
// CHECK: func @exponential(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
// CHECK: "xla_lhlo.exp"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
// CHECK: }

View File

@ -0,0 +1,34 @@
HloModule FusedReduce
%add (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
%fused_computation (param: f32[100,10]) -> f32[10] {
%param = f32[100,10] parameter(0)
%constant = f32[] constant(0)
ROOT %reduce = f32[10]{0} reduce(f32[100,10]{1,0} %param, f32[] %constant),
dimensions={0}, to_apply=%add
}
ENTRY %FusedReduce (x: f32[100,10]) -> f32[10] {
%x = f32[100,10] parameter(0)
ROOT %fusion = f32[10]{0} fusion(f32[100,10]{1,0} %x), kind=kInput,
calls=%fused_computation
}
// CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[RTYPE:.*]])
// CHECK: "xla_lhlo.fusion"() ( {
// CHECK: %[[REF0:.*]] = tensor_load %arg0 : [[TYPE]]
// CHECK: %[[CT0:.*]] = xla_hlo.constant dense<0.000000e+00>
// CHECK: %[[RED:.*]] = "xla_hlo.reduce"(%0, %1) ( {
// CHECK: ^bb0(%[[BARG0:.*]]: [[ETYPE:.*]], %[[BARG1:.*]]: [[ETYPE]])
// CHECK: %[[ADD:.*]] = xla_hlo.add %[[BARG0]], %[[BARG1]] : [[ETYPE]]
// CHECK: "xla_hlo.return"(%[[ADD]])
// CHECK: })
// CHECK: tensor_store %[[RED]], %[[RESULT]] : [[RTYPE]]
// CHECK: "xla_lhlo.terminator"()
// CHECK-NEXT: })

View File

@ -0,0 +1,10 @@
HloModule Iota
ENTRY %Iota() -> s64[10, 5] {
ROOT %iota = s64[10, 5]{1,0} iota(), iota_dimension=0
}
// CHECK: func @iota(%[[OUT:.*]]: [[OUT_T:.*]]) {
// CHECK: "xla_lhlo.iota"(%[[OUT]])
// CHECK: {iota_dimension = 0 : i64} : ([[OUT_T]]) -> ()
// CHECK: }

View File

@ -0,0 +1,15 @@
HloModule AddMultiply
ENTRY %AddMultiply (x: s32[2,2], y: s32[2,2]) -> s32[2,2] {
%x = s32[2,2]{1,0} parameter(0)
%y = s32[2,2]{1,0} parameter(1)
%add = s32[2,2]{1,0} add(s32[2,2]{1,0} %x, s32[2,2]{1,0} %y)
%iota = s32[2, 2]{1,0} iota(), iota_dimension=0
ROOT %mul = s32[2,2]{1,0} multiply(s32[2,2]{1,0} %add, s32[2,2]{1,0} %iota)
}
// CHECK-NOT: store
// CHECK: %[[RESULT:.*]] = muli %{{.*}}, %{{.*}}
// CHECK: store %[[RESULT]]

View File

@ -0,0 +1,10 @@
HloModule Log
ENTRY %Log (x: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
ROOT %log = f32[2,2]{1,0} log(f32[2,2]{1,0} %x)
}
// CHECK: func @log(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
// CHECK: "xla_lhlo.log"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
// CHECK: }

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h"
#include "tensorflow/core/platform/path.h"
namespace xla {
namespace mlir_gpu {
@ -21,513 +22,174 @@ namespace mlir_gpu {
class LhloGenTest : public MlirIrGenTestBase {};
TEST_F(LhloGenTest, Const) {
CompileAndVerifyIr(R"(
HloModule Const
ENTRY %Const () -> s32[100] {
%const.0 = s32[] constant(10)
ROOT %broadcast.0 = s32[100]{0} broadcast(s32[] %const.0), dimensions={}
})",
R"(
;CHECK: func @constant(%[[ARG0:.*]]: memref<i32>)
;CHECK: "xla_lhlo.constant"(%[[ARG0]]) {value = dense<10> : tensor<i32>}
;CHECK: func @broadcast(%[[ARG1:.*]]: memref<i32>, %[[ARG2:.*]]: memref<100xi32>)
;CHECK: "xla_lhlo.broadcast_in_dim"(%[[ARG1]], %[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>}
)",
LoweringStage::LHLO);
CompileAndVerifyIr(
/*hlo_text_filename=*/tensorflow::io::JoinPath(
"tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests",
"const.hlo"),
LoweringStage::LHLO);
}
TEST_F(LhloGenTest, BrokenAdd) {
CompileAndVerifyErrors(
R"(
HloModule Add
ENTRY %Add (x: f32[2,2,2], y: f32[2,2,2]) -> f32[2,2,2] {
%x = f32[2,2,2]{2,1,0} parameter(0)
%y = f32[2,2,2]{2,1,0} parameter(1)
ROOT %add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y)
})",
R"(CHECK: ERRORS FOUND: [%add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y): failed for testing: xla_lhlo.add; failed for testing: std.return])",
/*hlo_text_filename=*/
tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service",
"mlir_gpu", "tests", "broken_add.hlo"),
LoweringStage::LHLO);
}
TEST_F(LhloGenTest, Add) {
CompileAndVerifyIr(R"(
HloModule Add
ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
%y = f32[2,2]{1,0} parameter(1)
ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
})",
R"(
;CHECK: func @add(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) {
;CHECK: "xla_lhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
;CHECK: }
)");
CompileAndVerifyIr(
/*hlo_text_filename=*/tensorflow::io::JoinPath(
"tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests",
"add.hlo"));
}
TEST_F(LhloGenTest, Compare) {
CompileAndVerifyIr(R"(
HloModule Compare
ENTRY %Compare (x: f32[2,2], y: f32[2,2]) -> pred[2,2] {
%x = f32[2,2]{1,0} parameter(0)
%y = f32[2,2]{1,0} parameter(1)
ROOT %compare = pred[2,2]{1,0} compare(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y), direction=EQ
})",
R"(
;CHECK: func @compare(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[PRED:.*]]: [[PRED_TYPE:.*]]) {
;CHECK: "xla_lhlo.compare"(%[[ARG0]], %[[ARG1]], %[[PRED]])
;CHECK: {comparison_direction = "EQ"} : ([[TYPE]], [[TYPE]], [[PRED_TYPE]]) -> ()
;CHECK: }
)");
CompileAndVerifyIr(
/*hlo_text_filename=*/tensorflow::io::JoinPath(
"tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests",
"compare.hlo"));
}
TEST_F(LhloGenTest, Copy) {
CompileAndVerifyIr(R"(
HloModule Copy
ENTRY %Copy (x: f32[2,4]) -> f32[2,4] {
%x = f32[2,4] parameter(0)
ROOT %copy = f32[2,4] copy(f32[2,4] %x)
})",
R"(
;CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4xf32>, %[[RESULT:.*]]: memref<2x4xf32>) {
;CHECK: "xla_lhlo.copy"(%[[OPERAND]], %[[RESULT]]) : (memref<2x4xf32>, memref<2x4xf32>) -> ()
)");
CompileAndVerifyIr(
/*hlo_text_filename=*/tensorflow::io::JoinPath(
"tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests",
"copy.hlo"));
}
TEST_F(LhloGenTest, Select) {
CompileAndVerifyIr(R"(
HloModule Select
ENTRY %Select (p: pred[2,2], x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
%p = pred[2,2]{1,0} parameter(0)
%x = f32[2,2]{1,0} parameter(1)
%y = f32[2,2]{1,0} parameter(2)
ROOT %select = f32[2,2]{1,0} select(pred[2,2]{1,0} %p, f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
})",
R"(
;CHECK: func @select(%[[PRED:.*]]: [[PRED_TYPE:.*]], %[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) {
;CHECK: "xla_lhlo.select"(%[[PRED]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[PRED_TYPE]], [[TYPE]], [[TYPE]], [[TYPE]]) -> ()
;CHECK: }
)");
CompileAndVerifyIr(
/*hlo_text_filename=*/tensorflow::io::JoinPath(
"tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests",
"select.hlo"));
}
TEST_F(LhloGenTest, Exp) {
CompileAndVerifyIr(R"(
HloModule Exp
ENTRY %Exp (x: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
ROOT %exp = f32[2,2]{1,0} exponential(f32[2,2]{1,0} %x)
})",
R"(
;CHECK: func @exponential(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
;CHECK: "xla_lhlo.exp"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
;CHECK: }
)");
CompileAndVerifyIr(
/*hlo_text_filename=*/tensorflow::io::JoinPath(
"tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests",
"exp.hlo"));
}
TEST_F(LhloGenTest, Log) {
CompileAndVerifyIr(R"(
HloModule Log
ENTRY %Log (x: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
ROOT %log = f32[2,2]{1,0} log(f32[2,2]{1,0} %x)
})",
R"(
;CHECK: func @log(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
;CHECK: "xla_lhlo.log"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
;CHECK: }
)");
CompileAndVerifyIr(
/*hlo_text_filename=*/tensorflow::io::JoinPath(
"tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests",
"log.hlo"));
}
TEST_F(LhloGenTest, AddInGPUDialect) {
CompileAndVerifyIr(R"(
HloModule Add
ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
%y = f32[2,2]{1,0} parameter(1)
ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
})",
R"(
;CHECK: func @add(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) {
;CHECK: "gpu.launch_func"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ARG0]], %[[ARG1]], %[[ARG2]]
;CHECK: }
;CHECK: func @add_kernel(%[[ARG0]]: [[TYPE]], %[[ARG1]]: [[TYPE]], %[[ARG2]]: [[TYPE]]
;CHECK-DAG: std.subview %[[ARG0]]{{\[}}[[INDEX:.*]]]
;CHECK-DAG: std.subview %[[ARG1]]{{\[}}[[INDEX]]]
;CHECK-DAG: std.subview %[[ARG2]]{{\[}}[[INDEX]]]
;CHECK: %[[VAL1:.*]] = load %{{.*\[}}[[INDEX:.*]]]
;CHECK: %[[VAL2:.*]] = load %{{.*\[}}[[INDEX]]]
;CHECK: %[[RES:.*]] = addf %[[VAL1]], %[[VAL2]]
;CHECK: store %[[RES]], %{{.*\[}}[[INDEX]]]
)",
LoweringStage::GPU);
CompileAndVerifyIr(
/*hlo_text_filename=*/
tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service",
"mlir_gpu", "tests", "add_in_gpu_dialect.hlo"),
LoweringStage::GPU);
}
// This test verifies that the kernel signature is amended correctly. The actual
// body of the generated function does not matter, it is already checked at the
// GPU level above.
TEST_F(LhloGenTest, AddAsKernel) {
CompileAndVerifyIr(R"(
HloModule Add
ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
%y = f32[2,2]{1,0} parameter(1)
ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
})",
R"(
;CHECK: func @add_kernel(%[[ARG0:.*]]: [[TYPE:!llvm<.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]
;
; Check that relevant sizes and strides are emitted.
;
;CHECK: %[[CAST0:.*]] = llvm.bitcast %[[ARG0:.*]] : !llvm<"i8*"> to !llvm<"float*">
;CHECK: %[[SIZE00:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
;CHECK: %[[SIZE01:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
;CHECK: %[[STRIDE01:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
;CHECK: %[[STRIDE00:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
;CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG1:.*]] : !llvm<"i8*"> to !llvm<"float*">
;CHECK: %[[SIZE10:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
;CHECK: %[[SIZE11:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
;CHECK: %[[STRIDE11:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
;CHECK: %[[STRIDE10:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
;CHECK: %[[CAST2:.*]] = llvm.bitcast %[[ARG2:.*]] : !llvm<"i8*"> to !llvm<"float*">
;CHECK: %[[SIZE20:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
;CHECK: %[[SIZE21:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
;CHECK: %[[STRIDE21:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
;CHECK: %[[STRIDE20:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
;
; Check that the emitted sizes and strides, as well the pointers to HLO buffers,
; are inserted into the memref descriptors.
;
;CHECK: %[[DESC0:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC01:.*]] = llvm.insertvalue %[[CAST0]], %[[DESC0]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC02:.*]] = llvm.insertvalue %[[CAST0]], %[[DESC01]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC03:.*]] = llvm.insertvalue %{{.*}}, %[[DESC02]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC04:.*]] = llvm.insertvalue %[[SIZE00]], %[[DESC03]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC05:.*]] = llvm.insertvalue %[[STRIDE00]], %[[DESC04]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC06:.*]] = llvm.insertvalue %[[SIZE01]], %[[DESC05]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE01]], %[[DESC06]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC1:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC11:.*]] = llvm.insertvalue %[[CAST1]], %[[DESC1]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC12:.*]] = llvm.insertvalue %[[CAST1]], %[[DESC11]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC13:.*]] = llvm.insertvalue %{{.*}}, %[[DESC12]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC14:.*]] = llvm.insertvalue %[[SIZE10]], %[[DESC13]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC15:.*]] = llvm.insertvalue %[[STRIDE10]], %[[DESC14]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC16:.*]] = llvm.insertvalue %[[SIZE11]], %[[DESC15]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE11]], %[[DESC16]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC2:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC21:.*]] = llvm.insertvalue %[[CAST2]], %[[DESC2]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC22:.*]] = llvm.insertvalue %[[CAST2]], %[[DESC21]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC23:.*]] = llvm.insertvalue %{{.*}}, %[[DESC22]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC24:.*]] = llvm.insertvalue %[[SIZE20]], %[[DESC23]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC25:.*]] = llvm.insertvalue %[[STRIDE20]], %[[DESC24]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %[[DESC26:.*]] = llvm.insertvalue %[[SIZE21]], %[[DESC25]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
;CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE21]], %[[DESC26]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
)",
LoweringStage::KERNEL);
CompileAndVerifyIr(
tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service",
"mlir_gpu", "tests", "add_as_kernel.hlo"),
LoweringStage::KERNEL);
}
// TODO(b/149302060) Reenable once fusion is fixed.
TEST_F(LhloGenTest, DISABLED_AddMultiply) {
CompileAndVerifyIr(R"(
HloModule AddMultiply
ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
%y = f32[2,2]{1,0} parameter(1)
%z = f32[2,2]{1,0} parameter(2)
%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
ROOT %mul = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %add, f32[2,2]{1,0} %z)
})",
R"(
;CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]], %[[RESULT:.*]]: [[TYPE]])
;CHECK: "xla_lhlo.fusion"() ( {
;CHECK: %[[REF0:.*]] = tensor_load %[[ARG0]] : [[TYPE]]
;CHECK: %[[REF1:.*]] = tensor_load %[[ARG1]] : [[TYPE]]
;CHECK: %[[REF2:.*]] = tensor_load %[[ARG2]] : [[TYPE]]
;CHECK: %[[ADD:.*]] = xla_hlo.add %[[REF1]], %[[REF2]]
;CHECK: %[[MUL:.*]] = xla_hlo.mul %[[ADD]], %[[REF0]]
;CHECK: tensor_store %[[MUL]], %[[RESULT]]
;CHECK: "xla_lhlo.terminator"()
;CHECK-NEXT: }
)");
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",
"add_multiply.hlo"));
}
// TODO(b/149302060) Reenable once fusion is fixed.
TEST_F(LhloGenTest, DISABLED_IotaAddMultiply) {
CompileAndVerifyIr(R"(
HloModule AddMultiply
ENTRY %AddMultiply (x: s32[2,2], y: s32[2,2]) -> s32[2,2] {
%x = s32[2,2]{1,0} parameter(0)
%y = s32[2,2]{1,0} parameter(1)
%add = s32[2,2]{1,0} add(s32[2,2]{1,0} %x, s32[2,2]{1,0} %y)
%iota = s32[2, 2]{1,0} iota(), iota_dimension=0
ROOT %mul = s32[2,2]{1,0} multiply(s32[2,2]{1,0} %add, s32[2,2]{1,0} %iota)
})",
R"(
;CHECK-NOT: store
;CHECK: %[[RESULT:.*]] = muli %{{.*}}, %{{.*}}
;CHECK: store %[[RESULT]]
)",
LoweringStage::GPU);
CompileAndVerifyIr(
tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service",
"mlir_gpu", "tests", "iota_add_multiply.hlo"),
LoweringStage::GPU);
}
TEST_F(LhloGenTest, AddMultiplyGPU) {
CompileAndVerifyIr(R"(
HloModule AddMultiply
ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
%y = f32[2,2]{1,0} parameter(1)
%z = f32[2,2]{1,0} parameter(2)
%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
ROOT %mul = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %add, f32[2,2]{1,0} %z)
})",
R"(
;CHECK: func @fusion_kernel(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]], %[[RESULT:.*]]: [[TYPE]])
;CHECK-DAG: std.subview %[[ARG0]]{{\[}}[[INDEX:.*]]]
;CHECK-DAG: std.subview %[[ARG1]]{{\[}}[[INDEX]]]
;CHECK-DAG: std.subview %[[ARG2]]{{\[}}[[INDEX]]]
;CHECK-DAG: std.subview %[[RESULT]]{{\[}}[[INDEX]]]
;CHECK: %[[V0:.*]] = load %{{.*\[}}[[CSTIDX:.*]]]
;CHECK: %[[V1:.*]] = load %{{.*\[}}[[CSTIDX:.*]]]
;CHECK: %[[ADD:.*]] = addf %[[V0]], %[[V1]]
;CHECK: %[[V2:.*]] = load %{{.*\[}}[[CSTIDX:.*]]]
;CHECK: %[[MUL:.*]] = mulf %[[ADD]], %[[V2]]
;CHECK: store %[[MUL]], %{{.*\[}}[[CSTIDX:.*]]]
;CHECK-NEXT: return
)",
LoweringStage::GPU);
CompileAndVerifyIr(
tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service",
"mlir_gpu", "tests", "add_multiply_gpu.hlo"),
LoweringStage::GPU);
}
// TODO(b/137624192): Reenable once we can fuse reductions.
TEST_F(LhloGenTest, DISABLED_FusedReduce) {
CompileAndVerifyIr(R"(
HloModule FusedReduce
%add (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
%fused_computation (param: f32[100,10]) -> f32[10] {
%param = f32[100,10] parameter(0)
%constant = f32[] constant(0)
ROOT %reduce = f32[10]{0} reduce(f32[100,10]{1,0} %param, f32[] %constant),
dimensions={0}, to_apply=%add
}
ENTRY %FusedReduce (x: f32[100,10]) -> f32[10] {
%x = f32[100,10] parameter(0)
ROOT %fusion = f32[10]{0} fusion(f32[100,10]{1,0} %x), kind=kInput,
calls=%fused_computation
}
)",
R"(
;CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[RTYPE:.*]])
;CHECK: "xla_lhlo.fusion"() ( {
;CHECK: %[[REF0:.*]] = tensor_load %arg0 : [[TYPE]]
;CHECK: %[[CT0:.*]] = xla_hlo.constant dense<0.000000e+00>
;CHECK: %[[RED:.*]] = "xla_hlo.reduce"(%0, %1) ( {
;CHECK: ^bb0(%[[BARG0:.*]]: [[ETYPE:.*]], %[[BARG1:.*]]: [[ETYPE]])
;CHECK: %[[ADD:.*]] = xla_hlo.add %[[BARG0]], %[[BARG1]] : [[ETYPE]]
;CHECK: "xla_hlo.return"(%[[ADD]])
;CHECK: })
;CHECK: tensor_store %[[RED]], %[[RESULT]] : [[RTYPE]]
;CHECK: "xla_lhlo.terminator"()
;CHECK-NEXT: })
)");
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",
"fused_reduce.hlo"));
}
TEST_F(LhloGenTest, Broadcast) {
CompileAndVerifyIr(R"(
HloModule Broadcast
ENTRY %Broadcast (x: f32[10]) -> f32[10, 5] {
%x = f32[10]{0} parameter(0)
ROOT %broadcast = f32[10, 5]{1,0} broadcast(f32[10]{0} %x), dimensions={0}
})",
R"(
;CHECK: func @broadcast(%[[IN:.*]]: [[IN_T:.*]], %[[OUT:.*]]: [[OUT_T:.*]]) {
;CHECK: "xla_lhlo.broadcast_in_dim"(%[[IN]], %[[OUT]])
;CHECK: {broadcast_dimensions = dense<0> : tensor<1xi64>}
;CHECK: : ([[IN_T]], [[OUT_T]]) -> ()
;CHECK: }
)");
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",
"broadcast.hlo"));
}
TEST_F(LhloGenTest, Iota) {
CompileAndVerifyIr(R"(
HloModule Iota
ENTRY %Iota() -> s64[10, 5] {
ROOT %iota = s64[10, 5]{1,0} iota(), iota_dimension=0
})",
R"(
;CHECK: func @iota(%[[OUT:.*]]: [[OUT_T:.*]]) {
;CHECK: "xla_lhlo.iota"(%[[OUT]])
;CHECK: {iota_dimension = 0 : i64} : ([[OUT_T]]) -> ()
;CHECK: }
)");
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",
"iota.hlo"));
}
TEST_F(LhloGenTest, AddReduce) {
CompileAndVerifyIr(R"(
HloModule AddReduce
%add (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
ENTRY %AddReduce (x: f32[100,10], c: f32[]) -> f32[100] {
%x = f32[100,10]{1,0} parameter(0)
%c = f32[] parameter(1)
ROOT %reduce = f32[100]{0} reduce(f32[100,10]{1,0} %x, f32[] %c), dimensions={1}, to_apply=%add
})",
R"(
;CHECK: func @reduce(%[[ARG:.*]]: [[ARGT:.*]], %[[CST:.*]]: memref<f32>, %[[RES:.*]]: [[REST:.*]]) {
;CHECK: "xla_lhlo.reduce"(%[[ARG]], %[[CST]], %[[RES]]) ( {
;CHECK: ^bb0(%[[FARG0:.*]]: memref<f32>, %[[FARG1:.*]]: memref<f32>, %[[FRES:.*]]: memref<f32>):
;CHECK: %[[LHS:.*]] = tensor_load %[[FARG0]] : memref<f32>
;CHECK: %[[RHS:.*]] = tensor_load %[[FARG1]] : memref<f32>
;CHECK: %[[RES:.*]] = xla_hlo.add %[[LHS]], %[[RHS]] : tensor<f32>
;CHECK: tensor_store %[[RES]], %[[FRES]] : memref<f32>
;CHECK: "xla_lhlo.terminator"() : () -> ()
;CHECK-NEXT: }) {dimensions = dense<1> : tensor<1xi64>} : ([[ARGT]], memref<f32>, [[REST]]) -> ()
)");
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",
"add_reduce.hlo"));
}
TEST_F(LhloGenTest, Abs) {
CompileAndVerifyIr(R"(
HloModule Abs
ENTRY %Abs (val: f32[2,2]) -> f32[2,2] {
%val = f32[2,2]{1,0} parameter(0)
ROOT %abs = f32[2,2]{1,0} abs(f32[2,2]{1,0} %val)
})",
R"(
;CHECK: func @abs(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
;CHECK: "xla_lhlo.abs"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
;CHECK: }
)");
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",
"abs.hlo"));
}
TEST_F(LhloGenTest, Ceil) {
CompileAndVerifyIr(R"(
HloModule Ceil
ENTRY %Ceil (val: f32[2,2]) -> f32[2,2] {
%val = f32[2,2]{1,0} parameter(0)
ROOT %ceil = f32[2,2]{1,0} ceil(f32[2,2]{1,0} %val)
})",
R"(
;CHECK: func @ceil(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
;CHECK: "xla_lhlo.ceil"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
;CHECK: }
)");
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",
"ceil.hlo"));
}
TEST_F(LhloGenTest, Cos) {
CompileAndVerifyIr(R"(
HloModule Cos
ENTRY %Cos (val: f32[2,2]) -> f32[2,2] {
%val = f32[2,2]{1,0} parameter(0)
ROOT %cos = f32[2,2]{1,0} cosine(f32[2,2]{1,0} %val)
})",
R"(
;CHECK: func @cosine(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
;CHECK: "xla_lhlo.cos"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
;CHECK: }
)");
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",
"cos.hlo"));
}
TEST_F(LhloGenTest, Neg) {
CompileAndVerifyIr(R"(
HloModule Neg
ENTRY %Neg (val: f32[2,2]) -> f32[2,2] {
%val = f32[2,2]{1,0} parameter(0)
ROOT %neg = f32[2,2]{1,0} negate(f32[2,2]{1,0} %val)
})",
R"(
;CHECK: func @negate(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
;CHECK: "xla_lhlo.neg"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
;CHECK: }
)");
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",
"neg.hlo"));
}
TEST_F(LhloGenTest, Rem) {
CompileAndVerifyIr(R"(
HloModule Rem
ENTRY %Rem(x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
%y = f32[2,2]{1,0} parameter(1)
ROOT %rem = f32[2,2]{1,0} remainder(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
})",
R"(
;CHECK: func @remainder(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) {
;CHECK: "xla_lhlo.remainder"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
;CHECK: }
)");
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",
"rem.hlo"));
}
TEST_F(LhloGenTest, Rsqrt) {
CompileAndVerifyIr(R"(
HloModule Rsqrt
ENTRY %Rsqrt (x: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
ROOT %rsqrt = f32[2,2]{1,0} rsqrt(f32[2,2]{1,0} %x)
})",
R"(
;CHECK: func @rsqrt(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
;CHECK: "xla_lhlo.rsqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
;CHECK: }
)");
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",
"rsqrt.hlo"));
}
TEST_F(LhloGenTest, Sign) {
CompileAndVerifyIr(R"(
HloModule Sign
ENTRY %Sign (val: f32[2,2]) -> f32[2,2] {
%val = f32[2,2]{1,0} parameter(0)
ROOT %sign = f32[2,2]{1,0} sign(f32[2,2]{1,0} %val)
})",
R"(
;CHECK: func @sign(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
;CHECK: "xla_lhlo.sign"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
;CHECK: }
)");
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",
"rsqrt.hlo"));
}
TEST_F(LhloGenTest, Tanh) {
CompileAndVerifyIr(R"(
HloModule Tanh
ENTRY %Tanh (val: f32[2,2]) -> f32[2,2] {
%val = f32[2,2]{1,0} parameter(0)
ROOT %tanh = f32[2,2]{1,0} tanh(f32[2,2]{1,0} %val)
})",
R"(
;CHECK: func @tanh(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
;CHECK: "xla_lhlo.tanh"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
;CHECK: }
)");
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",
"tanh.hlo"));
}
} // namespace mlir_gpu

View File

@ -0,0 +1,9 @@
HloModule Neg
ENTRY %Neg (val: f32[2,2]) -> f32[2,2] {
%val = f32[2,2]{1,0} parameter(0)
ROOT %neg = f32[2,2]{1,0} negate(f32[2,2]{1,0} %val)
}
// CHECK: func @negate(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
// CHECK: "xla_lhlo.neg"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
// CHECK: }

View File

@ -0,0 +1,10 @@
HloModule Rem
ENTRY %Rem(x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
%y = f32[2,2]{1,0} parameter(1)
ROOT %rem = f32[2,2]{1,0} remainder(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
}
// CHECK: func @remainder(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) {
// CHECK: "xla_lhlo.remainder"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
// CHECK: }

View File

@ -0,0 +1,10 @@
HloModule Rsqrt
ENTRY %Rsqrt (x: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
ROOT %rsqrt = f32[2,2]{1,0} rsqrt(f32[2,2]{1,0} %x)
}
// CHECK: func @rsqrt(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
// CHECK: "xla_lhlo.rsqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
// CHECK: }

View File

@ -0,0 +1,13 @@
HloModule Select
ENTRY %Select (p: pred[2,2], x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
%p = pred[2,2]{1,0} parameter(0)
%x = f32[2,2]{1,0} parameter(1)
%y = f32[2,2]{1,0} parameter(2)
ROOT %select = f32[2,2]{1,0} select(pred[2,2]{1,0} %p, f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
}
// CHECK: func @select(%[[PRED:.*]]: [[PRED_TYPE:.*]], %[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) {
// CHECK: "xla_lhlo.select"(%[[PRED]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[PRED_TYPE]], [[TYPE]], [[TYPE]], [[TYPE]]) -> ()
// CHECK: }

View File

@ -0,0 +1,9 @@
HloModule Sign
ENTRY %Sign (val: f32[2,2]) -> f32[2,2] {
%val = f32[2,2]{1,0} parameter(0)
ROOT %sign = f32[2,2]{1,0} sign(f32[2,2]{1,0} %val)
}
// CHECK: func @sign(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
// CHECK: "xla_lhlo.sign"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
// CHECK: }

View File

@ -0,0 +1,9 @@
HloModule Tanh
ENTRY %Tanh (val: f32[2,2]) -> f32[2,2] {
%val = f32[2,2]{1,0} parameter(0)
ROOT %tanh = f32[2,2]{1,0} tanh(f32[2,2]{1,0} %val)
}
// CHECK: func @tanh(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
// CHECK: "xla_lhlo.tanh"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
// CHECK: }

View File

@ -30,24 +30,27 @@ namespace xla {
StatusOr<bool> RunFileCheck(const std::string& input,
absl::string_view pattern) {
using tensorflow::io::JoinPath;
// Generate an input file for the FileCheck pattern.
string pattern_path;
std::string pattern_path;
auto env = tensorflow::Env::Default();
if (!env->LocalTempFilename(&pattern_path)) {
return tensorflow::errors::Internal("couldn't get a pattern file name");
}
TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, pattern_path, pattern));
return RunFileCheckWithPatternFile(input, pattern_path);
}
StatusOr<bool> RunFileCheckWithPatternFile(const std::string& input,
const std::string& pattern_file) {
// Invoke FileCheck to check whether input matches `pattern`.
string file_check_path = tensorflow::GetDataDependencyFilepath(
JoinPath("external", "llvm-project", "llvm", "FileCheck"));
std::string file_check_path = tensorflow::GetDataDependencyFilepath(
tensorflow::io::JoinPath("external", "llvm-project", "llvm", "FileCheck"));
tensorflow::SubProcess file_check_process;
file_check_process.SetProgram(
file_check_path,
{file_check_path, "-v", "-dump-input=fail", pattern_path});
{file_check_path, "-v", "-dump-input=fail", pattern_file});
file_check_process.SetChannelAction(tensorflow::CHAN_STDIN,
tensorflow::ACTION_PIPE);
file_check_process.SetChannelAction(tensorflow::CHAN_STDERR,
@ -56,7 +59,7 @@ StatusOr<bool> RunFileCheck(const std::string& input,
return tensorflow::errors::Internal("couldn't start FileCheck");
}
string standard_error;
std::string standard_error;
int exit_status = file_check_process.Communicate(
/*stdin_input=*/&input, /*stdout_output=*/nullptr,
/*stderr_output=*/&standard_error);
@ -64,6 +67,7 @@ StatusOr<bool> RunFileCheck(const std::string& input,
// FileCheck returns 0 when the inputs match. If matching failed, log
// the error message generated by FileCheck and the inputs.
bool succeeded = (exit_status == 0);
auto env = tensorflow::Env::Default();
if (!succeeded) {
LOG(WARNING) << "Tried to execute FileCheck at " << file_check_path;
if (!env->FileExists(file_check_path).ok()) {
@ -71,8 +75,6 @@ StatusOr<bool> RunFileCheck(const std::string& input,
}
LOG(WARNING) << "FileCheck error:\n" << standard_error;
LOG(WARNING) << "FileCheck pattern was:";
XLA_LOG_LINES(tensorflow::WARNING, pattern);
} else if (!standard_error.empty()) {
LOG(INFO) << "FileCheck stderr:";
XLA_LOG_LINES(tensorflow::INFO, standard_error);

View File

@ -26,7 +26,14 @@ namespace xla {
// Runs FileCheck with the given pattern over given input string. Provided that
// FileCheck can execute, returns true if and only if FileCheck succeeded in
// matching the input.
StatusOr<bool> RunFileCheck(const string& input, absl::string_view pattern);
StatusOr<bool> RunFileCheck(const std::string& input,
absl::string_view pattern);
// Runs FileCheck with the given pattern file over given input string. Provided
// that FileCheck can execute, returns true if and only if FileCheck succeeded
// in matching the input.
StatusOr<bool> RunFileCheckWithPatternFile(const std::string& input,
const std::string& pattern_file);
} // namespace xla