Fix bug in cuddn_fused_conv_rewriter.

The scale for the convolution result and for the side input must be scalar constants.
Before this change, any constant would be accepted.
This commit is contained in:
Bas Aarts 2019-06-25 16:06:27 -07:00
parent d5544d3d7d
commit 1c14da0bc1

View File

@ -44,7 +44,7 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
using match::AddAnyOrder;
using match::AnyOf;
using match::Broadcast;
using match::Constant;
using match::ConstantScalar;
using match::GetTupleElement;
using match::Maximum;
using match::MultiplyAnyOrder;
@ -59,7 +59,7 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
HloInstruction* relu_input;
// Match max(0, relu_input).
auto zero_pattern = Broadcast(match::ConstantScalar(0));
auto zero_pattern = Broadcast(ConstantScalar(0));
if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) &&
!Match(instr, Maximum(Op(&relu_input), zero_pattern))) {
return absl::nullopt;
@ -78,14 +78,14 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias));
const auto conv_pattern = [&] {
auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr));
auto alpha_pattern = Broadcast(ConstantScalar(&alpha_conv_instr));
auto conv_pattern = GetTupleElement(
&gte, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0);
return AnyOf<HloInstruction>(
MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern);
}();
const auto side_input_pattern = [&] {
auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr));
auto alpha_pattern = Broadcast(ConstantScalar(&alpha_side_input_instr));
// If bias is already matched, match arbitrary additional input as side
// input. Note this may force a cheap operation (e.g. broadcast) to be
// materialized into a large buffer, as large as the output buffer.