STT-tensorflow/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
Marcello Maggioni f7c7fbd40b [XLA] Extend hlo_rematerialization pass to support rematerialization of tuple producing instrs.
Allow rematerialization of tuple producing instructions by extending the process we use to
rematerialize bitcasts to also handle get-tuple-element'ed buffers that are not nested.
This allows to rematerialize through tuples as well.

PiperOrigin-RevId: 352691189
Change-Id: Ia1a7674c7e32f1c53253cd5b674abce99f87d509
2021-01-19 17:45:46 -08:00

1044 lines
43 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
#include <memory>
#include <string>
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_rematerialization_test_utils.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
// Inherits methods to create rematerializable computations. See
// RematerializationTestBase for more.
class HloRematerializationTest : public RematerializationTestBase {
protected:
StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
HloModule* module,
int64 min_remat_size = 0) {
TF_EXPECT_OK(verifier().Run(module).status());
HloMemoryScheduler scheduler(
[](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); },
ComputationSchedulerToModuleScheduler(DefaultMemoryScheduler));
TF_EXPECT_OK(scheduler.Run(module).status());
HloRematerialization remat(
ByteSizeOf, memory_limit_bytes,
/*sizes=*/nullptr,
HloRematerialization::RematerializationPass::kPreFusion,
/*block_size_limit=*/1, nullptr,
HloRematerialization::RematerializationMode::kRecomputeAndCompress,
min_remat_size);
return remat.Run(module);
}
};
// Test rematerialization of a single computation produced by
// MakeRematerializableComputation.
TEST_F(HloRematerializationTest, SingleComputation) {
auto module = CreateNewVerifiedModule();
HloComputation* computation =
module->AddEntryComputation(MakeRematerializableComputation());
// Find and save the original broadcast instruction which should be
// rematerialized.
const HloInstruction* slice = computation->root_instruction();
ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _)));
const HloInstruction* concat = slice->operand(0);
const HloInstruction* bcast = concat->operand(0);
// Computation requires 16KB without rematerialization, but uses only 12KB
// with rematerialization so pick a memory limit between these values (14KB).
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/14 * 1024, module.get()));
EXPECT_TRUE(changed);
// Root should not have changed.
EXPECT_EQ(computation->root_instruction(), slice);
// The broadcast should have been rematerialized.
const HloInstruction* remat_bcast = concat->operand(0);
EXPECT_THAT(remat_bcast, op::Broadcast(::testing::Ne(bcast)));
// The rematerialized broadcast should be immediate before the concat in the
// sequence.
EXPECT_EQ(module->schedule()
.sequence(computation)
.instructions()[computation->instruction_count() - 2],
concat);
EXPECT_EQ(module->schedule()
.sequence(computation)
.instructions()[computation->instruction_count() - 3],
remat_bcast);
}
// Test rematerialization of a single computation that contains nodes that
// doesn't contain node worth using remat.
TEST_F(HloRematerializationTest, SingleComputationNoWorthRemat) {
auto module = CreateNewVerifiedModule();
HloComputation* computation =
module->AddEntryComputation(MakeRematerializableComputation());
// Find and save the original broadcast instruction which should be
// rematerialized.
const HloInstruction* slice = computation->root_instruction();
ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _)));
// Set the minimum remat size to 14KiB, meaning no nodes should be remat.
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/14 * 1024, module.get(),
/*min_remat_size=*/14 * 1024));
EXPECT_FALSE(changed);
}
// Test rematerialization of a single computation produced by
// MakeRematerializableComputation but with a sufficiently high memory limit
// such that no instructions are rematerialized.
TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
auto module = CreateNewVerifiedModule();
HloComputation* computation =
module->AddEntryComputation(MakeRematerializableComputation());
EXPECT_EQ(computation->instruction_count(), 8);
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/20 * 1024, module.get()));
// No instructions should have been materialized.
EXPECT_FALSE(changed);
EXPECT_EQ(computation->instruction_count(), 8);
}
// Test rematerialization of a computation which calls another computation via a
// while. Both the entry computation and while body computation can have memory
// usage reduced via rematerialization however the memory limit is set such that
// only one computation needs to have an instruction rematerialized. The entry
// computation should be the one chosen because rematerialization in the while
// will presumably be more expensive.
TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
auto module = CreateNewVerifiedModule();
auto cond_builder = HloComputation::Builder(TestName() + ".cond");
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, vec1_shape_, "param"));
cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloComputation* while_cond =
module->AddEmbeddedComputation(cond_builder.Build());
HloComputation* body_computation = module->AddEmbeddedComputation(
MakeRematerializableComputation(/*suffix=*/".body"));
HloComputation* entry_computation =
module->AddEntryComputation(MakeRematerializableWhileComputation(
while_cond, /*while_body=*/body_computation));
EXPECT_EQ(entry_computation->instruction_count(), 7);
EXPECT_EQ(body_computation->instruction_count(), 8);
// The body computation uses 16KB and the entry computation uses 2KB at the
// while so the peak memory use of the module is 18KB. Set the memory limit a
// bit lower (17KB) to force rematerialization of the entry computation.
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/17 * 1024, module.get()));
EXPECT_TRUE(changed);
// Only the entry computation should have a rematerialized instruction added.
EXPECT_EQ(entry_computation->instruction_count(), 8);
EXPECT_EQ(body_computation->instruction_count(), 8);
}
// Test rematerialization of a computation which calls another computation via a
// while. Both the entry computation and while body computation should have
// computations rematerialized.
TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
auto module = CreateNewVerifiedModule();
auto cond_builder = HloComputation::Builder(TestName() + ".cond");
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, vec1_shape_, "param"));
cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloComputation* while_cond =
module->AddEmbeddedComputation(cond_builder.Build());
HloComputation* body_computation = module->AddEmbeddedComputation(
MakeRematerializableComputation(/*suffix=*/".body"));
HloComputation* entry_computation =
module->AddEntryComputation(MakeRematerializableWhileComputation(
while_cond, /*while_body=*/body_computation));
EXPECT_EQ(entry_computation->instruction_count(), 7);
EXPECT_EQ(body_computation->instruction_count(), 8);
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/15 * 1024, module.get()));
EXPECT_TRUE(changed);
// Both computations should have rematerialized instructions added.
EXPECT_EQ(entry_computation->instruction_count(), 9);
EXPECT_EQ(body_computation->instruction_count(), 9);
}
// Test rematerialization of a doubly nested computation. All computations
// should have an instruction rematerialized.
TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
auto module = CreateNewVerifiedModule();
auto cond_builder = HloComputation::Builder(TestName() + ".cond");
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, vec1_shape_, "param"));
cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloComputation* while_cond =
module->AddEmbeddedComputation(cond_builder.Build());
HloComputation* while_cond_copy =
module->AddEmbeddedComputation(while_cond->Clone());
HloComputation* inner_computation = module->AddEmbeddedComputation(
MakeRematerializableComputation(/*suffix=*/".inner"));
HloComputation* middle_computation =
module->AddEmbeddedComputation(MakeRematerializableWhileComputation(
while_cond, /*while_body=*/inner_computation,
/*suffix=*/".middle"));
HloComputation* entry_computation =
module->AddEntryComputation(MakeRematerializableWhileComputation(
while_cond_copy, /*while_body=*/middle_computation));
EXPECT_EQ(entry_computation->instruction_count(), 7);
EXPECT_EQ(middle_computation->instruction_count(), 7);
EXPECT_EQ(inner_computation->instruction_count(), 8);
// If all computations are maximally rematerialized then peak memory usage is
// ~12K so pick something slightly larger.
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/13 * 1024, module.get()));
EXPECT_TRUE(changed);
// All computations should have rematerialized instructions added.
EXPECT_EQ(entry_computation->instruction_count(), 9);
EXPECT_EQ(middle_computation->instruction_count(), 9);
EXPECT_EQ(inner_computation->instruction_count(), 9);
}
TEST_F(HloRematerializationTest, RngNotRematerialized) {
// Test that a single rng is not rematerialized:
//
// Entry computation:
// F32[] %param = {...}
// F32[1024] rng = rng(param)
// F32[1024] tanh = tanh(rng)
// F32[1024] exp = exp(rng)
// F32[1024] add_0 = add(rng, tanh) // LIVE: add_0 + rng +
// // tanh + exp
//
// F32[1024] add_1 = add(rng, add(exp, add_0)) // LIVE: add_1 + add_0 +
// // rng + tanh + exp
//
// F32[1024] add_2 = add(rng, add(tanh, add_1)) // LIVE: add_2 + add_1 +
// // rng + tanh + exp
auto module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
auto rng = builder.AddInstruction(HloInstruction::CreateRng(
vec1024_shape_, RandomDistribution::RNG_UNIFORM, {param, param}));
auto tanh = builder.AddInstruction(
HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kTanh, rng));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kExp, rng));
auto add_0 = builder.AddInstruction(
HloInstruction::CreateBinary(vec1024_shape_, HloOpcode::kAdd, rng, tanh));
auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary(
vec1024_shape_, HloOpcode::kAdd, rng,
builder.AddInstruction(HloInstruction::CreateBinary(
vec1024_shape_, HloOpcode::kAdd, exp, add_0))));
builder.AddInstruction(HloInstruction::CreateBinary(
vec1024_shape_, HloOpcode::kAdd, rng,
builder.AddInstruction(HloInstruction::CreateBinary(
vec1024_shape_, HloOpcode::kAdd, tanh, add_1))));
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
auto count_rngs = [](const HloComputation* computation) {
int64 rng_count = 0;
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kRng) {
++rng_count;
}
}
return rng_count;
};
// Before rematerialization there should be a single broadcast rng in
// the graph.
ASSERT_EQ(count_rngs(entry_computation), 1);
const int64 original_instruction_count =
entry_computation->instruction_count();
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module.get()));
EXPECT_TRUE(changed);
// The rng should not have been rematerialized.
EXPECT_EQ(count_rngs(entry_computation), 1);
// There should have been rematerialization.
EXPECT_GT(entry_computation->instruction_count(), original_instruction_count);
}
TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
// Test that a single instruction is rematerialized several times. Module:
//
// Entry computation:
// F32[] %param = {...}
// F32[1024] %bcast = broadcast(%param)
// F32[1024] %add_1 = add(%bcast, bcast)
// F32[1024] %call_1 = call(Subcomputation, {%add_1})
// F32[1024] %add_2 = add(%bcast, call_1)
// F32[1024] %call_2 = call(SubComputation, {%add_2})
// F32[1024] %add_3 = add(%bcast, call_2)
// F32[1024] %call_3 = call(Subcomputation, {%add_3})
// F32[1024] %add_4 = add(%bcast, call_3)
//
// Subcomputation:
// F32[1024] %param = {...}
// F32[2048] %concat = concat({%param, %param})
// F32[1024] %slice = slice(%concat)
//
// The value %bcast is live across each call of Subcomputation (which requires
// 8KB) though the value is not used in the calls. Rematerializing %bcast
// across these calls reduces peak memory use from ~20KB down to ~16KB.
auto module = CreateNewVerifiedModule();
HloComputation* subcomputation = nullptr;
{
auto builder = HloComputation::Builder(TestName() + ".subcomputation");
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, vec1024_shape_, "param"));
auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(xla::F32, {2048}), {param, param},
/*dimension=*/0));
builder.AddInstruction(HloInstruction::CreateSlice(
vec1024_shape_, concat, /*start_indices=*/{0},
/*limit_indices=*/{1024}, /*strides=*/{1}));
subcomputation = module->AddEmbeddedComputation(builder.Build());
}
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
auto bcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary(
vec1024_shape_, HloOpcode::kAdd, bcast, bcast));
auto call_1 = builder.AddInstruction(
HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation));
auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary(
vec1024_shape_, HloOpcode::kAdd, bcast, call_1));
auto call_2 = builder.AddInstruction(
HloInstruction::CreateCall(vec1024_shape_, {add_2}, subcomputation));
auto add_3 = builder.AddInstruction(HloInstruction::CreateBinary(
vec1024_shape_, HloOpcode::kAdd, bcast, call_2));
auto call_3 = builder.AddInstruction(
HloInstruction::CreateCall(vec1024_shape_, {add_3}, subcomputation));
auto add_4 = builder.AddInstruction(HloInstruction::CreateBinary(
vec1024_shape_, HloOpcode::kAdd, bcast, call_3));
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
auto count_broadcasts = [](const HloComputation* computation) {
int64 bcast_count = 0;
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kBroadcast) {
bcast_count++;
}
}
return bcast_count;
};
// Before rematerialization there should be a single broadcast instruction in
// the graph.
EXPECT_EQ(count_broadcasts(entry_computation), 1);
EXPECT_EQ(entry_computation->instruction_count(), 9);
EXPECT_EQ(add_2->operand(0), bcast);
EXPECT_EQ(add_3->operand(0), bcast);
EXPECT_EQ(add_4->operand(0), bcast);
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/22 * 1024, module.get()));
EXPECT_TRUE(changed);
// The broadcast should have been rematerialized 3 times.
EXPECT_EQ(count_broadcasts(entry_computation), 4);
EXPECT_EQ(entry_computation->instruction_count(), 12);
// The operands of add_2, add_3, and add_4 should all be rematerialized
// broadcasts.
EXPECT_NE(add_2->operand(0), bcast);
EXPECT_THAT(add_2->operand(0), op::Broadcast(param));
EXPECT_NE(add_3->operand(0), bcast);
EXPECT_THAT(add_3->operand(0), op::Broadcast(param));
EXPECT_NE(add_4->operand(0), bcast);
EXPECT_THAT(add_4->operand(0), op::Broadcast(param));
}
TEST_F(HloRematerializationTest, CopyNotRematerialized) {
// Test that copies are not rematerialized.
auto module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, vec1024_shape_, "param"));
auto copy = builder.AddInstruction(
HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kCopy, param));
auto negate_a_1 = builder.AddInstruction(
HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, copy));
auto negate_a_2 = builder.AddInstruction(HloInstruction::CreateUnary(
vec1024_shape_, HloOpcode::kNegate, negate_a_1));
auto negate_b_1 = builder.AddInstruction(
HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, copy));
auto negate_b_2 = builder.AddInstruction(HloInstruction::CreateUnary(
vec1024_shape_, HloOpcode::kNegate, negate_b_1));
builder.AddInstruction(HloInstruction::CreateTuple({negate_a_2, negate_b_2}));
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/1 * 1024, module.get()));
auto count_copies = [](const HloComputation* computation) {
int64 copy_count = 0;
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCopy) {
copy_count++;
}
}
return copy_count;
};
EXPECT_TRUE(changed);
EXPECT_EQ(count_copies(entry_computation), 1);
}
class IndirectUseTest : public HloRematerializationTest,
public ::testing::WithParamInterface<bool> {};
TEST_P(IndirectUseTest, IndirectUseRematerialized) {
// Test that an rematerializable instruction is rematerialized if it has
// indirect use
// Module:
//
// Entry computation:
// F32[] %param = {...}
// F32[1024] %bcast = broadcast(%param)
// F32[1024] %add_1 = add(%bcast, bcast)
// F32[1024] %call = call(Subcomputation, {%add_1})
// F32[1024] %add_2 = add(%bcast, call)
// {F32[1024], F32[1024]} %tuple = tuple(%bcast, %add_2)
// F32[1024] %gte = GetTupleElement(%tuple, 0)
// F32[1024] %negate = negate(%gte)
//
// Subcomputation:
// F32[1024] %param = {...}
// F32[2048] %concat = concat({%param, %param})
// F32[1024] %slice = slice(%concat)
//
// The value %bcast is live across the call and rematerialization of %bcast
// across that point would reduce peak memory use by 4KB.
//
// This test is parameterized on whether the broadcast has an indirect use
// or not. The indirect use is controlled by the index of the GetTupleElement
// instruction. If the element is 0, then the %negate operand aliases %bcast
// (ie %bcast is used indirectly by %negate), otherwise the %negate operand
// aliases %add_2.
const bool indirectly_used = GetParam();
auto module = CreateNewVerifiedModule();
HloComputation* subcomputation = nullptr;
{
auto builder = HloComputation::Builder(TestName() + ".subcomputation");
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, vec1024_shape_, "param"));
auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(xla::F32, {2048}), {param, param},
/*dimension=*/0));
builder.AddInstruction(HloInstruction::CreateSlice(
vec1024_shape_, concat, /*start_indices=*/{0},
/*limit_indices=*/{1024}, /*strides=*/{1}));
subcomputation = module->AddEmbeddedComputation(builder.Build());
}
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
auto bcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary(
vec1024_shape_, HloOpcode::kAdd, bcast, bcast));
auto call_1 = builder.AddInstruction(
HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation));
auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary(
vec1024_shape_, HloOpcode::kAdd, bcast, call_1));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({bcast, add_2}));
auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
vec1024_shape_, tuple, indirectly_used ? 0 : 1));
builder.AddInstruction(
HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, gte));
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
EXPECT_EQ(entry_computation->instruction_count(), 8);
// Pick a memory limit some where between 24KB (initial peak memory
// including parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/22 * 1024, module.get()));
// Rematerialization should only occur if the rematerializable instruction
// has no indirect uses.
if (indirectly_used) {
EXPECT_TRUE(changed);
EXPECT_EQ(entry_computation->instruction_count(), 3);
} else {
EXPECT_TRUE(changed);
EXPECT_EQ(entry_computation->instruction_count(), 9);
}
}
INSTANTIATE_TEST_SUITE_P(IndirectUseTestInstantiation, IndirectUseTest,
::testing::Values(true, false));
class CompressingRematerializationTest : public RematerializationTestBase {
protected:
// A special shape size function, which pads the most minor dimension to 64.
static int64 ShapeSizePadMinorTo64(const Shape& shape) {
if (shape.IsTuple()) {
// Size of a tuple is 4 bytes.
return 4;
}
Shape descending_shape =
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(shape);
int64 size =
ShapeUtil::ByteSizeOfPrimitiveType(descending_shape.element_type());
for (int64 i = 0; i < descending_shape.rank(); ++i) {
int64 dim = descending_shape.dimensions(i);
if (i == descending_shape.rank() - 1) {
dim = RoundUpToNearest<int64>(dim, 64);
}
size *= dim;
}
return size;
}
// Swap the layout of the two most-minor dimensions if the second-minor
// dimension is bigger than the most-minor dimension.
static StatusOr<Shape> ChooseCompactLayoutForShape(const Shape& shape) {
Shape result = shape;
Layout layout = result.layout();
int64 most_minor_index = layout.minor_to_major()[0];
int64 second_minor_index = layout.minor_to_major()[1];
int64 most_minor = result.dimensions(most_minor_index);
int64 second_minor = result.dimensions(second_minor_index);
if (most_minor < second_minor) {
Layout new_layout = layout;
new_layout.set_minor_to_major(0, second_minor_index);
new_layout.set_minor_to_major(1, most_minor_index);
*result.mutable_layout() = new_layout;
}
return result;
}
StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
HloModule* module,
int64 min_remat_size = 0) {
TF_EXPECT_OK(verifier().Run(module).status());
HloRematerialization remat(
ShapeSizePadMinorTo64, memory_limit_bytes,
/*sizes=*/nullptr,
HloRematerialization::RematerializationPass::kPreFusion,
/*block_size_limit=*/1, ChooseCompactLayoutForShape,
HloRematerialization::RematerializationMode::kCompressOnly,
min_remat_size);
return remat.Run(module);
}
};
// Test rematerialization only remats big buffer that pass certain limits.
TEST_F(CompressingRematerializationTest, OnlyRematBigBuffer) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
%add_float {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
ENTRY %entry {
%param.0 = f32[] parameter(0)
%constant = f32[] constant(0)
%broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={}
%broadcast.1 = f32[10,2]{1,0} broadcast(f32[] %param.0), dimensions={}
%negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0)
%reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%reduce.1 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%reduce.2 = f32[] reduce(f32[10,2]{1,0} %broadcast.1, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%add = f32[] add(f32[] %reduce.0, f32[] %reduce.1)
ROOT %add.2 = f32[] add(f32[] %add, f32[] %reduce.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
// Only rematerialize buffers which have shaep f32[64, 2]. Buffers with shape
// f32[10, 2] are ignored.
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/30 * 1024,
module.get(), 10 * 1024));
EXPECT_TRUE(changed);
HloInstruction* broadcast =
module->entry_computation()->GetInstructionWithName("broadcast.0");
HloInstruction* broadcast_2 =
module->entry_computation()->GetInstructionWithName("broadcast.1");
HloInstruction* reduce =
module->entry_computation()->GetInstructionWithName("reduce.1");
HloInstruction* reduce_2 =
module->entry_computation()->GetInstructionWithName("reduce.2");
EXPECT_THAT(reduce,
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
EXPECT_THAT(reduce_2, op::Reduce(broadcast_2, op::Constant()));
}
// Test rematerialization of a single instruction.
TEST_F(CompressingRematerializationTest, SingleRemat) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
%add_float {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
ENTRY %entry {
%param.0 = f32[] parameter(0)
%constant = f32[] constant(0)
%broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={}
%negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0)
%reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%reduce.1 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%add = f32[] add(f32[] %reduce.0, f32[] %reduce.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/30 * 1024, module.get()));
EXPECT_TRUE(changed);
HloInstruction* broadcast =
module->entry_computation()->GetInstructionWithName("broadcast.0");
HloInstruction* reduce =
module->entry_computation()->GetInstructionWithName("reduce.1");
EXPECT_THAT(reduce,
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
}
TEST_F(CompressingRematerializationTest, AllUsersUseSameCopy) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
%add_float {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
ENTRY %entry {
%param.0 = f32[] parameter(0)
%constant = f32[] constant(0)
%broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={}
%negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0)
%reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%reduce.1 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%reduce.2 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%add = f32[] add(f32[] %reduce.0, f32[] %reduce.1)
%reduce.3 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%add.2 = f32[] add(f32[] %reduce.2, f32[] %reduce.3)
ROOT %tuple = (f32[], f32[]) tuple (f32[] add, f32[] add.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/30 * 1024, module.get()));
EXPECT_TRUE(changed);
HloInstruction* broadcast =
module->entry_computation()->GetInstructionWithName("broadcast.0");
// Both reduces reuse the same copy instruction.
HloInstruction* reduce_2 =
module->entry_computation()->GetInstructionWithName("reduce.2");
HloInstruction* reduce_3 =
module->entry_computation()->GetInstructionWithName("reduce.3");
EXPECT_THAT(reduce_2,
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
EXPECT_THAT(reduce_3,
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
}
// Test rematerialization of values through bitcasts
// Its expected that the broadcast gets rematerialized
TEST_F(HloRematerializationTest, ThroughBitcastRemat) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
ENTRY %mycomp (param: f32[1]) -> f32[1] {
%param = f32[1]{0} parameter(0)
%reshape = f32[] reshape(f32[1]{0} %param)
%broadcast = f32[1024,1]{1,0} broadcast(f32[] %reshape), dimensions={}
%bitcast = f32[1024]{0} bitcast(f32[1024,1]{1,0} %broadcast)
%negate = f32[1024,1]{1,0} negate(f32[1024,1]{1,0} %broadcast)
%concatenate = f32[2048,1]{1,0} concatenate(f32[1024,1]{1,0} %negate, f32[1024,1]{1,0} %negate), dimensions={0}
%slice = f32[1,1]{1,0} slice(f32[2048,1]{1,0} %concatenate), slice={[0:1], [0:1]}
%bitcast.1 = f32[1]{0} bitcast(f32[1,1]{1,0} %slice)
%concatenate.1 = f32[1025]{0} concatenate(f32[1024]{0} %bitcast, f32[1]{0} %bitcast.1), dimensions={0}
ROOT %slice.1 = f32[1]{0} slice(f32[1025]{0} %concatenate.1), slice={[0:1]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto* computation = module->entry_computation();
// Find and save the original broadcast instruction which should be
// rematerialized.
const HloInstruction* slice = computation->root_instruction();
ASSERT_THAT(slice,
op::Slice(op::Concatenate(op::Bitcast(op::Broadcast(_)), _)));
const HloInstruction* concat = slice->operand(0);
const HloInstruction* bcast = concat->operand(0)->operand(0);
// Computation requires 16KB without rematerialization, but uses only 12KB
// with rematerialization so pick a memory limit between these values (14KB).
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/14 * 1024, module.get()));
EXPECT_TRUE(changed);
// Root should not have changed.
EXPECT_EQ(computation->root_instruction(), slice);
// The bitcast for the rematerialized broadcast
const HloInstruction* remat_bitcast = concat->operand(0);
// The broadcast should have been rematerialized.
const HloInstruction* remat_broadcast = remat_bitcast->operand(0);
EXPECT_THAT(remat_broadcast, op::Broadcast(::testing::Ne(bcast)));
// The rematerialized broadcast should be immediately before its bitcast
// and the bitcast before the concatenate in the sequence.
EXPECT_EQ(module->schedule()
.sequence(computation)
.instructions()[computation->instruction_count() - 2],
concat);
EXPECT_EQ(module->schedule()
.sequence(computation)
.instructions()[computation->instruction_count() - 3],
remat_bitcast);
EXPECT_EQ(module->schedule()
.sequence(computation)
.instructions()[computation->instruction_count() - 4],
remat_broadcast);
}
// Test that the "deny list for move remats" engages when we rematerialize
// through bitcasts.
TEST_F(HloRematerializationTest, ThroughBitcastRematInfiniteLoop) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
ENTRY %mycomp (param: f32[1]) -> f32[1024] {
%param = f32[1]{0} parameter(0)
%reshape = f32[] reshape(f32[1]{0} %param)
%broadcast = f32[1024,1]{1,0} broadcast(f32[] %reshape), dimensions={}
%bitcast = f32[1024]{0} bitcast(f32[1024,1]{1,0} %broadcast)
%broadcast2 = f32[1024,1]{1,0} broadcast(f32[] %reshape), dimensions={}
%bitcast2 = f32[1024]{0} bitcast(f32[1024,1]{1,0} %broadcast2)
ROOT %add = f32[1024]{0} add(f32[1024]{0} %bitcast, f32[1024]{0} %bitcast2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto* computation = module->entry_computation();
// Find and save the original broadcasts instruction which should be
// rematerialized.
const HloInstruction* add = computation->root_instruction();
// Run with a low rematerialization limit that cannot be satisfied to make
// sure that we don't get stuck in a loop trying to lower it.
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/1024, module.get()));
ASSERT_THAT(add, op::Add(op::Bitcast(op::Broadcast(_)),
op::Bitcast(op::Broadcast(_))));
EXPECT_TRUE(changed);
}
TEST_F(HloRematerializationTest, RematTupleShape) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
%add_mul_comp {
%p0 = f32[] parameter(0)
%p1 = f32[] parameter(1)
%x = f32[1024]{0} broadcast(f32[] %p0), dimensions={}
%y = f32[1024]{0} broadcast(f32[] %p1), dimensions={}
%add = f32[1024] add(%x, %y)
%mul = f32[1024] multiply(%x, %y)
ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul)
}
ENTRY %entry {
%param.0 = f32[] parameter(0)
%param.1 = f32[] parameter(1)
%fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop,
calls=%add_mul_comp
%gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0
%add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.1)
%broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={}
%mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1)
%gte.2 = f32[1024]{0} get-tuple-element(%fus), index=1
ROOT %add.2 = f32[1024]{0} add(f32[1024]{0} %mul, f32[1024]{0} %gte.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
const HloComputation* computation = module->entry_computation();
const HloInstruction* add = computation->root_instruction();
ASSERT_THAT(add, op::Add(op::Multiply(), op::GetTupleElement(op::Fusion())));
const HloInstruction* fusion = add->operand(0)->operand(0);
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/11 * 1024, module.get()));
EXPECT_TRUE(changed);
ASSERT_THAT(
add, op::Add(op::Multiply(), op::GetTupleElement(AllOf(
op::Fusion(), ::testing::Ne(fusion)))));
}
TEST_F(HloRematerializationTest, RematTupleShapeDoubleUse) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
%add_mul_comp {
%p0 = f32[] parameter(0)
%p1 = f32[] parameter(1)
%x = f32[1024]{0} broadcast(f32[] %p0), dimensions={}
%y = f32[1024]{0} broadcast(f32[] %p1), dimensions={}
%add = f32[1024] add(%x, %y)
%mul = f32[1024] multiply(%x, %y)
ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul)
}
ENTRY %entry {
%param.0 = f32[] parameter(0)
%param.1 = f32[] parameter(1)
%fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop,
calls=%add_mul_comp
%gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0
%add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.1)
%broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={}
%mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1)
%gte.2 = f32[1024]{0} get-tuple-element(%fus), index=1
%gte.3 = f32[1024]{0} get-tuple-element(%fus), index=0
%add.2 = f32[1024]{0} add(f32[1024]{0} %mul, f32[1024]{0} %gte.2)
ROOT %mul.2 = f32[1024]{0} multiply(f32[1024]{0} %add.2, f32[1024]{0} %gte.3)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
const HloComputation* computation = module->entry_computation();
const HloInstruction* add = computation->root_instruction();
ASSERT_THAT(add, op::Multiply(op::Add(op::Multiply(),
op::GetTupleElement(op::Fusion())),
op::GetTupleElement(op::Fusion())));
const HloInstruction* fusion = add->operand(0)->operand(0);
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/11 * 1024, module.get()));
EXPECT_TRUE(changed);
ASSERT_THAT(
add,
op::Multiply(
op::Add(op::Multiply(), op::GetTupleElement(AllOf(
op::Fusion(), ::testing::Ne(fusion)))),
op::GetTupleElement(AllOf(op::Fusion(), ::testing::Ne(fusion)))));
// Check that the rematerialized fusion is the same for both ops.
EXPECT_EQ(add->operand(0)->operand(1)->operand(0),
add->operand(1)->operand(0));
}
TEST_F(HloRematerializationTest, RematTupleShapeThroughBitcasts) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
%add_mul_comp {
%p0 = f32[] parameter(0)
%p1 = f32[] parameter(1)
%x = f32[1024]{0} broadcast(f32[] %p0), dimensions={}
%y = f32[1024]{0} broadcast(f32[] %p1), dimensions={}
%add = f32[1024] add(%x, %y)
%mul = f32[1024] multiply(%x, %y)
ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul)
}
ENTRY %entry {
%param.0 = f32[] parameter(0)
%param.1 = f32[] parameter(1)
%fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop,
calls=%add_mul_comp
%gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0
%add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.1)
%broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={}
%mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1)
%gte.2 = f32[1024]{0} get-tuple-element(%fus), index=1
%bc.1 = f32[1024,1]{0,1} bitcast(%mul)
%bc.2 = f32[1024,1]{0,1} bitcast(%gte.2)
ROOT %add.2 = f32[1024,1]{0,1} add(f32[1024,1]{0,1} %bc.1,
f32[1024,1]{0,1} %bc.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
const HloComputation* computation = module->entry_computation();
const HloInstruction* add = computation->root_instruction();
ASSERT_THAT(add, op::Add(op::Bitcast(op::Multiply()),
op::Bitcast(op::GetTupleElement(op::Fusion()))));
const HloInstruction* fusion = add->operand(0)->operand(0)->operand(0);
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/11 * 1024, module.get()));
EXPECT_TRUE(changed);
ASSERT_THAT(add, op::Add(op::Bitcast(op::Multiply()),
op::Bitcast(op::GetTupleElement(
AllOf(op::Fusion(), ::testing::Ne(fusion))))));
}
TEST_F(HloRematerializationTest, RematThroughTuple) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
%add_mul_comp {
%p0 = f32[] parameter(0)
%p1 = f32[] parameter(1)
%x = f32[1024]{0} broadcast(f32[] %p0), dimensions={}
%y = f32[1024]{0} broadcast(f32[] %p1), dimensions={}
%add = f32[1024] add(%x, %y)
%mul = f32[1024] multiply(%x, %y)
ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul)
}
ENTRY %entry {
%param.0 = f32[] parameter(0)
%param.1 = f32[] parameter(1)
%fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop,
calls=%add_mul_comp
%gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0
%gte.3 = f32[1024]{0} get-tuple-element(%fus), index=1
%add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.3)
%broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={}
%mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1)
%tpl = (f32[1024]{0}, f32[1024]{0}) tuple(%gte.1, %add)
%bc.1 = f32[1024,1]{0,1} bitcast(%mul)
%gte.2 = f32[1024]{0} get-tuple-element(%tpl), index=0
ROOT %add.2 = f32[1024]{0} add(f32[1024]{0} %gte.2, f32[1024]{0} %add)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
const HloComputation* computation = module->entry_computation();
const HloInstruction* add = computation->root_instruction();
ASSERT_THAT(add, op::Add(op::GetTupleElement(
op::Tuple(op::GetTupleElement(op::Fusion()), _)),
op::Add()));
const HloInstruction* tuple = add->operand(0)->operand(0);
const HloInstruction* fusion = tuple->operand(0)->operand(0);
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/11 * 1024, module.get()));
EXPECT_TRUE(changed);
ASSERT_THAT(
add, op::Add(op::GetTupleElement(AllOf(op::Fusion(), ::testing::Ne(tuple),
::testing::Ne(fusion))),
op::Add()));
}
} // namespace
} // namespace xla