Fix bug in cuddn_fused_conv_rewriter.
The scale for the convolution result and for the side input must be scalar constants. Before this change, any constant would be accepted.
This commit is contained in:
parent
d5544d3d7d
commit
1c14da0bc1
@ -44,7 +44,7 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
|
||||
using match::AddAnyOrder;
|
||||
using match::AnyOf;
|
||||
using match::Broadcast;
|
||||
using match::Constant;
|
||||
using match::ConstantScalar;
|
||||
using match::GetTupleElement;
|
||||
using match::Maximum;
|
||||
using match::MultiplyAnyOrder;
|
||||
@ -59,7 +59,7 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
|
||||
HloInstruction* relu_input;
|
||||
|
||||
// Match max(0, relu_input).
|
||||
auto zero_pattern = Broadcast(match::ConstantScalar(0));
|
||||
auto zero_pattern = Broadcast(ConstantScalar(0));
|
||||
if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) &&
|
||||
!Match(instr, Maximum(Op(&relu_input), zero_pattern))) {
|
||||
return absl::nullopt;
|
||||
@ -78,14 +78,14 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
|
||||
|
||||
const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias));
|
||||
const auto conv_pattern = [&] {
|
||||
auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr));
|
||||
auto alpha_pattern = Broadcast(ConstantScalar(&alpha_conv_instr));
|
||||
auto conv_pattern = GetTupleElement(
|
||||
>e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0);
|
||||
return AnyOf<HloInstruction>(
|
||||
MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern);
|
||||
}();
|
||||
const auto side_input_pattern = [&] {
|
||||
auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr));
|
||||
auto alpha_pattern = Broadcast(ConstantScalar(&alpha_side_input_instr));
|
||||
// If bias is already matched, match arbitrary additional input as side
|
||||
// input. Note this may force a cheap operation (e.g. broadcast) to be
|
||||
// materialized into a large buffer, as large as the output buffer.
|
||||
|
Loading…
Reference in New Issue
Block a user