changes to address code review feedback
This commit is contained in:
parent
11b85f7473
commit
3e4a3d5c83
@ -46,7 +46,7 @@ GpuCodegenTest::CreateNewVerifiedModuleWithFTZ(bool ftz) {
|
|||||||
ShapeUtil::ByteSizeOfElements);
|
ShapeUtil::ByteSizeOfElements);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GpuCodegenTest::CompileAndVerifyPtx(
|
void GpuCodegenTest::CompileAndOptionallyVerifyPtx(
|
||||||
std::unique_ptr<VerifiedHloModule> hlo_module, absl::string_view pattern) {
|
std::unique_ptr<VerifiedHloModule> hlo_module, absl::string_view pattern) {
|
||||||
std::unique_ptr<Executable> executable =
|
std::unique_ptr<Executable> executable =
|
||||||
std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie());
|
std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie());
|
||||||
@ -55,11 +55,11 @@ void GpuCodegenTest::CompileAndVerifyPtx(
|
|||||||
// On the ROCM platform the "ptx" string is not populated for the compiled
|
// On the ROCM platform the "ptx" string is not populated for the compiled
|
||||||
// executable, and hence the "ptx_str" will be empty. So disabling the
|
// executable, and hence the "ptx_str" will be empty. So disabling the
|
||||||
// pattern check on the ROCm platform
|
// pattern check on the ROCm platform
|
||||||
#if !defined(TENSORFLOW_USE_ROCM)
|
if (!is_built_with_rocm_) {
|
||||||
StatusOr<bool> filecheck_result = RunFileCheck(ptx_str, pattern);
|
StatusOr<bool> filecheck_result = RunFileCheck(ptx_str, pattern);
|
||||||
ASSERT_TRUE(filecheck_result.ok());
|
ASSERT_TRUE(filecheck_result.ok());
|
||||||
EXPECT_TRUE(filecheck_result.ValueOrDie());
|
EXPECT_TRUE(filecheck_result.ValueOrDie());
|
||||||
#endif
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -39,8 +39,11 @@ class GpuCodegenTest : public LlvmIrGenTestBase {
|
|||||||
|
|
||||||
// Compiles the given HLO module to PTX and verifies the PTX matches the given
|
// Compiles the given HLO module to PTX and verifies the PTX matches the given
|
||||||
// FileCheck pattern. (See http://llvm.org/docs/CommandGuide/FileCheck.html).
|
// FileCheck pattern. (See http://llvm.org/docs/CommandGuide/FileCheck.html).
|
||||||
void CompileAndVerifyPtx(std::unique_ptr<VerifiedHloModule> hlo_module,
|
// The "VerifyPtx" part only happens on the CUDA platform,
|
||||||
absl::string_view pattern);
|
// and hence the "Optionally" in function name.
|
||||||
|
// For ROCm platform this routine will only do the "Compile" part.
|
||||||
|
void CompileAndOptionallyVerifyPtx(
|
||||||
|
std::unique_ptr<VerifiedHloModule> hlo_module, absl::string_view pattern);
|
||||||
|
|
||||||
bool is_built_with_rocm_;
|
bool is_built_with_rocm_;
|
||||||
};
|
};
|
||||||
|
@ -76,25 +76,15 @@ class GpuFtzDisabledTest : public GpuFtzTest {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Check that we emit mul.ftz.f32 when in ftz mode, and plain mul.f32 otherwise.
|
// Check that we emit mul.ftz.f32 when in ftz mode, and plain mul.f32 otherwise.
|
||||||
//
|
|
||||||
// On the ROCM platform the "ptx" string is not populated for the compiled
|
|
||||||
// executable, and hence the call to CompileAdnVerifyPtx does not do the
|
|
||||||
// "VerifyPtx" part, it merely compiles the executable
|
|
||||||
//
|
|
||||||
TEST_F(GpuFtzEnabledTest, MultiplyFtz) {
|
TEST_F(GpuFtzEnabledTest, MultiplyFtz) {
|
||||||
CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
|
CompileAndOptionallyVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
|
||||||
CHECK-NOT: mul.rn.f32
|
CHECK-NOT: mul.rn.f32
|
||||||
CHECK: mul.rn.ftz.f32
|
CHECK: mul.rn.ftz.f32
|
||||||
CHECK-NOT: mul.rn.f32
|
CHECK-NOT: mul.rn.f32
|
||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
//
|
|
||||||
// On the ROCM platform the "ptx" string is not populated for the compiled
|
|
||||||
// executable, and hence the call to CompileAdnVerifyPtx does not do the
|
|
||||||
// "VerifyPtx" part, it merely compiles the executable
|
|
||||||
//
|
|
||||||
TEST_F(GpuFtzDisabledTest, MultiplyFtz) {
|
TEST_F(GpuFtzDisabledTest, MultiplyFtz) {
|
||||||
CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
|
CompileAndOptionallyVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
|
||||||
CHECK-NOT: mul.rn.ftz.f32
|
CHECK-NOT: mul.rn.ftz.f32
|
||||||
CHECK: mul.rn.f32
|
CHECK: mul.rn.f32
|
||||||
CHECK-NOT: mul.rn.ftz.f32
|
CHECK-NOT: mul.rn.ftz.f32
|
||||||
@ -106,13 +96,8 @@ TEST_F(GpuFtzDisabledTest, MultiplyFtz) {
|
|||||||
// calls to ex2.approx. When ftz is on, we get two calls to the ftz version;
|
// calls to ex2.approx. When ftz is on, we get two calls to the ftz version;
|
||||||
// when ftz is off, we get one call to the ftz version and one call to the
|
// when ftz is off, we get one call to the ftz version and one call to the
|
||||||
// regular version.
|
// regular version.
|
||||||
//
|
|
||||||
// On the ROCM platform the "ptx" string is not populated for the compiled
|
|
||||||
// executable, and hence the call to CompileAdnVerifyPtx does not do the
|
|
||||||
// "VerifyPtx" part, it merely compiles the executable
|
|
||||||
//
|
|
||||||
TEST_F(GpuFtzEnabledTest, ExpFtz) {
|
TEST_F(GpuFtzEnabledTest, ExpFtz) {
|
||||||
CompileAndVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
|
CompileAndOptionallyVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
|
||||||
CHECK-NOT: ex2.approx.f32
|
CHECK-NOT: ex2.approx.f32
|
||||||
CHECK: ex2.approx.ftz.f32
|
CHECK: ex2.approx.ftz.f32
|
||||||
CHECK-NOT: ex2.approx.f32
|
CHECK-NOT: ex2.approx.f32
|
||||||
@ -122,13 +107,8 @@ TEST_F(GpuFtzEnabledTest, ExpFtz) {
|
|||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// On the ROCM platform the "ptx" string is not populated for the compiled
|
|
||||||
// executable, and hence the call to CompileAdnVerifyPtx does not do the
|
|
||||||
// "VerifyPtx" part, it merely compiles the executable
|
|
||||||
//
|
|
||||||
TEST_F(GpuFtzDisabledTest, ExpFtz) {
|
TEST_F(GpuFtzDisabledTest, ExpFtz) {
|
||||||
CompileAndVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
|
CompileAndOptionallyVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
|
||||||
CHECK-NOT: ex2.approx.f32
|
CHECK-NOT: ex2.approx.f32
|
||||||
CHECK-DAG: ex2.approx.ftz.f32
|
CHECK-DAG: ex2.approx.ftz.f32
|
||||||
CHECK-DAG: ex2.approx.f32
|
CHECK-DAG: ex2.approx.f32
|
||||||
|
@ -108,21 +108,15 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) {
|
|||||||
|
|
||||||
// In the IR generated for AMDGPUs, we do not seem to have the
|
// In the IR generated for AMDGPUs, we do not seem to have the
|
||||||
// the addrspace(1) attribute for the lines being checked by the following
|
// the addrspace(1) attribute for the lines being checked by the following
|
||||||
// patterns still need to investigate why that is the case, and whether or not
|
// patterns.
|
||||||
// it is ok
|
// need to investigate why that is the case, and whether or not it is ok
|
||||||
auto expected_ir = is_built_with_rocm_ ? R"(
|
CompileAndVerifyIr(std::move(module),
|
||||||
|
R"(
|
||||||
; CHECK: %[[urem1:.*]] = urem i{{[0-9]*}} %[[linear_index:.*]], 14
|
; CHECK: %[[urem1:.*]] = urem i{{[0-9]*}} %[[linear_index:.*]], 14
|
||||||
; CHECK: %[[bitcast:.*]] = bitcast i8* %[[alloc:.*]] to float*
|
; CHECK: %[[bitcast:.*]] = bitcast i8{{( addrspace\(1\))?}}* %[[alloc:.*]] to float{{( addrspace\(1\))?}}*
|
||||||
; CHECK: %[[idx1:.*]] = zext i{{[0-9]*}} %[[urem1]] to i64
|
; CHECK: %[[idx1:.*]] = zext i{{[0-9]*}} %[[urem1]] to i64
|
||||||
; CHECK: getelementptr inbounds float, float* %[[bitcast]], i64 %[[idx1]]
|
; CHECK: getelementptr inbounds float, float{{( addrspace\(1\))?}}* %[[bitcast]], i64 %[[idx1]]
|
||||||
)"
|
)",
|
||||||
: R"(
|
|
||||||
; CHECK: %[[urem1:.*]] = urem i{{[0-9]*}} %[[linear_index:.*]], 14
|
|
||||||
; CHECK: %[[bitcast:.*]] = bitcast i8 addrspace(1)* %[[alloc:.*]] to float addrspace(1)*
|
|
||||||
; CHECK: %[[idx1:.*]] = zext i{{[0-9]*}} %[[urem1]] to i64
|
|
||||||
; CHECK: getelementptr inbounds float, float addrspace(1)* %[[bitcast]], i64 %[[idx1]]
|
|
||||||
)";
|
|
||||||
CompileAndVerifyIr(std::move(module), expected_ir,
|
|
||||||
/*match_optimized_ir=*/true);
|
/*match_optimized_ir=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ TEST_F(GpuLdgTest, LdgForParamRead) {
|
|||||||
auto hlo_module = CreateNewVerifiedModule();
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
hlo_module->AddEntryComputation(std::move(computation));
|
hlo_module->AddEntryComputation(std::move(computation));
|
||||||
|
|
||||||
CompileAndVerifyPtx(std::move(hlo_module), R"(
|
CompileAndOptionallyVerifyPtx(std::move(hlo_module), R"(
|
||||||
CHECK-NOT: ld.global.f32
|
CHECK-NOT: ld.global.f32
|
||||||
CHECK: ld.global.nc.f32
|
CHECK: ld.global.nc.f32
|
||||||
)");
|
)");
|
||||||
@ -86,7 +86,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) {
|
|||||||
auto hlo_module = CreateNewVerifiedModule();
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
hlo_module->AddEntryComputation(std::move(computation));
|
hlo_module->AddEntryComputation(std::move(computation));
|
||||||
|
|
||||||
CompileAndVerifyPtx(std::move(hlo_module), R"(
|
CompileAndOptionallyVerifyPtx(std::move(hlo_module), R"(
|
||||||
CHECK: {
|
CHECK: {
|
||||||
CHECK-NOT: ld.global.f32
|
CHECK-NOT: ld.global.f32
|
||||||
CHECK: ld.global.nc.f32
|
CHECK: ld.global.nc.f32
|
||||||
@ -143,7 +143,7 @@ TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) {
|
|||||||
std::unique_ptr<HloComputation> computation = builder.Build();
|
std::unique_ptr<HloComputation> computation = builder.Build();
|
||||||
hlo_module->AddEntryComputation(std::move(computation));
|
hlo_module->AddEntryComputation(std::move(computation));
|
||||||
|
|
||||||
CompileAndVerifyPtx(std::move(hlo_module), R"(
|
CompileAndOptionallyVerifyPtx(std::move(hlo_module), R"(
|
||||||
CHECK-LABEL: .entry sin
|
CHECK-LABEL: .entry sin
|
||||||
CHECK: {
|
CHECK: {
|
||||||
CHECK-NOT: ld.global.nc.f32
|
CHECK-NOT: ld.global.nc.f32
|
||||||
|
Loading…
Reference in New Issue
Block a user