changes to address code review feedback

This commit is contained in:
Deven Desai 2020-01-16 02:48:17 +00:00
parent 11b85f7473
commit 3e4a3d5c83
5 changed files with 25 additions and 48 deletions

View File

@ -46,7 +46,7 @@ GpuCodegenTest::CreateNewVerifiedModuleWithFTZ(bool ftz) {
ShapeUtil::ByteSizeOfElements);
}
void GpuCodegenTest::CompileAndVerifyPtx(
void GpuCodegenTest::CompileAndOptionallyVerifyPtx(
std::unique_ptr<VerifiedHloModule> hlo_module, absl::string_view pattern) {
std::unique_ptr<Executable> executable =
std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie());
@ -55,11 +55,11 @@ void GpuCodegenTest::CompileAndVerifyPtx(
// On the ROCM platform the "ptx" string is not populated for the compiled
// executable, and hence the "ptx_str" will be empty. So disabling the
// pattern check on the ROCm platform
#if !defined(TENSORFLOW_USE_ROCM)
StatusOr<bool> filecheck_result = RunFileCheck(ptx_str, pattern);
ASSERT_TRUE(filecheck_result.ok());
EXPECT_TRUE(filecheck_result.ValueOrDie());
#endif
if (!is_built_with_rocm_) {
StatusOr<bool> filecheck_result = RunFileCheck(ptx_str, pattern);
ASSERT_TRUE(filecheck_result.ok());
EXPECT_TRUE(filecheck_result.ValueOrDie());
}
}
} // namespace gpu

View File

@ -39,8 +39,11 @@ class GpuCodegenTest : public LlvmIrGenTestBase {
// Compiles the given HLO module to PTX and verifies the PTX matches the given
// FileCheck pattern. (See http://llvm.org/docs/CommandGuide/FileCheck.html).
void CompileAndVerifyPtx(std::unique_ptr<VerifiedHloModule> hlo_module,
absl::string_view pattern);
// The "VerifyPtx" part only happens on the CUDA platform,
// and hence the "Optionally" in function name.
// For ROCm platform this routine will only do the "Compile" part.
void CompileAndOptionallyVerifyPtx(
std::unique_ptr<VerifiedHloModule> hlo_module, absl::string_view pattern);
bool is_built_with_rocm_;
};

View File

@ -76,25 +76,15 @@ class GpuFtzDisabledTest : public GpuFtzTest {
};
// Check that we emit mul.ftz.f32 when in ftz mode, and plain mul.f32 otherwise.
//
// On the ROCM platform the "ptx" string is not populated for the compiled
// executable, and hence the call to CompileAdnVerifyPtx does not do the
// "VerifyPtx" part, it merely compiles the executable
//
TEST_F(GpuFtzEnabledTest, MultiplyFtz) {
CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
CompileAndOptionallyVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
CHECK-NOT: mul.rn.f32
CHECK: mul.rn.ftz.f32
CHECK-NOT: mul.rn.f32
)");
}
//
// On the ROCM platform the "ptx" string is not populated for the compiled
// executable, and hence the call to CompileAdnVerifyPtx does not do the
// "VerifyPtx" part, it merely compiles the executable
//
TEST_F(GpuFtzDisabledTest, MultiplyFtz) {
CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
CompileAndOptionallyVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
CHECK-NOT: mul.rn.ftz.f32
CHECK: mul.rn.f32
CHECK-NOT: mul.rn.ftz.f32
@ -106,13 +96,8 @@ TEST_F(GpuFtzDisabledTest, MultiplyFtz) {
// calls to ex2.approx. When ftz is on, we get two calls to the ftz version;
// when ftz is off, we get one call to the ftz version and one call to the
// regular version.
//
// On the ROCM platform the "ptx" string is not populated for the compiled
// executable, and hence the call to CompileAdnVerifyPtx does not do the
// "VerifyPtx" part, it merely compiles the executable
//
TEST_F(GpuFtzEnabledTest, ExpFtz) {
CompileAndVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
CompileAndOptionallyVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
CHECK-NOT: ex2.approx.f32
CHECK: ex2.approx.ftz.f32
CHECK-NOT: ex2.approx.f32
@ -122,13 +107,8 @@ TEST_F(GpuFtzEnabledTest, ExpFtz) {
)");
}
//
// On the ROCM platform the "ptx" string is not populated for the compiled
// executable, and hence the call to CompileAdnVerifyPtx does not do the
// "VerifyPtx" part, it merely compiles the executable
//
TEST_F(GpuFtzDisabledTest, ExpFtz) {
CompileAndVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
CompileAndOptionallyVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
CHECK-NOT: ex2.approx.f32
CHECK-DAG: ex2.approx.ftz.f32
CHECK-DAG: ex2.approx.f32

View File

@ -108,21 +108,15 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) {
// In the IR generated for AMDGPUs, we do not seem to have the
// the addrspace(1) attribute for the lines being checked by the following
// patterns still need to investigate why that is the case, and whether or not
// it is ok
auto expected_ir = is_built_with_rocm_ ? R"(
// patterns.
// need to investigate why that is the case, and whether or not it is ok
CompileAndVerifyIr(std::move(module),
R"(
; CHECK: %[[urem1:.*]] = urem i{{[0-9]*}} %[[linear_index:.*]], 14
; CHECK: %[[bitcast:.*]] = bitcast i8* %[[alloc:.*]] to float*
; CHECK: %[[bitcast:.*]] = bitcast i8{{( addrspace\(1\))?}}* %[[alloc:.*]] to float{{( addrspace\(1\))?}}*
; CHECK: %[[idx1:.*]] = zext i{{[0-9]*}} %[[urem1]] to i64
; CHECK: getelementptr inbounds float, float* %[[bitcast]], i64 %[[idx1]]
)"
: R"(
; CHECK: %[[urem1:.*]] = urem i{{[0-9]*}} %[[linear_index:.*]], 14
; CHECK: %[[bitcast:.*]] = bitcast i8 addrspace(1)* %[[alloc:.*]] to float addrspace(1)*
; CHECK: %[[idx1:.*]] = zext i{{[0-9]*}} %[[urem1]] to i64
; CHECK: getelementptr inbounds float, float addrspace(1)* %[[bitcast]], i64 %[[idx1]]
)";
CompileAndVerifyIr(std::move(module), expected_ir,
; CHECK: getelementptr inbounds float, float{{( addrspace\(1\))?}}* %[[bitcast]], i64 %[[idx1]]
)",
/*match_optimized_ir=*/true);
}

View File

@ -56,7 +56,7 @@ TEST_F(GpuLdgTest, LdgForParamRead) {
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(std::move(computation));
CompileAndVerifyPtx(std::move(hlo_module), R"(
CompileAndOptionallyVerifyPtx(std::move(hlo_module), R"(
CHECK-NOT: ld.global.f32
CHECK: ld.global.nc.f32
)");
@ -86,7 +86,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) {
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(std::move(computation));
CompileAndVerifyPtx(std::move(hlo_module), R"(
CompileAndOptionallyVerifyPtx(std::move(hlo_module), R"(
CHECK: {
CHECK-NOT: ld.global.f32
CHECK: ld.global.nc.f32
@ -143,7 +143,7 @@ TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) {
std::unique_ptr<HloComputation> computation = builder.Build();
hlo_module->AddEntryComputation(std::move(computation));
CompileAndVerifyPtx(std::move(hlo_module), R"(
CompileAndOptionallyVerifyPtx(std::move(hlo_module), R"(
CHECK-LABEL: .entry sin
CHECK: {
CHECK-NOT: ld.global.nc.f32