From 5e876a8c25819070d78aa96595943afa207a6671 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 20 Jul 2018 15:41:36 -0700 Subject: [PATCH] [XLA:GPU] Limit the number of shmem tiles XLA:GPU will use for 021 transposes. There's a limit to how much shared memory we can use. PiperOrigin-RevId: 205465441 --- .../xla/service/gpu/ir_emitter_unnested.cc | 34 ++++++++++++++++ tensorflow/compiler/xla/tests/fusion_test.cc | 40 +++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 7100c9a08ad..b3229303df6 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -3243,6 +3243,40 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return false; } + // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the + // elements are of size 4 bytes), and CUDA has an architectural limit of 48kb + // shared memory per SM. (This is increased to 96kb in Volta, but we don't + // use this, in part because it eats into our L1 cache space.) + // + // For correctness we need to ensure that we don't make more than 48kb worth + // of shmem tiles per block. And for performance, we'd probably like to use + // significantly less, so that we can fit more than one block at a time on a + // gpu core. + // + // We say without benchmarks that we want at least 3 threads/block, + // corresponding to 3 shmem tiles if the elements are 32 bits wide. We choose + // which params get the shmem transpose treatment arbitrarily; it's not clear + // if there's a Right Choice. + // + // This is only sound if tiled transposes are the only place where we use + // shared memory in fusions. If in the future other fusile ops use shared + // memory, we'll have to adjust this heuristic. + constexpr int kMinBlocksPerCore = 3; + constexpr int64 kShmemPerCore = 48 * 1024; + int64 shmem_used = 0; + for (int64 i = 0; i < params_012.size(); ++i) { + const HloInstruction* operand = hlo->operand(params_012[i]); + shmem_used += + 32 * 33 * + ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type()); + + if (kMinBlocksPerCore * shmem_used > kShmemPerCore) { + // Erase this element and everything after it from params_012. + params_012.resize(i); + break; + } + } + VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString(); thunk_sequence_->emplace_back( BuildKernelThunk(hlo, /*implements_whole_instruction=*/true)); diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index dc644779357..607bcdd51ee 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -799,6 +799,46 @@ ENTRY main { *result)); } +class FusionClientLibraryTest : public ClientLibraryTestBase {}; + +XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { + // On the GPU backend, it's possible to have too many transposes within one + // fusion, causing the kernel to run out shared memory and thus not compile. + // We want to check that doesn't happen. + // + // To do this, we create a computation that computes + // + // P0 + P0*P1*P1 + P0*P2*P2 ... + // + // where even parameters have layout 1 and odd parameters have layout 2. + // + // Our goal is to tempt the backend into creating one giant multi-output + // fusion for the whole computation, including the transposes. Currently + // multi-output fusion only fuses fusions, so each of the terms in the sum + // needs to be a fusion itself, thus the contortions above. + constexpr int kNumParams = 25; + XlaBuilder b("ManyLayoutTransformations"); + + // This test produces values that overflow int32, which is UB, so use uint32, + // where overflow is OK. + Array2D arr(32, 32); + arr.FillUnique(); + std::unique_ptr l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + LayoutUtil::MakeLayout({0, 1})); + + std::unique_ptr l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + LayoutUtil::MakeLayout({1, 0})); + + XlaOp p0 = AddParam(*l1, &b); + XlaOp sum = p0; + for (int i = 1; i < kNumParams; ++i) { + auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b); + sum = sum + p0 * pN * pN; + } + + ComputeAndCompare(&b, {}); +} + void BM_ParallelFusion(int num_iters) { // Simple element-wise computation to benchmark parallel task partitioning. tensorflow::testing::StopTiming();