Fix a test on Maxwell GPUs.

This commit is contained in:
Frederic Bastien 2020-09-14 09:54:12 -07:00
parent 5dba0941db
commit 62a3d533af

View File

@ -336,21 +336,33 @@ ENTRY %cluster {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
ParseAndReturnVerifiedModule(hlo_text));
const se::DeviceDescription& device_description =
backend().default_stream_executor()->GetDeviceDescription();
int cc_major = 0, cc_minor = 0;
device_description.cuda_compute_capability(&cc_major, &cc_minor);
string expected;
if (cc_major < 6) {
// We do not vectorize for GPU before Pascal.
expected = "CHECK-NOT: ld.global.nc.v2.f32";
} else {
expected = R"(
CHECK: ld.global.nc.v2.f32
CHECK: st.global.v2.f32
CHECK: st.global.v2.f32
CHECK: ld.global.nc.v2.f32
CHECK: st.global.v2.f32
CHECK: st.global.v2.f32
CHECK: ld.global.nc.v2.f32
CHECK: st.global.v2.f32
CHECK: st.global.v2.f32
CHECK: ld.global.nc.v2.f32
CHECK: st.global.v2.f32
CHECK: st.global.v2.f32
)";
}
CompileAndOptionallyVerifyPtx(std::move(optimized_module),
R"(
CHECK: ld.global.nc.v2.f32
CHECK: st.global.v2.f32
CHECK: st.global.v2.f32
CHECK: ld.global.nc.v2.f32
CHECK: st.global.v2.f32
CHECK: st.global.v2.f32
CHECK: ld.global.nc.v2.f32
CHECK: st.global.v2.f32
CHECK: st.global.v2.f32
CHECK: ld.global.nc.v2.f32
CHECK: st.global.v2.f32
CHECK: st.global.v2.f32
)");
expected);
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}