Move scalar multiply to the smaller side of convolution.

PiperOrigin-RevId: 324283914
Change-Id: I40f5f8cbf47e4c60997ed03bbff114f5f17519b4
This commit is contained in:
Amit Patankar 2020-07-31 14:21:44 -07:00 committed by TensorFlower Gardener
parent 781ff0196c
commit 934b4b6a35
3 changed files with 0 additions and 264 deletions

View File

@ -428,10 +428,6 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
shape, hlo, zero, dims, AddReduce_computation));
}
// Move scalar multiply to the smallest side of convolution to
// reduce multiply computations.
Status ScalarMultiplyReduction(HloInstruction* dot);
// Convenience method for replacing an instruction with a bitcast. If operand
// is not null, then the bitcast will use the specified operand instead of the
// operand of the instruction.
@ -567,197 +563,6 @@ bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
}
}
namespace {
float GetConstantValue(HloInstruction* inst) {
switch (inst->shape().element_type()) {
case BF16:
return static_cast<float>(inst->literal().GetFirstElement<bfloat16>());
case F32:
return inst->literal().GetFirstElement<float>();
default:
LOG(FATAL) << "Unsupported data type: " << inst->shape().element_type();
}
}
bool IsOpCodeMultiplyCommutative(HloOpcode opcode) {
switch (opcode) {
case HloOpcode::kMultiply:
case HloOpcode::kTranspose:
case HloOpcode::kReshape:
case HloOpcode::kSelect:
return true;
default:
return false;
}
}
std::unique_ptr<HloInstruction> MakeScalarInstruction(HloInstruction* target,
float multiplier) {
switch (target->shape().element_type()) {
case BF16:
return HloInstruction::CreateConstant(LiteralUtil::ConvertF32ToBF16(
LiteralUtil::CreateR0<float>(multiplier)));
break;
case F32:
return HloInstruction::CreateConstant(
LiteralUtil::CreateR0<float>(multiplier));
break;
default:
LOG(FATAL) << "Unsupported data type: " << target->shape().element_type();
}
}
} // namespace
Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction(
HloInstruction* dot) {
// We only process bfloat16 and float32 for now.
if (dot->shape().element_type() != BF16 &&
dot->shape().element_type() != F32) {
return Status::OK();
}
auto lhs = dot->mutable_operand(0);
auto rhs = dot->mutable_operand(1);
const int64 dot_size = ShapeUtil::ElementsIn(dot->shape());
const int64 lhs_size = ShapeUtil::ElementsIn(lhs->shape());
const int64 rhs_size = ShapeUtil::ElementsIn(rhs->shape());
HloInstruction* target = nullptr;
// (current node, user, operand_index)
std::vector<std::tuple<HloInstruction*, HloInstruction*, int64>> operands;
std::vector<HloInstruction*> users;
// Find which side of dot has the smallest size:
// operand 0, operand 1, or output.
if (dot_size <= std::min(lhs_size, rhs_size)) {
target = dot;
if (dot_size < lhs_size) {
operands.emplace_back(lhs, dot, 0);
}
if (dot_size < rhs_size) {
operands.emplace_back(rhs, dot, 1);
}
} else if (lhs_size <= rhs_size) {
target = lhs;
if (lhs_size < rhs_size) {
operands.emplace_back(rhs, dot, 1);
}
if (lhs_size < dot_size && dot->user_count() == 1) {
users.push_back(dot->users().front());
}
} else {
target = rhs;
if (rhs_size < lhs_size) {
operands.emplace_back(lhs, dot, 0);
}
if (rhs_size < dot_size && dot->user_count() == 1) {
users.push_back(dot->users().front());
}
}
std::vector<float> values;
// DFS to find scalar multiply ops from the operands.
while (!operands.empty()) {
auto [inst, user, index] = operands.back();
operands.pop_back();
// Skip the op types that are not commutative with multiply.
if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
continue;
}
HloInstruction* operand;
HloInstruction* multiplier;
// Pattern match a scalar multiply.
if (Match(inst, m::MultiplyAnyOrder(
m::Op(&operand),
m::Broadcast(m::ConstantScalar(&multiplier))))) {
CHECK_LT(index, user->operand_count());
CHECK_EQ(inst, user->operands()[index]);
// When found a scalar multiply, save its scalar value.
values.push_back(GetConstantValue(multiplier));
// And remove the scalar multiply op.
TF_RETURN_IF_ERROR(user->ReplaceOperandWith(index, operand));
inst = operand;
}
// Push the operands of inst.
int64 i = 0;
for (auto* operand : inst->operands()) {
operands.emplace_back(operand, inst, i++);
}
}
// DFS to find scalar multiply ops from the users.
while (!users.empty()) {
auto inst = users.back();
users.pop_back();
if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
continue;
}
HloInstruction* operand;
HloInstruction* multiplier;
if (Match(inst, m::MultiplyAnyOrder(
m::Op(&operand),
m::Broadcast(m::ConstantScalar(&multiplier))))) {
values.push_back(GetConstantValue(multiplier));
TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(operand));
inst = operand;
}
// Process the instructions with only one user.
// Otherwise moving scalar multiply to the operands changes the values of
// other users.
if (inst->user_count() == 1) {
users.push_back(inst->users().front());
}
}
if (values.empty()) {
return Status::OK();
}
changed_ = true;
// Combine all constant multipliers.
float multiplier = 1.0;
for (const float v : values) {
multiplier *= v;
}
// Create a new const scalar multiply instruction.
HloInstruction* new_const_inst;
new_const_inst =
computation_->AddInstruction(MakeScalarInstruction(target, multiplier));
// Broadcast the scalar multiplier.
HloInstruction* new_broadcast = computation_->AddInstruction(
HloInstruction::CreateBroadcast(target->shape(), new_const_inst, {}));
// Create a new scalar multiply instruction.
HloInstruction* new_multiply =
computation_->AddInstruction(HloInstruction::CreateBinary(
target->shape(), HloOpcode::kMultiply, target, new_broadcast));
CHECK_EQ(new_multiply->shape(), target->shape());
// Update the dependency with the rest of the instructions.
if (target == lhs) {
return dot->ReplaceOperandWith(0, new_multiply);
} else if (target == rhs) {
return dot->ReplaceOperandWith(1, new_multiply);
} else {
CHECK_EQ(target, dot);
return dot->ReplaceAllUsesWith(new_multiply);
}
}
void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction,
HloInstruction* operand) {
CHECK_EQ(1, instruction->operand_count());
@ -5237,10 +5042,6 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
Status AlgebraicSimplifierVisitor::HandleConvolution(
HloInstruction* convolution) {
if (options_.enable_scalar_multiply_reduction()) {
TF_RETURN_IF_ERROR(ScalarMultiplyReduction(convolution));
}
// Zero-sized input or filter.
if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {

View File

@ -86,17 +86,6 @@ class AlgebraicSimplifierOptions {
}
bool enable_conv_operand_swap() const { return enable_conv_operand_swap_; }
// Move constant scalar multiply to one operand or output of convolutions with
// the smallest tensor size, to reduce the number of scalar multiply.
void set_enable_scalar_multiply_reduction(
bool enable_scalar_multiply_reduction) {
enable_scalar_multiply_reduction_ = enable_scalar_multiply_reduction;
}
bool enable_scalar_multiply_reduction() const {
return enable_scalar_multiply_reduction_;
}
// If enable_window_reduce_replacement is true, the kReduceWindow instruction
// can be optimized by replacement with simpler operations.
void set_enable_window_reduce_to_reduce_replacement(
@ -157,7 +146,6 @@ class AlgebraicSimplifierOptions {
bool enable_dot_to_multiply_rewrite_{true};
bool enable_conv_simplification_{true};
bool enable_conv_operand_swap_{true};
bool enable_scalar_multiply_reduction_{false};
bool enable_window_reduce_to_reduce_replacement_{true};
bool enable_reduce_of_reshape_{true};
bool replace_transpose_with_bitcast_{true};

View File

@ -5343,59 +5343,6 @@ ENTRY AddBroadcastZeroWithDynamicSlice {
EXPECT_THAT(root->operand(1)->opcode(), HloOpcode::kPad);
}
TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReduction) {
const char* hlo_string = R"(
HloModule ConstScalarMultiply
ENTRY ConstScalarMultiply {
param0 = f32[16,512,4096]{2,1,0} parameter(0)
constant.0 = f32[] constant(0.5)
broadcast.0 = f32[16,512,4096] broadcast(constant.0), dimensions={}
multiply.0 = f32[16,512,4096]{2,1,0} multiply(param0, broadcast.0)
param1 = f32[16,512,4096]{2,1,0} parameter(1)
multiply.1 = f32[16,512,4096]{2,1,0} multiply(multiply.0, param1)
param2 = f32[16,512,1024]{2,1,0} parameter(2)
constant.1 = f32[] constant(1.109)
broadcast.1 = f32[16,512,1024] broadcast(constant.1), dimensions={}
multiply.2 = f32[16,512,1024]{2,1,0} multiply(param2, broadcast.1)
ROOT convolution = f32[4096,1024,1]{1,0,2} convolution(multiply.1, multiply.2), window={size=16}, dim_labels=0fb_0io->bf0
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AlgebraicSimplifierOptions options;
options.set_enable_scalar_multiply_reduction(true);
AlgebraicSimplifier simplifier(options);
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
EXPECT_THAT(root,
GmockMatch(m::MultiplyAnyOrder(
m::Op(), m::Broadcast(m::ConstantScalar(0.5f * 1.109f)))));
}
TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReductionMultiUser) {
const char* hlo_string = R"(
HloModule ConstScalarMultiply
ENTRY ConstScalarMultiply {
param0 = f32[16,512,1024] parameter(0)
param1 = f32[4096,1024,1] parameter(1)
convolution = f32[16,512,4096] convolution(param0, param1), window={size=1}, dim_labels=0bf_oi0->0bf
constant.1 = f32[] constant(0.5)
broadcast.1 = f32[16,512,4096] broadcast(constant.1), dimensions={}
multiply.1 = f32[16,512,4096] multiply(convolution, broadcast.1)
param2 = f32[16,512,4096] parameter(2)
multiply.2 = f32[16,512,4096] multiply(convolution, param2)
ROOT add.1 = f32[16,512,4096] add(multiply.1, multiply.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AlgebraicSimplifierOptions options;
options.set_enable_scalar_multiply_reduction(true);
AlgebraicSimplifier simplifier(options);
ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
}
INSTANTIATE_TEST_SUITE_P(DotOfConcatSimplificationTestInstantiation,
DotOfConcatSimplificationTest,
::testing::ValuesIn(kDotOfConcatTestSpecs));