Fix bug in cuddn_fused_conv_rewriter.
The scale for the convolution result and for the side input must be scalar constants. Before this change, any constant would be accepted.
This commit is contained in:
		
							parent
							
								
									d5544d3d7d
								
							
						
					
					
						commit
						1c14da0bc1
					
				@ -44,7 +44,7 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
 | 
			
		||||
  using match::AddAnyOrder;
 | 
			
		||||
  using match::AnyOf;
 | 
			
		||||
  using match::Broadcast;
 | 
			
		||||
  using match::Constant;
 | 
			
		||||
  using match::ConstantScalar;
 | 
			
		||||
  using match::GetTupleElement;
 | 
			
		||||
  using match::Maximum;
 | 
			
		||||
  using match::MultiplyAnyOrder;
 | 
			
		||||
@ -59,7 +59,7 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
 | 
			
		||||
  HloInstruction* relu_input;
 | 
			
		||||
 | 
			
		||||
  // Match max(0, relu_input).
 | 
			
		||||
  auto zero_pattern = Broadcast(match::ConstantScalar(0));
 | 
			
		||||
  auto zero_pattern = Broadcast(ConstantScalar(0));
 | 
			
		||||
  if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) &&
 | 
			
		||||
      !Match(instr, Maximum(Op(&relu_input), zero_pattern))) {
 | 
			
		||||
    return absl::nullopt;
 | 
			
		||||
@ -78,14 +78,14 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
 | 
			
		||||
 | 
			
		||||
  const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias));
 | 
			
		||||
  const auto conv_pattern = [&] {
 | 
			
		||||
    auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr));
 | 
			
		||||
    auto alpha_pattern = Broadcast(ConstantScalar(&alpha_conv_instr));
 | 
			
		||||
    auto conv_pattern = GetTupleElement(
 | 
			
		||||
        >e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0);
 | 
			
		||||
    return AnyOf<HloInstruction>(
 | 
			
		||||
        MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern);
 | 
			
		||||
  }();
 | 
			
		||||
  const auto side_input_pattern = [&] {
 | 
			
		||||
    auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr));
 | 
			
		||||
    auto alpha_pattern = Broadcast(ConstantScalar(&alpha_side_input_instr));
 | 
			
		||||
    // If bias is already matched, match arbitrary additional input as side
 | 
			
		||||
    // input. Note this may force a cheap operation (e.g. broadcast) to be
 | 
			
		||||
    // materialized into a large buffer, as large as the output buffer.
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user