changes to address code review feedback
This commit is contained in:
parent
11b85f7473
commit
3e4a3d5c83
@ -46,7 +46,7 @@ GpuCodegenTest::CreateNewVerifiedModuleWithFTZ(bool ftz) {
|
||||
ShapeUtil::ByteSizeOfElements);
|
||||
}
|
||||
|
||||
void GpuCodegenTest::CompileAndVerifyPtx(
|
||||
void GpuCodegenTest::CompileAndOptionallyVerifyPtx(
|
||||
std::unique_ptr<VerifiedHloModule> hlo_module, absl::string_view pattern) {
|
||||
std::unique_ptr<Executable> executable =
|
||||
std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie());
|
||||
@ -55,11 +55,11 @@ void GpuCodegenTest::CompileAndVerifyPtx(
|
||||
// On the ROCM platform the "ptx" string is not populated for the compiled
|
||||
// executable, and hence the "ptx_str" will be empty. So disabling the
|
||||
// pattern check on the ROCm platform
|
||||
#if !defined(TENSORFLOW_USE_ROCM)
|
||||
StatusOr<bool> filecheck_result = RunFileCheck(ptx_str, pattern);
|
||||
ASSERT_TRUE(filecheck_result.ok());
|
||||
EXPECT_TRUE(filecheck_result.ValueOrDie());
|
||||
#endif
|
||||
if (!is_built_with_rocm_) {
|
||||
StatusOr<bool> filecheck_result = RunFileCheck(ptx_str, pattern);
|
||||
ASSERT_TRUE(filecheck_result.ok());
|
||||
EXPECT_TRUE(filecheck_result.ValueOrDie());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -39,8 +39,11 @@ class GpuCodegenTest : public LlvmIrGenTestBase {
|
||||
|
||||
// Compiles the given HLO module to PTX and verifies the PTX matches the given
|
||||
// FileCheck pattern. (See http://llvm.org/docs/CommandGuide/FileCheck.html).
|
||||
void CompileAndVerifyPtx(std::unique_ptr<VerifiedHloModule> hlo_module,
|
||||
absl::string_view pattern);
|
||||
// The "VerifyPtx" part only happens on the CUDA platform,
|
||||
// and hence the "Optionally" in function name.
|
||||
// For ROCm platform this routine will only do the "Compile" part.
|
||||
void CompileAndOptionallyVerifyPtx(
|
||||
std::unique_ptr<VerifiedHloModule> hlo_module, absl::string_view pattern);
|
||||
|
||||
bool is_built_with_rocm_;
|
||||
};
|
||||
|
@ -76,25 +76,15 @@ class GpuFtzDisabledTest : public GpuFtzTest {
|
||||
};
|
||||
|
||||
// Check that we emit mul.ftz.f32 when in ftz mode, and plain mul.f32 otherwise.
|
||||
//
|
||||
// On the ROCM platform the "ptx" string is not populated for the compiled
|
||||
// executable, and hence the call to CompileAdnVerifyPtx does not do the
|
||||
// "VerifyPtx" part, it merely compiles the executable
|
||||
//
|
||||
TEST_F(GpuFtzEnabledTest, MultiplyFtz) {
|
||||
CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
|
||||
CompileAndOptionallyVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
|
||||
CHECK-NOT: mul.rn.f32
|
||||
CHECK: mul.rn.ftz.f32
|
||||
CHECK-NOT: mul.rn.f32
|
||||
)");
|
||||
}
|
||||
//
|
||||
// On the ROCM platform the "ptx" string is not populated for the compiled
|
||||
// executable, and hence the call to CompileAdnVerifyPtx does not do the
|
||||
// "VerifyPtx" part, it merely compiles the executable
|
||||
//
|
||||
TEST_F(GpuFtzDisabledTest, MultiplyFtz) {
|
||||
CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
|
||||
CompileAndOptionallyVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
|
||||
CHECK-NOT: mul.rn.ftz.f32
|
||||
CHECK: mul.rn.f32
|
||||
CHECK-NOT: mul.rn.ftz.f32
|
||||
@ -106,13 +96,8 @@ TEST_F(GpuFtzDisabledTest, MultiplyFtz) {
|
||||
// calls to ex2.approx. When ftz is on, we get two calls to the ftz version;
|
||||
// when ftz is off, we get one call to the ftz version and one call to the
|
||||
// regular version.
|
||||
//
|
||||
// On the ROCM platform the "ptx" string is not populated for the compiled
|
||||
// executable, and hence the call to CompileAdnVerifyPtx does not do the
|
||||
// "VerifyPtx" part, it merely compiles the executable
|
||||
//
|
||||
TEST_F(GpuFtzEnabledTest, ExpFtz) {
|
||||
CompileAndVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
|
||||
CompileAndOptionallyVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
|
||||
CHECK-NOT: ex2.approx.f32
|
||||
CHECK: ex2.approx.ftz.f32
|
||||
CHECK-NOT: ex2.approx.f32
|
||||
@ -122,13 +107,8 @@ TEST_F(GpuFtzEnabledTest, ExpFtz) {
|
||||
)");
|
||||
}
|
||||
|
||||
//
|
||||
// On the ROCM platform the "ptx" string is not populated for the compiled
|
||||
// executable, and hence the call to CompileAdnVerifyPtx does not do the
|
||||
// "VerifyPtx" part, it merely compiles the executable
|
||||
//
|
||||
TEST_F(GpuFtzDisabledTest, ExpFtz) {
|
||||
CompileAndVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
|
||||
CompileAndOptionallyVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
|
||||
CHECK-NOT: ex2.approx.f32
|
||||
CHECK-DAG: ex2.approx.ftz.f32
|
||||
CHECK-DAG: ex2.approx.f32
|
||||
|
@ -108,21 +108,15 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) {
|
||||
|
||||
// In the IR generated for AMDGPUs, we do not seem to have the
|
||||
// the addrspace(1) attribute for the lines being checked by the following
|
||||
// patterns still need to investigate why that is the case, and whether or not
|
||||
// it is ok
|
||||
auto expected_ir = is_built_with_rocm_ ? R"(
|
||||
// patterns.
|
||||
// need to investigate why that is the case, and whether or not it is ok
|
||||
CompileAndVerifyIr(std::move(module),
|
||||
R"(
|
||||
; CHECK: %[[urem1:.*]] = urem i{{[0-9]*}} %[[linear_index:.*]], 14
|
||||
; CHECK: %[[bitcast:.*]] = bitcast i8* %[[alloc:.*]] to float*
|
||||
; CHECK: %[[bitcast:.*]] = bitcast i8{{( addrspace\(1\))?}}* %[[alloc:.*]] to float{{( addrspace\(1\))?}}*
|
||||
; CHECK: %[[idx1:.*]] = zext i{{[0-9]*}} %[[urem1]] to i64
|
||||
; CHECK: getelementptr inbounds float, float* %[[bitcast]], i64 %[[idx1]]
|
||||
)"
|
||||
: R"(
|
||||
; CHECK: %[[urem1:.*]] = urem i{{[0-9]*}} %[[linear_index:.*]], 14
|
||||
; CHECK: %[[bitcast:.*]] = bitcast i8 addrspace(1)* %[[alloc:.*]] to float addrspace(1)*
|
||||
; CHECK: %[[idx1:.*]] = zext i{{[0-9]*}} %[[urem1]] to i64
|
||||
; CHECK: getelementptr inbounds float, float addrspace(1)* %[[bitcast]], i64 %[[idx1]]
|
||||
)";
|
||||
CompileAndVerifyIr(std::move(module), expected_ir,
|
||||
; CHECK: getelementptr inbounds float, float{{( addrspace\(1\))?}}* %[[bitcast]], i64 %[[idx1]]
|
||||
)",
|
||||
/*match_optimized_ir=*/true);
|
||||
}
|
||||
|
||||
|
@ -56,7 +56,7 @@ TEST_F(GpuLdgTest, LdgForParamRead) {
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(std::move(computation));
|
||||
|
||||
CompileAndVerifyPtx(std::move(hlo_module), R"(
|
||||
CompileAndOptionallyVerifyPtx(std::move(hlo_module), R"(
|
||||
CHECK-NOT: ld.global.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
)");
|
||||
@ -86,7 +86,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) {
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(std::move(computation));
|
||||
|
||||
CompileAndVerifyPtx(std::move(hlo_module), R"(
|
||||
CompileAndOptionallyVerifyPtx(std::move(hlo_module), R"(
|
||||
CHECK: {
|
||||
CHECK-NOT: ld.global.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
@ -143,7 +143,7 @@ TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) {
|
||||
std::unique_ptr<HloComputation> computation = builder.Build();
|
||||
hlo_module->AddEntryComputation(std::move(computation));
|
||||
|
||||
CompileAndVerifyPtx(std::move(hlo_module), R"(
|
||||
CompileAndOptionallyVerifyPtx(std::move(hlo_module), R"(
|
||||
CHECK-LABEL: .entry sin
|
||||
CHECK: {
|
||||
CHECK-NOT: ld.global.nc.f32
|
||||
|
Loading…
Reference in New Issue
Block a user