changes to address code review feedback

This commit is contained in:
Deven Desai 2020-01-16 02:48:17 +00:00
parent 11b85f7473
commit 3e4a3d5c83
5 changed files with 25 additions and 48 deletions

View File

@ -46,7 +46,7 @@ GpuCodegenTest::CreateNewVerifiedModuleWithFTZ(bool ftz) {
ShapeUtil::ByteSizeOfElements); ShapeUtil::ByteSizeOfElements);
} }
void GpuCodegenTest::CompileAndVerifyPtx( void GpuCodegenTest::CompileAndOptionallyVerifyPtx(
std::unique_ptr<VerifiedHloModule> hlo_module, absl::string_view pattern) { std::unique_ptr<VerifiedHloModule> hlo_module, absl::string_view pattern) {
std::unique_ptr<Executable> executable = std::unique_ptr<Executable> executable =
std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie()); std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie());
@ -55,11 +55,11 @@ void GpuCodegenTest::CompileAndVerifyPtx(
// On the ROCM platform the "ptx" string is not populated for the compiled // On the ROCM platform the "ptx" string is not populated for the compiled
// executable, and hence the "ptx_str" will be empty. So disabling the // executable, and hence the "ptx_str" will be empty. So disabling the
// pattern check on the ROCm platform // pattern check on the ROCm platform
#if !defined(TENSORFLOW_USE_ROCM) if (!is_built_with_rocm_) {
StatusOr<bool> filecheck_result = RunFileCheck(ptx_str, pattern); StatusOr<bool> filecheck_result = RunFileCheck(ptx_str, pattern);
ASSERT_TRUE(filecheck_result.ok()); ASSERT_TRUE(filecheck_result.ok());
EXPECT_TRUE(filecheck_result.ValueOrDie()); EXPECT_TRUE(filecheck_result.ValueOrDie());
#endif }
} }
} // namespace gpu } // namespace gpu

View File

@ -39,8 +39,11 @@ class GpuCodegenTest : public LlvmIrGenTestBase {
// Compiles the given HLO module to PTX and verifies the PTX matches the given // Compiles the given HLO module to PTX and verifies the PTX matches the given
// FileCheck pattern. (See http://llvm.org/docs/CommandGuide/FileCheck.html). // FileCheck pattern. (See http://llvm.org/docs/CommandGuide/FileCheck.html).
void CompileAndVerifyPtx(std::unique_ptr<VerifiedHloModule> hlo_module, // The "VerifyPtx" part only happens on the CUDA platform,
absl::string_view pattern); // and hence the "Optionally" in function name.
// For ROCm platform this routine will only do the "Compile" part.
void CompileAndOptionallyVerifyPtx(
std::unique_ptr<VerifiedHloModule> hlo_module, absl::string_view pattern);
bool is_built_with_rocm_; bool is_built_with_rocm_;
}; };

View File

