Fix bug in cuddn_fused_conv_rewriter.
The scale for the convolution result and for the side input must be scalar constants. Before this change, any constant would be accepted.
This commit is contained in:
		
							parent
							
								
									d5544d3d7d
								
							
						
					
					
						commit
						1c14da0bc1
					
				| @ -44,7 +44,7 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) { | ||||
|   using match::AddAnyOrder; | ||||
|   using match::AnyOf; | ||||
|   using match::Broadcast; | ||||
|   using match::Constant; | ||||
|   using match::ConstantScalar; | ||||
|   using match::GetTupleElement; | ||||
|   using match::Maximum; | ||||
|   using match::MultiplyAnyOrder; | ||||
| @ -59,7 +59,7 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) { | ||||
|   HloInstruction* relu_input; | ||||
| 
 | ||||
|   // Match max(0, relu_input).
 | ||||
|   auto zero_pattern = Broadcast(match::ConstantScalar(0)); | ||||
|   auto zero_pattern = Broadcast(ConstantScalar(0)); | ||||
|   if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) && | ||||
|       !Match(instr, Maximum(Op(&relu_input), zero_pattern))) { | ||||
|     return absl::nullopt; | ||||
| @ -78,14 +78,14 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) { | ||||
| 
 | ||||
|   const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias)); | ||||
|   const auto conv_pattern = [&] { | ||||
|     auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr)); | ||||
|     auto alpha_pattern = Broadcast(ConstantScalar(&alpha_conv_instr)); | ||||
|     auto conv_pattern = GetTupleElement( | ||||
|         >e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0); | ||||
|     return AnyOf<HloInstruction>( | ||||
|         MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern); | ||||
|   }(); | ||||
|   const auto side_input_pattern = [&] { | ||||
|     auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr)); | ||||
|     auto alpha_pattern = Broadcast(ConstantScalar(&alpha_side_input_instr)); | ||||
|     // If bias is already matched, match arbitrary additional input as side
 | ||||
|     // input. Note this may force a cheap operation (e.g. broadcast) to be
 | ||||
|     // materialized into a large buffer, as large as the output buffer.
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user