Move scalar multiply to the smaller side of convolution.
PiperOrigin-RevId: 324283914 Change-Id: I40f5f8cbf47e4c60997ed03bbff114f5f17519b4
This commit is contained in:
		
							parent
							
								
									781ff0196c
								
							
						
					
					
						commit
						934b4b6a35
					
				@ -428,10 +428,6 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
 | 
			
		||||
        shape, hlo, zero, dims, AddReduce_computation));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Move scalar multiply to the smallest side of convolution to
 | 
			
		||||
  // reduce multiply computations.
 | 
			
		||||
  Status ScalarMultiplyReduction(HloInstruction* dot);
 | 
			
		||||
 | 
			
		||||
  // Convenience method for replacing an instruction with a bitcast. If operand
 | 
			
		||||
  // is not null, then the bitcast will use the specified operand instead of the
 | 
			
		||||
  // operand of the instruction.
 | 
			
		||||
@ -567,197 +563,6 @@ bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
float GetConstantValue(HloInstruction* inst) {
 | 
			
		||||
  switch (inst->shape().element_type()) {
 | 
			
		||||
    case BF16:
 | 
			
		||||
      return static_cast<float>(inst->literal().GetFirstElement<bfloat16>());
 | 
			
		||||
    case F32:
 | 
			
		||||
      return inst->literal().GetFirstElement<float>();
 | 
			
		||||
    default:
 | 
			
		||||
      LOG(FATAL) << "Unsupported data type: " << inst->shape().element_type();
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool IsOpCodeMultiplyCommutative(HloOpcode opcode) {
 | 
			
		||||
  switch (opcode) {
 | 
			
		||||
    case HloOpcode::kMultiply:
 | 
			
		||||
    case HloOpcode::kTranspose:
 | 
			
		||||
    case HloOpcode::kReshape:
 | 
			
		||||
    case HloOpcode::kSelect:
 | 
			
		||||
      return true;
 | 
			
		||||
    default:
 | 
			
		||||
      return false;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<HloInstruction> MakeScalarInstruction(HloInstruction* target,
 | 
			
		||||
                                                      float multiplier) {
 | 
			
		||||
  switch (target->shape().element_type()) {
 | 
			
		||||
    case BF16:
 | 
			
		||||
      return HloInstruction::CreateConstant(LiteralUtil::ConvertF32ToBF16(
 | 
			
		||||
          LiteralUtil::CreateR0<float>(multiplier)));
 | 
			
		||||
      break;
 | 
			
		||||
    case F32:
 | 
			
		||||
      return HloInstruction::CreateConstant(
 | 
			
		||||
          LiteralUtil::CreateR0<float>(multiplier));
 | 
			
		||||
      break;
 | 
			
		||||
    default:
 | 
			
		||||
      LOG(FATAL) << "Unsupported data type: " << target->shape().element_type();
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction(
 | 
			
		||||
    HloInstruction* dot) {
 | 
			
		||||
  // We only process bfloat16 and float32 for now.
 | 
			
		||||
  if (dot->shape().element_type() != BF16 &&
 | 
			
		||||
      dot->shape().element_type() != F32) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto lhs = dot->mutable_operand(0);
 | 
			
		||||
  auto rhs = dot->mutable_operand(1);
 | 
			
		||||
 | 
			
		||||
  const int64 dot_size = ShapeUtil::ElementsIn(dot->shape());
 | 
			
		||||
  const int64 lhs_size = ShapeUtil::ElementsIn(lhs->shape());
 | 
			
		||||
  const int64 rhs_size = ShapeUtil::ElementsIn(rhs->shape());
 | 
			
		||||
 | 
			
		||||
  HloInstruction* target = nullptr;
 | 
			
		||||
  // (current node, user, operand_index)
 | 
			
		||||
  std::vector<std::tuple<HloInstruction*, HloInstruction*, int64>> operands;
 | 
			
		||||
  std::vector<HloInstruction*> users;
 | 
			
		||||
 | 
			
		||||
  // Find which side of dot has the smallest size:
 | 
			
		||||
  // operand 0, operand 1, or output.
 | 
			
		||||
  if (dot_size <= std::min(lhs_size, rhs_size)) {
 | 
			
		||||
    target = dot;
 | 
			
		||||
    if (dot_size < lhs_size) {
 | 
			
		||||
      operands.emplace_back(lhs, dot, 0);
 | 
			
		||||
    }
 | 
			
		||||
    if (dot_size < rhs_size) {
 | 
			
		||||
      operands.emplace_back(rhs, dot, 1);
 | 
			
		||||
    }
 | 
			
		||||
  } else if (lhs_size <= rhs_size) {
 | 
			
		||||
    target = lhs;
 | 
			
		||||
    if (lhs_size < rhs_size) {
 | 
			
		||||
      operands.emplace_back(rhs, dot, 1);
 | 
			
		||||
    }
 | 
			
		||||
    if (lhs_size < dot_size && dot->user_count() == 1) {
 | 
			
		||||
      users.push_back(dot->users().front());
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    target = rhs;
 | 
			
		||||
    if (rhs_size < lhs_size) {
 | 
			
		||||
      operands.emplace_back(lhs, dot, 0);
 | 
			
		||||
    }
 | 
			
		||||
    if (rhs_size < dot_size && dot->user_count() == 1) {
 | 
			
		||||
      users.push_back(dot->users().front());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::vector<float> values;
 | 
			
		||||
 | 
			
		||||
  // DFS to find scalar multiply ops from the operands.
 | 
			
		||||
  while (!operands.empty()) {
 | 
			
		||||
    auto [inst, user, index] = operands.back();
 | 
			
		||||
    operands.pop_back();
 | 
			
		||||
 | 
			
		||||
    // Skip the op types that are not commutative with multiply.
 | 
			
		||||
    if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    HloInstruction* operand;
 | 
			
		||||
    HloInstruction* multiplier;
 | 
			
		||||
    // Pattern match a scalar multiply.
 | 
			
		||||
    if (Match(inst, m::MultiplyAnyOrder(
 | 
			
		||||
                        m::Op(&operand),
 | 
			
		||||
                        m::Broadcast(m::ConstantScalar(&multiplier))))) {
 | 
			
		||||
      CHECK_LT(index, user->operand_count());
 | 
			
		||||
      CHECK_EQ(inst, user->operands()[index]);
 | 
			
		||||
 | 
			
		||||
      // When found a scalar multiply, save its scalar value.
 | 
			
		||||
      values.push_back(GetConstantValue(multiplier));
 | 
			
		||||
      // And remove the scalar multiply op.
 | 
			
		||||
      TF_RETURN_IF_ERROR(user->ReplaceOperandWith(index, operand));
 | 
			
		||||
      inst = operand;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Push the operands of inst.
 | 
			
		||||
    int64 i = 0;
 | 
			
		||||
    for (auto* operand : inst->operands()) {
 | 
			
		||||
      operands.emplace_back(operand, inst, i++);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // DFS to find scalar multiply ops from the users.
 | 
			
		||||
  while (!users.empty()) {
 | 
			
		||||
    auto inst = users.back();
 | 
			
		||||
    users.pop_back();
 | 
			
		||||
 | 
			
		||||
    if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    HloInstruction* operand;
 | 
			
		||||
    HloInstruction* multiplier;
 | 
			
		||||
    if (Match(inst, m::MultiplyAnyOrder(
 | 
			
		||||
                        m::Op(&operand),
 | 
			
		||||
                        m::Broadcast(m::ConstantScalar(&multiplier))))) {
 | 
			
		||||
      values.push_back(GetConstantValue(multiplier));
 | 
			
		||||
 | 
			
		||||
      TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(operand));
 | 
			
		||||
      inst = operand;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Process the instructions with only one user.
 | 
			
		||||
    // Otherwise moving scalar multiply to the operands changes the values of
 | 
			
		||||
    // other users.
 | 
			
		||||
    if (inst->user_count() == 1) {
 | 
			
		||||
      users.push_back(inst->users().front());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (values.empty()) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  changed_ = true;
 | 
			
		||||
 | 
			
		||||
  // Combine all constant multipliers.
 | 
			
		||||
  float multiplier = 1.0;
 | 
			
		||||
  for (const float v : values) {
 | 
			
		||||
    multiplier *= v;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Create a new const scalar multiply instruction.
 | 
			
		||||
  HloInstruction* new_const_inst;
 | 
			
		||||
  new_const_inst =
 | 
			
		||||
      computation_->AddInstruction(MakeScalarInstruction(target, multiplier));
 | 
			
		||||
 | 
			
		||||
  // Broadcast the scalar multiplier.
 | 
			
		||||
  HloInstruction* new_broadcast = computation_->AddInstruction(
 | 
			
		||||
      HloInstruction::CreateBroadcast(target->shape(), new_const_inst, {}));
 | 
			
		||||
  // Create a new scalar multiply instruction.
 | 
			
		||||
  HloInstruction* new_multiply =
 | 
			
		||||
      computation_->AddInstruction(HloInstruction::CreateBinary(
 | 
			
		||||
          target->shape(), HloOpcode::kMultiply, target, new_broadcast));
 | 
			
		||||
  CHECK_EQ(new_multiply->shape(), target->shape());
 | 
			
		||||
 | 
			
		||||
  // Update the dependency with the rest of the instructions.
 | 
			
		||||
  if (target == lhs) {
 | 
			
		||||
    return dot->ReplaceOperandWith(0, new_multiply);
 | 
			
		||||
  } else if (target == rhs) {
 | 
			
		||||
    return dot->ReplaceOperandWith(1, new_multiply);
 | 
			
		||||
  } else {
 | 
			
		||||
    CHECK_EQ(target, dot);
 | 
			
		||||
    return dot->ReplaceAllUsesWith(new_multiply);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction,
 | 
			
		||||
                                                    HloInstruction* operand) {
 | 
			
		||||
  CHECK_EQ(1, instruction->operand_count());
 | 
			
		||||
@ -5237,10 +5042,6 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
 | 
			
		||||
 | 
			
		||||
Status AlgebraicSimplifierVisitor::HandleConvolution(
 | 
			
		||||
    HloInstruction* convolution) {
 | 
			
		||||
  if (options_.enable_scalar_multiply_reduction()) {
 | 
			
		||||
    TF_RETURN_IF_ERROR(ScalarMultiplyReduction(convolution));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Zero-sized input or filter.
 | 
			
		||||
  if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
 | 
			
		||||
      ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
 | 
			
		||||
 | 
			
		||||
@ -86,17 +86,6 @@ class AlgebraicSimplifierOptions {
 | 
			
		||||
  }
 | 
			
		||||
  bool enable_conv_operand_swap() const { return enable_conv_operand_swap_; }
 | 
			
		||||
 | 
			
		||||
  // Move constant scalar multiply to one operand or output of convolutions with
 | 
			
		||||
  // the smallest tensor size, to reduce the number of scalar multiply.
 | 
			
		||||
  void set_enable_scalar_multiply_reduction(
 | 
			
		||||
      bool enable_scalar_multiply_reduction) {
 | 
			
		||||
    enable_scalar_multiply_reduction_ = enable_scalar_multiply_reduction;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool enable_scalar_multiply_reduction() const {
 | 
			
		||||
    return enable_scalar_multiply_reduction_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // If enable_window_reduce_replacement is true, the kReduceWindow instruction
 | 
			
		||||
  // can be optimized by replacement with simpler operations.
 | 
			
		||||
  void set_enable_window_reduce_to_reduce_replacement(
 | 
			
		||||
@ -157,7 +146,6 @@ class AlgebraicSimplifierOptions {
 | 
			
		||||
  bool enable_dot_to_multiply_rewrite_{true};
 | 
			
		||||
  bool enable_conv_simplification_{true};
 | 
			
		||||
  bool enable_conv_operand_swap_{true};
 | 
			
		||||
  bool enable_scalar_multiply_reduction_{false};
 | 
			
		||||
  bool enable_window_reduce_to_reduce_replacement_{true};
 | 
			
		||||
  bool enable_reduce_of_reshape_{true};
 | 
			
		||||
  bool replace_transpose_with_bitcast_{true};
 | 
			
		||||
 | 
			
		||||
@ -5343,59 +5343,6 @@ ENTRY AddBroadcastZeroWithDynamicSlice {
 | 
			
		||||
  EXPECT_THAT(root->operand(1)->opcode(), HloOpcode::kPad);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReduction) {
 | 
			
		||||
  const char* hlo_string = R"(
 | 
			
		||||
HloModule ConstScalarMultiply
 | 
			
		||||
ENTRY ConstScalarMultiply {
 | 
			
		||||
  param0 = f32[16,512,4096]{2,1,0} parameter(0)
 | 
			
		||||
  constant.0 = f32[] constant(0.5)
 | 
			
		||||
  broadcast.0 = f32[16,512,4096] broadcast(constant.0), dimensions={}
 | 
			
		||||
  multiply.0 = f32[16,512,4096]{2,1,0} multiply(param0, broadcast.0)
 | 
			
		||||
  param1 = f32[16,512,4096]{2,1,0} parameter(1)
 | 
			
		||||
  multiply.1 = f32[16,512,4096]{2,1,0} multiply(multiply.0, param1)
 | 
			
		||||
  param2 = f32[16,512,1024]{2,1,0} parameter(2)
 | 
			
		||||
  constant.1 = f32[] constant(1.109)
 | 
			
		||||
  broadcast.1 = f32[16,512,1024] broadcast(constant.1), dimensions={}
 | 
			
		||||
  multiply.2 = f32[16,512,1024]{2,1,0} multiply(param2, broadcast.1)
 | 
			
		||||
  ROOT convolution = f32[4096,1024,1]{1,0,2} convolution(multiply.1, multiply.2), window={size=16}, dim_labels=0fb_0io->bf0
 | 
			
		||||
}
 | 
			
		||||
)";
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(auto module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(hlo_string));
 | 
			
		||||
  AlgebraicSimplifierOptions options;
 | 
			
		||||
  options.set_enable_scalar_multiply_reduction(true);
 | 
			
		||||
  AlgebraicSimplifier simplifier(options);
 | 
			
		||||
  ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
 | 
			
		||||
  auto root = module->entry_computation()->root_instruction();
 | 
			
		||||
  EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
 | 
			
		||||
  EXPECT_THAT(root,
 | 
			
		||||
              GmockMatch(m::MultiplyAnyOrder(
 | 
			
		||||
                  m::Op(), m::Broadcast(m::ConstantScalar(0.5f * 1.109f)))));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReductionMultiUser) {
 | 
			
		||||
  const char* hlo_string = R"(
 | 
			
		||||
HloModule ConstScalarMultiply
 | 
			
		||||
ENTRY ConstScalarMultiply {
 | 
			
		||||
  param0 = f32[16,512,1024] parameter(0)
 | 
			
		||||
  param1 = f32[4096,1024,1] parameter(1)
 | 
			
		||||
  convolution = f32[16,512,4096] convolution(param0, param1), window={size=1}, dim_labels=0bf_oi0->0bf
 | 
			
		||||
  constant.1 = f32[] constant(0.5)
 | 
			
		||||
  broadcast.1 = f32[16,512,4096] broadcast(constant.1), dimensions={}
 | 
			
		||||
  multiply.1 = f32[16,512,4096] multiply(convolution, broadcast.1)
 | 
			
		||||
  param2 = f32[16,512,4096] parameter(2)
 | 
			
		||||
  multiply.2 = f32[16,512,4096] multiply(convolution, param2)
 | 
			
		||||
  ROOT add.1 = f32[16,512,4096] add(multiply.1, multiply.2)
 | 
			
		||||
}
 | 
			
		||||
)";
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(auto module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(hlo_string));
 | 
			
		||||
  AlgebraicSimplifierOptions options;
 | 
			
		||||
  options.set_enable_scalar_multiply_reduction(true);
 | 
			
		||||
  AlgebraicSimplifier simplifier(options);
 | 
			
		||||
  ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
INSTANTIATE_TEST_SUITE_P(DotOfConcatSimplificationTestInstantiation,
 | 
			
		||||
                         DotOfConcatSimplificationTest,
 | 
			
		||||
                         ::testing::ValuesIn(kDotOfConcatTestSpecs));
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user