Merge pull request #30155 from bas-aarts:xla_cudnn_fused_conv_bug
PiperOrigin-RevId: 256442922
This commit is contained in:
commit
bedd98033b
@ -44,7 +44,7 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
|
|||||||
using match::AddAnyOrder;
|
using match::AddAnyOrder;
|
||||||
using match::AnyOf;
|
using match::AnyOf;
|
||||||
using match::Broadcast;
|
using match::Broadcast;
|
||||||
using match::Constant;
|
using match::ConstantScalar;
|
||||||
using match::GetTupleElement;
|
using match::GetTupleElement;
|
||||||
using match::Maximum;
|
using match::Maximum;
|
||||||
using match::MultiplyAnyOrder;
|
using match::MultiplyAnyOrder;
|
||||||
@ -59,7 +59,7 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
|
|||||||
HloInstruction* relu_input;
|
HloInstruction* relu_input;
|
||||||
|
|
||||||
// Match max(0, relu_input).
|
// Match max(0, relu_input).
|
||||||
auto zero_pattern = Broadcast(match::ConstantScalar(0));
|
auto zero_pattern = Broadcast(ConstantScalar(0));
|
||||||
if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) &&
|
if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) &&
|
||||||
!Match(instr, Maximum(Op(&relu_input), zero_pattern))) {
|
!Match(instr, Maximum(Op(&relu_input), zero_pattern))) {
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
@ -78,14 +78,14 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
|
|||||||
|
|
||||||
const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias));
|
const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias));
|
||||||
const auto conv_pattern = [&] {
|
const auto conv_pattern = [&] {
|
||||||
auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr));
|
auto alpha_pattern = Broadcast(ConstantScalar(&alpha_conv_instr));
|
||||||
auto conv_pattern = GetTupleElement(
|
auto conv_pattern = GetTupleElement(
|
||||||
>e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0);
|
>e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0);
|
||||||
return AnyOf<HloInstruction>(
|
return AnyOf<HloInstruction>(
|
||||||
MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern);
|
MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern);
|
||||||
}();
|
}();
|
||||||
const auto side_input_pattern = [&] {
|
const auto side_input_pattern = [&] {
|
||||||
auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr));
|
auto alpha_pattern = Broadcast(ConstantScalar(&alpha_side_input_instr));
|
||||||
// If bias is already matched, match arbitrary additional input as side
|
// If bias is already matched, match arbitrary additional input as side
|
||||||
// input. Note this may force a cheap operation (e.g. broadcast) to be
|
// input. Note this may force a cheap operation (e.g. broadcast) to be
|
||||||
// materialized into a large buffer, as large as the output buffer.
|
// materialized into a large buffer, as large as the output buffer.
|
||||||
|
Loading…
Reference in New Issue
Block a user