@ -76,25 +76,15 @@ class GpuFtzDisabledTest : public GpuFtzTest {
}; };
// Check that we emit mul.ftz.f32 when in ftz mode, and plain mul.f32 otherwise. // Check that we emit mul.ftz.f32 when in ftz mode, and plain mul.f32 otherwise.
//
// On the ROCM platform the "ptx" string is not populated for the compiled
// executable, and hence the call to CompileAdnVerifyPtx does not do the
// "VerifyPtx" part, it merely compiles the executable
//
TEST_F(GpuFtzEnabledTest, MultiplyFtz) { TEST_F(GpuFtzEnabledTest, MultiplyFtz) {
CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( CompileAndOptionallyVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
CHECK-NOT: mul.rn.f32 CHECK-NOT: mul.rn.f32
CHECK: mul.rn.ftz.f32 CHECK: mul.rn.ftz.f32
CHECK-NOT: mul.rn.f32 CHECK-NOT: mul.rn.f32
)"); )");
} }
//
// On the ROCM platform the "ptx" string is not populated for the compiled
// executable, and hence the call to CompileAdnVerifyPtx does not do the
// "VerifyPtx" part, it merely compiles the executable
//
TEST_F(GpuFtzDisabledTest, MultiplyFtz) { TEST_F(GpuFtzDisabledTest, MultiplyFtz) {
CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( CompileAndOptionallyVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
CHECK-NOT: mul.rn.ftz.f32 CHECK-NOT: mul.rn.ftz.f32
CHECK: mul.rn.f32 CHECK: mul.rn.f32
CHECK-NOT: mul.rn.ftz.f32 CHECK-NOT: mul.rn.ftz.f32
@ -106,13 +96,8 @@ TEST_F(GpuFtzDisabledTest, MultiplyFtz) {
// calls to ex2.approx. When ftz is on, we get two calls to the ftz version; // calls to ex2.approx. When ftz is on, we get two calls to the ftz version;
// when ftz is off, we get one call to the ftz version and one call to the // when ftz is off, we get one call to the ftz version and one call to the
// regular version. // regular version.
//
// On the ROCM platform the "ptx" string is not populated for the compiled
// executable, and hence the call to CompileAdnVerifyPtx does not do the
// "VerifyPtx" part, it merely compiles the executable
//
TEST_F(GpuFtzEnabledTest, ExpFtz) { TEST_F(GpuFtzEnabledTest, ExpFtz) {
CompileAndVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"( CompileAndOptionallyVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
CHECK-NOT: ex2.approx.f32 CHECK-NOT: ex2.approx.f32
CHECK: ex2.approx.ftz.f32 CHECK: ex2.approx.ftz.f32
CHECK-NOT: ex2.approx.f32 CHECK-NOT: ex2.approx.f32
@ -122,13 +107,8 @@ TEST_F(GpuFtzEnabledTest, ExpFtz) {
)"); )");
} }
//
// On the ROCM platform the "ptx" string is not populated for the compiled
// executable, and hence the call to CompileAdnVerifyPtx does not do the
// "VerifyPtx" part, it merely compiles the executable
//
TEST_F(GpuFtzDisabledTest, ExpFtz) { TEST_F(GpuFtzDisabledTest, ExpFtz) {
CompileAndVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"( CompileAndOptionallyVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
CHECK-NOT: ex2.approx.f32 CHECK-NOT: ex2.approx.f32
CHECK-DAG: ex2.approx.ftz.f32 CHECK-DAG: ex2.approx.ftz.f32
CHECK-DAG: ex2.approx.f32 CHECK-DAG: ex2.approx.f32

View File

@ -108,21 +108,15 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) {
// In the IR generated for AMDGPUs, we do not seem to have the // In the IR generated for AMDGPUs, we do not seem to have the
// the addrspace(1) attribute for the lines being checked by the following // the addrspace(1) attribute for the lines being checked by the following
// patterns still need to investigate why that is the case, and whether or not // patterns.
// it is ok // need to investigate why that is the case, and whether or not it is ok
auto expected_ir = is_built_with_rocm_ ? R"( CompileAndVerifyIr(std::move(module),
R"(
; CHECK: %[[urem1:.*]] = urem i{{[0-9]*}} %[[linear_index:.*]], 14 ; CHECK: %[[urem1:.*]] = urem i{{[0-9]*}} %[[linear_index:.*]], 14
; CHECK: %[[bitcast:.*]] = bitcast i8* %[[alloc:.*]] to float* ; CHECK: %[[bitcast:.*]] = bitcast i8{{( addrspace\(1\))?}}* %[[alloc:.*]] to float{{( addrspace\(1\))?}}*
; CHECK: %[[idx1:.*]] = zext i{{[0-9]*}} %[[urem1]] to i64 ; CHECK: %[[idx1:.*]] = zext i{{[0-9]*}} %[[urem1]] to i64
; CHECK: getelementptr inbounds float, float* %[[bitcast]], i64 %[[idx1]] ; CHECK: getelementptr inbounds float, float{{( addrspace\(1\))?}}* %[[bitcast]], i64 %[[idx1]]
)" )",
: R"(
; CHECK: %[[urem1:.*]] = urem i{{[0-9]*}} %[[linear_index:.*]], 14
; CHECK: %[[bitcast:.*]] = bitcast i8 addrspace(1)* %[[alloc:.*]] to float addrspace(1)*
; CHECK: %[[idx1:.*]] = zext i{{[0-9]*}} %[[urem1]] to i64
; CHECK: getelementptr inbounds float, float addrspace(1)* %[[bitcast]], i64 %[[idx1]]
)";
CompileAndVerifyIr(std::move(module), expected_ir,
/*match_optimized_ir=*/true); /*match_optimized_ir=*/true);
} }

View File

@ -56,7 +56,7 @@ TEST_F(GpuLdgTest, LdgForParamRead) {
auto hlo_module = CreateNewVerifiedModule(); auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(std::move(computation)); hlo_module->AddEntryComputation(std::move(computation));
CompileAndVerifyPtx(std::move(hlo_module), R"( CompileAndOptionallyVerifyPtx(std::move(hlo_module), R"(
CHECK-NOT: ld.global.f32 CHECK-NOT: ld.global.f32
CHECK: ld.global.nc.f32 CHECK: ld.global.nc.f32
)"); )");
@ -86,7 +86,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) {
auto hlo_module = CreateNewVerifiedModule(); auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(std::move(computation)); hlo_module->AddEntryComputation(std::move(computation));
CompileAndVerifyPtx(std::move(hlo_module), R"( CompileAndOptionallyVerifyPtx(std::move(hlo_module), R"(
CHECK: { CHECK: {
CHECK-NOT: ld.global.f32 CHECK-NOT: ld.global.f32
CHECK: ld.global.nc.f32 CHECK: ld.global.nc.f32
@ -143,7 +143,7 @@ TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) {
std::unique_ptr<HloComputation> computation = builder.Build(); std::unique_ptr<HloComputation> computation = builder.Build();
hlo_module->AddEntryComputation(std::move(computation)); hlo_module->AddEntryComputation(std::move(computation));
CompileAndVerifyPtx(std::move(hlo_module), R"( CompileAndOptionallyVerifyPtx(std::move(hlo_module), R"(
CHECK-LABEL: .entry sin CHECK-LABEL: .entry sin
CHECK: { CHECK: {
CHECK-NOT: ld.global.nc.f32 CHECK-NOT: ld.global.nc.f32