Fix tests with HLO bugs. PiperOrigin-RevId: 275481999 Change-Id: I803e5f455de4fe92369601d22a26fd657f524331
452 lines
18 KiB
C++
452 lines
18 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
|
|
|
#include <memory>
|
|
|
|
#include "absl/memory/memory.h"
|
|
#include "tensorflow/compiler/xla/layout_util.h"
|
|
#include "tensorflow/compiler/xla/literal_util.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
|
#include "tensorflow/compiler/xla/shape_util.h"
|
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
|
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
|
#include "tensorflow/core/platform/types.h"
|
|
|
|
namespace xla {
|
|
namespace {
|
|
|
|
class HloDceTest : public HloTestBase {
|
|
protected:
|
|
HloDceTest() {}
|
|
|
|
// Returns whether the given instruction exists in the given computation.
|
|
bool HasInstruction(const HloComputation& computation,
|
|
const HloInstruction* instruction) {
|
|
return absl::c_linear_search(computation.instructions(), instruction);
|
|
}
|
|
};
|
|
|
|
TEST_F(HloDceTest, NoDeadCode) {
|
|
// Verify that no dead code is removed from a computation with no dead code.
|
|
auto builder = HloComputation::Builder(TestName());
|
|
auto constant1 = builder.AddInstruction(
|
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
|
|
auto constant2 = builder.AddInstruction(
|
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
|
|
builder.AddInstruction(HloInstruction::CreateBinary(
|
|
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
|
|
|
|
auto module = CreateNewVerifiedModule();
|
|
auto computation = module->AddEntryComputation(builder.Build());
|
|
|
|
EXPECT_EQ(3, computation->instruction_count());
|
|
|
|
HloDCE dce;
|
|
EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
|
|
|
|
EXPECT_EQ(3, computation->instruction_count());
|
|
}
|
|
|
|
TEST_F(HloDceTest, InstructionsWithSideEffect) {
|
|
// Verify that side-effect instructions (Send in this test) are not removed.
|
|
auto builder = HloComputation::Builder(TestName());
|
|
auto constant = builder.AddInstruction(
|
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
|
|
auto token = builder.AddInstruction(HloInstruction::CreateToken());
|
|
auto send = builder.AddInstruction(
|
|
HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
|
|
builder.AddInstruction(HloInstruction::CreateSendDone(send));
|
|
builder.AddInstruction(HloInstruction::CreateTuple({}));
|
|
|
|
auto module = CreateNewVerifiedModule();
|
|
auto computation = module->AddEntryComputation(builder.Build());
|
|
|
|
EXPECT_EQ(5, computation->instruction_count());
|
|
|
|
HloDCE dce;
|
|
EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
|
|
|
|
EXPECT_EQ(5, computation->instruction_count());
|
|
}
|
|
|
|
TEST_F(HloDceTest, CustomCallInstructionsWithSideEffect) {
|
|
// Verify that custom call instruction with side-effect is not removed.
|
|
auto builder = HloComputation::Builder(TestName());
|
|
auto instr = Cast<HloCustomCallInstruction>(builder.AddInstruction(
|
|
HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
|
|
/*operands=*/{},
|
|
/*custom_call_target=*/"foo")));
|
|
instr->set_custom_call_has_side_effect(true);
|
|
builder.AddInstruction(HloInstruction::CreateTuple({}));
|
|
|
|
auto module = CreateNewVerifiedModule();
|
|
module->AddEntryComputation(builder.Build());
|
|
|
|
HloDCE dce;
|
|
TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&dce, module.get()));
|
|
EXPECT_FALSE(result);
|
|
}
|
|
|
|
TEST_F(HloDceTest, CustomCallInstructionsWithoutSideEffect) {
|
|
// Verify that custom call instruction without side-effect is removed.
|
|
auto builder = HloComputation::Builder(TestName());
|
|
builder.AddInstruction(
|
|
HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
|
|
/*operands=*/{},
|
|
/*custom_call_target=*/"foo"));
|
|
builder.AddInstruction(HloInstruction::CreateTuple({}));
|
|
|
|
auto module = CreateNewVerifiedModule();
|
|
module->AddEntryComputation(builder.Build());
|
|
|
|
HloDCE dce;
|
|
TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&dce, module.get()));
|
|
EXPECT_TRUE(result);
|
|
}
|
|
|
|
TEST_F(HloDceTest, DeadParameters) {
|
|
// Verify that dead parameters are not removed, but use of the dead parameters
|
|
// are.
|
|
auto builder = HloComputation::Builder(TestName());
|
|
auto live_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
|
0, ShapeUtil::MakeShape(F32, {}), "live_param"));
|
|
auto dead_param1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
|
1, ShapeUtil::MakeShape(F32, {}), "dead_param1"));
|
|
builder.AddInstruction(HloInstruction::CreateParameter(
|
|
2, ShapeUtil::MakeShape(F32, {}), "dead_param2"));
|
|
|
|
// This is a dead negate instruction.
|
|
builder.AddInstruction(HloInstruction::CreateUnary(
|
|
dead_param1->shape(), HloOpcode::kNegate, dead_param1));
|
|
|
|
// This negate is not dead because it is the root.
|
|
builder.AddInstruction(HloInstruction::CreateUnary(
|
|
live_param->shape(), HloOpcode::kNegate, live_param));
|
|
|
|
auto module = CreateNewVerifiedModule();
|
|
auto computation = module->AddEntryComputation(builder.Build());
|
|
|
|
EXPECT_EQ(5, computation->instruction_count());
|
|
EXPECT_EQ(1, dead_param1->user_count());
|
|
|
|
HloDCE dce;
|
|
EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
|
|
|
|
EXPECT_EQ(4, computation->instruction_count());
|
|
EXPECT_EQ(0, dead_param1->user_count());
|
|
}
|
|
|
|
TEST_F(HloDceTest, ControlDependencies) {
|
|
// Verify that instructions with control dependencies are not removed.
|
|
auto builder = HloComputation::Builder(TestName());
|
|
auto constant1 = builder.AddInstruction(
|
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
|
|
auto constant2 = builder.AddInstruction(
|
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
|
|
|
|
// Create two dead instructions: a negate and an add.
|
|
auto dead_negate = builder.AddInstruction(HloInstruction::CreateUnary(
|
|
constant1->shape(), HloOpcode::kNegate, constant1));
|
|
auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary(
|
|
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
|
|
|
|
// Create the same two instructions again, but these will have a control
|
|
// dependency added.
|
|
auto dead_negate_with_control_dep =
|
|
builder.AddInstruction(HloInstruction::CreateUnary(
|
|
constant1->shape(), HloOpcode::kNegate, constant1));
|
|
auto dead_add_with_control_dep =
|
|
builder.AddInstruction(HloInstruction::CreateBinary(
|
|
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
|
|
|
|
// Create a root so the previously added instruction is dead.
|
|
builder.AddInstruction(HloInstruction::CreateBinary(
|
|
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
|
|
|
|
auto module = CreateNewVerifiedModule();
|
|
auto computation = module->AddEntryComputation(builder.Build());
|
|
|
|
// Add a control dependency between two instructions.
|
|
TF_ASSERT_OK(dead_negate_with_control_dep->AddControlDependencyTo(
|
|
dead_add_with_control_dep));
|
|
|
|
EXPECT_EQ(7, computation->instruction_count());
|
|
EXPECT_TRUE(HasInstruction(*computation, dead_negate));
|
|
EXPECT_TRUE(HasInstruction(*computation, dead_add));
|
|
EXPECT_TRUE(HasInstruction(*computation, dead_negate_with_control_dep));
|
|
EXPECT_TRUE(HasInstruction(*computation, dead_add_with_control_dep));
|
|
|
|
HloDCE dce;
|
|
EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
|
|
|
|
EXPECT_EQ(5, computation->instruction_count());
|
|
EXPECT_FALSE(HasInstruction(*computation, dead_negate));
|
|
EXPECT_FALSE(HasInstruction(*computation, dead_add));
|
|
EXPECT_TRUE(HasInstruction(*computation, dead_negate_with_control_dep));
|
|
EXPECT_TRUE(HasInstruction(*computation, dead_add_with_control_dep));
|
|
}
|
|
|
|
// Tests that a dead call instruction is removed.
|
|
TEST_F(HloDceTest, DeadInstructionWithCalledComputation) {
|
|
auto module = CreateNewVerifiedModule();
|
|
Shape shape = ShapeUtil::MakeShape(F32, {});
|
|
|
|
// Called computation for the call instruction.
|
|
auto callee_builder = HloComputation::Builder(TestName() + "-callee");
|
|
{
|
|
auto param = callee_builder.AddInstruction(
|
|
HloInstruction::CreateParameter(0, shape, "param"));
|
|
callee_builder.AddInstruction(
|
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
|
|
}
|
|
auto called_computation =
|
|
module->AddEmbeddedComputation(callee_builder.Build());
|
|
|
|
// Entry computation with a call instruction.
|
|
auto builder = HloComputation::Builder(TestName());
|
|
auto param = builder.AddInstruction(
|
|
HloInstruction::CreateParameter(0, shape, "param"));
|
|
auto dead_call = builder.AddInstruction(
|
|
HloInstruction::CreateCall(shape, {param}, called_computation));
|
|
builder.AddInstruction(
|
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
|
|
auto computation = module->AddEntryComputation(builder.Build());
|
|
|
|
EXPECT_EQ(3, computation->instruction_count());
|
|
EXPECT_EQ(2, param->user_count());
|
|
EXPECT_EQ(0, dead_call->user_count());
|
|
EXPECT_TRUE(HasInstruction(*computation, dead_call));
|
|
|
|
HloDCE dce;
|
|
EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
|
|
|
|
EXPECT_EQ(2, computation->instruction_count());
|
|
EXPECT_EQ(1, param->user_count());
|
|
EXPECT_FALSE(HasInstruction(*computation, dead_call));
|
|
}
|
|
|
|
// Tests that a while instruction with an infeed (effectul instruction) in its
|
|
// body is not removed, even its user count is 0.
|
|
TEST_F(HloDceTest, CalledComputationWithSideEffect) {
|
|
auto module = CreateNewVerifiedModule();
|
|
Shape shape = ShapeUtil::MakeShape(F32, {});
|
|
|
|
// Condition computation of a while instruction.
|
|
auto cond_builder = HloComputation::Builder(TestName() + "-cond");
|
|
{
|
|
auto param = cond_builder.AddInstruction(
|
|
HloInstruction::CreateParameter(0, shape, "cond_param"));
|
|
auto constant = cond_builder.AddInstruction(
|
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
|
|
cond_builder.AddInstruction(
|
|
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
|
|
constant, ComparisonDirection::kLt));
|
|
}
|
|
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
|
|
|
|
// Body computation of a while instruction.
|
|
auto body_builder = HloComputation::Builder(TestName() + "-body");
|
|
{
|
|
auto param = body_builder.AddInstruction(
|
|
HloInstruction::CreateParameter(0, shape, "param"));
|
|
auto token = body_builder.AddInstruction(HloInstruction::CreateToken());
|
|
auto infeed = body_builder.AddInstruction(
|
|
HloInstruction::CreateInfeed(shape, token, ""));
|
|
auto infeed_data = body_builder.AddInstruction(
|
|
HloInstruction::CreateGetTupleElement(shape, infeed, 0));
|
|
body_builder.AddInstruction(HloInstruction::CreateBinary(
|
|
shape, HloOpcode::kAdd, param, infeed_data));
|
|
}
|
|
auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
|
|
|
|
// Entry computation with a while instruction and a negate on the parameter.
|
|
auto builder = HloComputation::Builder(TestName());
|
|
auto param = builder.AddInstruction(
|
|
HloInstruction::CreateParameter(0, shape, "param"));
|
|
auto live_while = builder.AddInstruction(HloInstruction::CreateWhile(
|
|
shape, cond_computation, body_computation, param));
|
|
builder.AddInstruction(
|
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
|
|
auto computation = module->AddEntryComputation(builder.Build());
|
|
|
|
// Check the while instruction is not removed even if its user count is 0.
|
|
EXPECT_EQ(3, computation->instruction_count());
|
|
EXPECT_EQ(2, param->user_count());
|
|
EXPECT_EQ(0, live_while->user_count());
|
|
EXPECT_TRUE(HasInstruction(*computation, live_while));
|
|
|
|
HloDCE dce;
|
|
EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
|
|
|
|
EXPECT_EQ(3, computation->instruction_count());
|
|
EXPECT_EQ(2, param->user_count());
|
|
EXPECT_EQ(0, live_while->user_count());
|
|
EXPECT_TRUE(HasInstruction(*computation, live_while));
|
|
}
|
|
|
|
// Tests that a nested call instruction with a side effect is not removed.
|
|
TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) {
|
|
auto module = CreateNewVerifiedModule();
|
|
Shape shape = ShapeUtil::MakeShape(F32, {});
|
|
|
|
// Nested called computation with a side effect.
|
|
auto nested_callee_builder =
|
|
HloComputation::Builder(TestName() + "-nested_callee");
|
|
{
|
|
auto param = nested_callee_builder.AddInstruction(
|
|
HloInstruction::CreateParameter(0, shape, "param"));
|
|
auto token =
|
|
nested_callee_builder.AddInstruction(HloInstruction::CreateToken());
|
|
nested_callee_builder.AddInstruction(
|
|
HloInstruction::CreateOutfeed(shape, param, token, ""));
|
|
}
|
|
auto nested_called_computation =
|
|
module->AddEmbeddedComputation(nested_callee_builder.Build());
|
|
|
|
// Outer called computation that calls the nested computation.
|
|
auto callee_builder = HloComputation::Builder(TestName() + "-callee");
|
|
{
|
|
auto param = callee_builder.AddInstruction(
|
|
HloInstruction::CreateParameter(0, shape, "param"));
|
|
callee_builder.AddInstruction(HloInstruction::CreateCall(
|
|
ShapeUtil::MakeTokenShape(), {param}, nested_called_computation));
|
|
}
|
|
auto called_computation =
|
|
module->AddEmbeddedComputation(callee_builder.Build());
|
|
|
|
// Entry computation with a call instruction.
|
|
auto builder = HloComputation::Builder(TestName());
|
|
auto param = builder.AddInstruction(
|
|
HloInstruction::CreateParameter(0, shape, "param"));
|
|
auto live_call = builder.AddInstruction(HloInstruction::CreateCall(
|
|
ShapeUtil::MakeTokenShape(), {param}, called_computation));
|
|
auto computation = module->AddEntryComputation(builder.Build());
|
|
|
|
EXPECT_EQ(2, computation->instruction_count());
|
|
EXPECT_EQ(1, param->user_count());
|
|
EXPECT_EQ(0, live_call->user_count());
|
|
EXPECT_TRUE(HasInstruction(*computation, live_call));
|
|
|
|
HloDCE dce;
|
|
EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
|
|
|
|
EXPECT_EQ(2, computation->instruction_count());
|
|
EXPECT_EQ(1, param->user_count());
|
|
EXPECT_EQ(0, live_call->user_count());
|
|
EXPECT_TRUE(HasInstruction(*computation, live_call));
|
|
}
|
|
|
|
TEST_F(HloDceTest, RemoveDeadSubcomputation) {
|
|
auto module = CreateNewVerifiedModule();
|
|
HloComputation::Builder builder(TestName());
|
|
|
|
HloComputation::Builder subcomp_builder("reduction_subcomp");
|
|
{
|
|
auto* param0 =
|
|
subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
|
|
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "param0"));
|
|
auto* param1 =
|
|
subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
|
|
/*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "param1"));
|
|
subcomp_builder.AddInstruction(HloInstruction::CreateBinary(
|
|
ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, param0, param1));
|
|
}
|
|
auto reduce_subcomp = module->AddEmbeddedComputation(subcomp_builder.Build());
|
|
|
|
// Create a dead reduce instruction.
|
|
builder.AddInstruction(HloInstruction::CreateReduce(
|
|
ShapeUtil::MakeShape(F32, {1}),
|
|
builder.AddInstruction(HloInstruction::CreateParameter(
|
|
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")),
|
|
builder.AddInstruction(
|
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
|
|
/*dimensions_to_reduce=*/{0}, reduce_subcomp));
|
|
|
|
// Add another instruction as the root of the computation.
|
|
builder.AddInstruction(
|
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
|
|
|
|
module->AddEntryComputation(builder.Build());
|
|
EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
|
|
|
|
HloDCE dce;
|
|
EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
|
|
|
|
// We should have DCE'ed the reduction computation along with the reduction
|
|
// instruction.
|
|
EXPECT_EQ(module->MakeComputationPostOrder().size(), 1);
|
|
}
|
|
|
|
TEST_F(HloDceTest, KeepUsedSubcomputation) {
|
|
auto module = CreateNewVerifiedModule();
|
|
HloComputation::Builder builder(TestName());
|
|
|
|
HloComputation::Builder subcomp_builder("reduction_subcomp");
|
|
{
|
|
auto* param0 =
|
|
subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
|
|
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "param0"));
|
|
auto* param1 =
|
|
subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
|
|
/*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "param1"));
|
|
subcomp_builder.AddInstruction(HloInstruction::CreateBinary(
|
|
ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, param0, param1));
|
|
}
|
|
auto reduce_subcomp = module->AddEmbeddedComputation(subcomp_builder.Build());
|
|
|
|
// Create a dead reduce instruction.
|
|
builder.AddInstruction(HloInstruction::CreateReduce(
|
|
ShapeUtil::MakeShape(F32, {}),
|
|
builder.AddInstruction(HloInstruction::CreateParameter(
|
|
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")),
|
|
builder.AddInstruction(
|
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
|
|
/*dimensions_to_reduce=*/{0}, reduce_subcomp));
|
|
|
|
// Add another instruction as the root of the computation that also uses
|
|
// reduce_subcomp.
|
|
builder.AddInstruction(HloInstruction::CreateReduce(
|
|
ShapeUtil::MakeShape(F32, {}),
|
|
builder.AddInstruction(HloInstruction::CreateParameter(
|
|
/*parameter_number=*/1, ShapeUtil::MakeShape(F32, {100}), "param1")),
|
|
builder.AddInstruction(
|
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
|
|
/*dimensions_to_reduce=*/{0}, reduce_subcomp));
|
|
|
|
module->AddEntryComputation(builder.Build());
|
|
EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
|
|
|
|
HloDCE dce;
|
|
EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
|
|
|
|
// We shouldn't have DCE'ed reduce_subcomp, even though we removed one of
|
|
// its users.
|
|
EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace xla
|