Merge pull request #30155 from bas-aarts:xla_cudnn_fused_conv_bug

PiperOrigin-RevId: 256442922
This commit is contained in:
TensorFlower Gardener 2019-07-03 17:48:14 -07:00
commit bedd98033b

View File

@ -44,7 +44,7 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
using match::AddAnyOrder; using match::AddAnyOrder;
using match::AnyOf; using match::AnyOf;
using match::Broadcast; using match::Broadcast;
using match::Constant; using match::ConstantScalar;
using match::GetTupleElement; using match::GetTupleElement;
using match::Maximum; using match::Maximum;
using match::MultiplyAnyOrder; using match::MultiplyAnyOrder;
@ -59,7 +59,7 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
HloInstruction* relu_input; HloInstruction* relu_input;
// Match max(0, relu_input). // Match max(0, relu_input).
auto zero_pattern = Broadcast(match::ConstantScalar(0)); auto zero_pattern = Broadcast(ConstantScalar(0));
if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) && if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) &&
!Match(instr, Maximum(Op(&relu_input), zero_pattern))) { !Match(instr, Maximum(Op(&relu_input), zero_pattern))) {
return absl::nullopt; return absl::nullopt;
@ -78,14 +78,14 @@ absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias)); const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias));
const auto conv_pattern = [&] { const auto conv_pattern = [&] {
auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr)); auto alpha_pattern = Broadcast(ConstantScalar(&alpha_conv_instr));
auto conv_pattern = GetTupleElement( auto conv_pattern = GetTupleElement(
&gte, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0); &gte, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0);
return AnyOf<HloInstruction>( return AnyOf<HloInstruction>(
MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern); MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern);
}(); }();
const auto side_input_pattern = [&] { const auto side_input_pattern = [&] {
auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr)); auto alpha_pattern = Broadcast(ConstantScalar(&alpha_side_input_instr));
// If bias is already matched, match arbitrary additional input as side // If bias is already matched, match arbitrary additional input as side
// input. Note this may force a cheap operation (e.g. broadcast) to be // input. Note this may force a cheap operation (e.g. broadcast) to be
// materialized into a large buffer, as large as the output buffer. // materialized into a large buffer, as large as the output buffer.