[XLA:GPU] Limit the number of shmem tiles XLA:GPU will use for 021 transposes.
There's a limit to how much shared memory we can use. PiperOrigin-RevId: 205465441
This commit is contained in:
parent
f4f37efdc9
commit
5e876a8c25
@ -3243,6 +3243,40 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Each of our shared memory tiles has 32*33 elements (so ~4kb, if the
|
||||
// elements are of size 4 bytes), and CUDA has an architectural limit of 48kb
|
||||
// shared memory per SM. (This is increased to 96kb in Volta, but we don't
|
||||
// use this, in part because it eats into our L1 cache space.)
|
||||
//
|
||||
// For correctness we need to ensure that we don't make more than 48kb worth
|
||||
// of shmem tiles per block. And for performance, we'd probably like to use
|
||||
// significantly less, so that we can fit more than one block at a time on a
|
||||
// gpu core.
|
||||
//
|
||||
// We say without benchmarks that we want at least 3 threads/block,
|
||||
// corresponding to 3 shmem tiles if the elements are 32 bits wide. We choose
|
||||
// which params get the shmem transpose treatment arbitrarily; it's not clear
|
||||
// if there's a Right Choice.
|
||||
//
|
||||
// This is only sound if tiled transposes are the only place where we use
|
||||
// shared memory in fusions. If in the future other fusile ops use shared
|
||||
// memory, we'll have to adjust this heuristic.
|
||||
constexpr int kMinBlocksPerCore = 3;
|
||||
constexpr int64 kShmemPerCore = 48 * 1024;
|
||||
int64 shmem_used = 0;
|
||||
for (int64 i = 0; i < params_012.size(); ++i) {
|
||||
const HloInstruction* operand = hlo->operand(params_012[i]);
|
||||
shmem_used +=
|
||||
32 * 33 *
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
|
||||
|
||||
if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
|
||||
// Erase this element and everything after it from params_012.
|
||||
params_012.resize(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString();
|
||||
thunk_sequence_->emplace_back(
|
||||
BuildKernelThunk(hlo, /*implements_whole_instruction=*/true));
|
||||
|
@ -799,6 +799,46 @@ ENTRY main {
|
||||
*result));
|
||||
}
|
||||
|
||||
class FusionClientLibraryTest : public ClientLibraryTestBase {};
|
||||
|
||||
XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) {
|
||||
// On the GPU backend, it's possible to have too many transposes within one
|
||||
// fusion, causing the kernel to run out shared memory and thus not compile.
|
||||
// We want to check that doesn't happen.
|
||||
//
|
||||
// To do this, we create a computation that computes
|
||||
//
|
||||
// P0 + P0*P1*P1 + P0*P2*P2 ...
|
||||
//
|
||||
// where even parameters have layout 1 and odd parameters have layout 2.
|
||||
//
|
||||
// Our goal is to tempt the backend into creating one giant multi-output
|
||||
// fusion for the whole computation, including the transposes. Currently
|
||||
// multi-output fusion only fuses fusions, so each of the terms in the sum
|
||||
// needs to be a fusion itself, thus the contortions above.
|
||||
constexpr int kNumParams = 25;
|
||||
XlaBuilder b("ManyLayoutTransformations");
|
||||
|
||||
// This test produces values that overflow int32, which is UB, so use uint32,
|
||||
// where overflow is OK.
|
||||
Array2D<uint32> arr(32, 32);
|
||||
arr.FillUnique();
|
||||
std::unique_ptr<Literal> l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
|
||||
LayoutUtil::MakeLayout({0, 1}));
|
||||
|
||||
std::unique_ptr<Literal> l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
|
||||
LayoutUtil::MakeLayout({1, 0}));
|
||||
|
||||
XlaOp p0 = AddParam(*l1, &b);
|
||||
XlaOp sum = p0;
|
||||
for (int i = 1; i < kNumParams; ++i) {
|
||||
auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b);
|
||||
sum = sum + p0 * pN * pN;
|
||||
}
|
||||
|
||||
ComputeAndCompare(&b, {});
|
||||
}
|
||||
|
||||
void BM_ParallelFusion(int num_iters) {
|
||||
// Simple element-wise computation to benchmark parallel task partitioning.
|
||||
tensorflow::testing::StopTiming();
|
||||
|
Loading…
x
Reference in New Issue
Block a user