diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index d77e62a1357..1f82c062df9 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -428,10 +428,6 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { shape, hlo, zero, dims, AddReduce_computation)); } - // Move scalar multiply to the smallest side of convolution to - // reduce multiply computations. - Status ScalarMultiplyReduction(HloInstruction* dot); - // Convenience method for replacing an instruction with a bitcast. If operand // is not null, then the bitcast will use the specified operand instead of the // operand of the instruction. @@ -567,197 +563,6 @@ bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, } } -namespace { - -float GetConstantValue(HloInstruction* inst) { - switch (inst->shape().element_type()) { - case BF16: - return static_cast(inst->literal().GetFirstElement()); - case F32: - return inst->literal().GetFirstElement(); - default: - LOG(FATAL) << "Unsupported data type: " << inst->shape().element_type(); - } -} - -bool IsOpCodeMultiplyCommutative(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kMultiply: - case HloOpcode::kTranspose: - case HloOpcode::kReshape: - case HloOpcode::kSelect: - return true; - default: - return false; - } -} - -std::unique_ptr MakeScalarInstruction(HloInstruction* target, - float multiplier) { - switch (target->shape().element_type()) { - case BF16: - return HloInstruction::CreateConstant(LiteralUtil::ConvertF32ToBF16( - LiteralUtil::CreateR0(multiplier))); - break; - case F32: - return HloInstruction::CreateConstant( - LiteralUtil::CreateR0(multiplier)); - break; - default: - LOG(FATAL) << "Unsupported data type: " << target->shape().element_type(); - } -} - -} // namespace - -Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction( - HloInstruction* dot) { - // We only process bfloat16 and float32 for now. - if (dot->shape().element_type() != BF16 && - dot->shape().element_type() != F32) { - return Status::OK(); - } - - auto lhs = dot->mutable_operand(0); - auto rhs = dot->mutable_operand(1); - - const int64 dot_size = ShapeUtil::ElementsIn(dot->shape()); - const int64 lhs_size = ShapeUtil::ElementsIn(lhs->shape()); - const int64 rhs_size = ShapeUtil::ElementsIn(rhs->shape()); - - HloInstruction* target = nullptr; - // (current node, user, operand_index) - std::vector> operands; - std::vector users; - - // Find which side of dot has the smallest size: - // operand 0, operand 1, or output. - if (dot_size <= std::min(lhs_size, rhs_size)) { - target = dot; - if (dot_size < lhs_size) { - operands.emplace_back(lhs, dot, 0); - } - if (dot_size < rhs_size) { - operands.emplace_back(rhs, dot, 1); - } - } else if (lhs_size <= rhs_size) { - target = lhs; - if (lhs_size < rhs_size) { - operands.emplace_back(rhs, dot, 1); - } - if (lhs_size < dot_size && dot->user_count() == 1) { - users.push_back(dot->users().front()); - } - } else { - target = rhs; - if (rhs_size < lhs_size) { - operands.emplace_back(lhs, dot, 0); - } - if (rhs_size < dot_size && dot->user_count() == 1) { - users.push_back(dot->users().front()); - } - } - - std::vector values; - - // DFS to find scalar multiply ops from the operands. - while (!operands.empty()) { - auto [inst, user, index] = operands.back(); - operands.pop_back(); - - // Skip the op types that are not commutative with multiply. - if (!IsOpCodeMultiplyCommutative(inst->opcode())) { - continue; - } - - HloInstruction* operand; - HloInstruction* multiplier; - // Pattern match a scalar multiply. - if (Match(inst, m::MultiplyAnyOrder( - m::Op(&operand), - m::Broadcast(m::ConstantScalar(&multiplier))))) { - CHECK_LT(index, user->operand_count()); - CHECK_EQ(inst, user->operands()[index]); - - // When found a scalar multiply, save its scalar value. - values.push_back(GetConstantValue(multiplier)); - // And remove the scalar multiply op. - TF_RETURN_IF_ERROR(user->ReplaceOperandWith(index, operand)); - inst = operand; - } - - // Push the operands of inst. - int64 i = 0; - for (auto* operand : inst->operands()) { - operands.emplace_back(operand, inst, i++); - } - } - - // DFS to find scalar multiply ops from the users. - while (!users.empty()) { - auto inst = users.back(); - users.pop_back(); - - if (!IsOpCodeMultiplyCommutative(inst->opcode())) { - continue; - } - - HloInstruction* operand; - HloInstruction* multiplier; - if (Match(inst, m::MultiplyAnyOrder( - m::Op(&operand), - m::Broadcast(m::ConstantScalar(&multiplier))))) { - values.push_back(GetConstantValue(multiplier)); - - TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(operand)); - inst = operand; - } - - // Process the instructions with only one user. - // Otherwise moving scalar multiply to the operands changes the values of - // other users. - if (inst->user_count() == 1) { - users.push_back(inst->users().front()); - } - } - - if (values.empty()) { - return Status::OK(); - } - - changed_ = true; - - // Combine all constant multipliers. - float multiplier = 1.0; - for (const float v : values) { - multiplier *= v; - } - - // Create a new const scalar multiply instruction. - HloInstruction* new_const_inst; - new_const_inst = - computation_->AddInstruction(MakeScalarInstruction(target, multiplier)); - - // Broadcast the scalar multiplier. - HloInstruction* new_broadcast = computation_->AddInstruction( - HloInstruction::CreateBroadcast(target->shape(), new_const_inst, {})); - // Create a new scalar multiply instruction. - HloInstruction* new_multiply = - computation_->AddInstruction(HloInstruction::CreateBinary( - target->shape(), HloOpcode::kMultiply, target, new_broadcast)); - CHECK_EQ(new_multiply->shape(), target->shape()); - - // Update the dependency with the rest of the instructions. - if (target == lhs) { - return dot->ReplaceOperandWith(0, new_multiply); - } else if (target == rhs) { - return dot->ReplaceOperandWith(1, new_multiply); - } else { - CHECK_EQ(target, dot); - return dot->ReplaceAllUsesWith(new_multiply); - } -} - void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction, HloInstruction* operand) { CHECK_EQ(1, instruction->operand_count()); @@ -5237,10 +5042,6 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( Status AlgebraicSimplifierVisitor::HandleConvolution( HloInstruction* convolution) { - if (options_.enable_scalar_multiply_reduction()) { - TF_RETURN_IF_ERROR(ScalarMultiplyReduction(convolution)); - } - // Zero-sized input or filter. if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) || ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 9f2a3404116..9f29df3c209 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -86,17 +86,6 @@ class AlgebraicSimplifierOptions { } bool enable_conv_operand_swap() const { return enable_conv_operand_swap_; } - // Move constant scalar multiply to one operand or output of convolutions with - // the smallest tensor size, to reduce the number of scalar multiply. - void set_enable_scalar_multiply_reduction( - bool enable_scalar_multiply_reduction) { - enable_scalar_multiply_reduction_ = enable_scalar_multiply_reduction; - } - - bool enable_scalar_multiply_reduction() const { - return enable_scalar_multiply_reduction_; - } - // If enable_window_reduce_replacement is true, the kReduceWindow instruction // can be optimized by replacement with simpler operations. void set_enable_window_reduce_to_reduce_replacement( @@ -157,7 +146,6 @@ class AlgebraicSimplifierOptions { bool enable_dot_to_multiply_rewrite_{true}; bool enable_conv_simplification_{true}; bool enable_conv_operand_swap_{true}; - bool enable_scalar_multiply_reduction_{false}; bool enable_window_reduce_to_reduce_replacement_{true}; bool enable_reduce_of_reshape_{true}; bool replace_transpose_with_bitcast_{true}; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 90ca44714f7..034d8ec4361 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -5343,59 +5343,6 @@ ENTRY AddBroadcastZeroWithDynamicSlice { EXPECT_THAT(root->operand(1)->opcode(), HloOpcode::kPad); } -TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReduction) { - const char* hlo_string = R"( -HloModule ConstScalarMultiply -ENTRY ConstScalarMultiply { - param0 = f32[16,512,4096]{2,1,0} parameter(0) - constant.0 = f32[] constant(0.5) - broadcast.0 = f32[16,512,4096] broadcast(constant.0), dimensions={} - multiply.0 = f32[16,512,4096]{2,1,0} multiply(param0, broadcast.0) - param1 = f32[16,512,4096]{2,1,0} parameter(1) - multiply.1 = f32[16,512,4096]{2,1,0} multiply(multiply.0, param1) - param2 = f32[16,512,1024]{2,1,0} parameter(2) - constant.1 = f32[] constant(1.109) - broadcast.1 = f32[16,512,1024] broadcast(constant.1), dimensions={} - multiply.2 = f32[16,512,1024]{2,1,0} multiply(param2, broadcast.1) - ROOT convolution = f32[4096,1024,1]{1,0,2} convolution(multiply.1, multiply.2), window={size=16}, dim_labels=0fb_0io->bf0 -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options; - options.set_enable_scalar_multiply_reduction(true); - AlgebraicSimplifier simplifier(options); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - auto root = module->entry_computation()->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kMultiply); - EXPECT_THAT(root, - GmockMatch(m::MultiplyAnyOrder( - m::Op(), m::Broadcast(m::ConstantScalar(0.5f * 1.109f))))); -} - -TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReductionMultiUser) { - const char* hlo_string = R"( -HloModule ConstScalarMultiply -ENTRY ConstScalarMultiply { - param0 = f32[16,512,1024] parameter(0) - param1 = f32[4096,1024,1] parameter(1) - convolution = f32[16,512,4096] convolution(param0, param1), window={size=1}, dim_labels=0bf_oi0->0bf - constant.1 = f32[] constant(0.5) - broadcast.1 = f32[16,512,4096] broadcast(constant.1), dimensions={} - multiply.1 = f32[16,512,4096] multiply(convolution, broadcast.1) - param2 = f32[16,512,4096] parameter(2) - multiply.2 = f32[16,512,4096] multiply(convolution, param2) - ROOT add.1 = f32[16,512,4096] add(multiply.1, multiply.2) -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options; - options.set_enable_scalar_multiply_reduction(true); - AlgebraicSimplifier simplifier(options); - ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); -} - INSTANTIATE_TEST_SUITE_P(DotOfConcatSimplificationTestInstantiation, DotOfConcatSimplificationTest, ::testing::ValuesIn(kDotOfConcatTestSpecs));