From 1c14da0bc125b705c491bcd2a46a0dd00b74ab30 Mon Sep 17 00:00:00 2001 From: Bas Aarts Date: Tue, 25 Jun 2019 16:06:27 -0700 Subject: [PATCH] Fix bug in cuddn_fused_conv_rewriter. The scale for the convolution result and for the side input must be scalar constants. Before this change, any constant would be accepted. --- .../compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc index cde65ad5745..dee257a5d97 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -44,7 +44,7 @@ absl::optional FindConvWithRelu(HloInstruction* instr) { using match::AddAnyOrder; using match::AnyOf; using match::Broadcast; - using match::Constant; + using match::ConstantScalar; using match::GetTupleElement; using match::Maximum; using match::MultiplyAnyOrder; @@ -59,7 +59,7 @@ absl::optional FindConvWithRelu(HloInstruction* instr) { HloInstruction* relu_input; // Match max(0, relu_input). - auto zero_pattern = Broadcast(match::ConstantScalar(0)); + auto zero_pattern = Broadcast(ConstantScalar(0)); if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) && !Match(instr, Maximum(Op(&relu_input), zero_pattern))) { return absl::nullopt; @@ -78,14 +78,14 @@ absl::optional FindConvWithRelu(HloInstruction* instr) { const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias)); const auto conv_pattern = [&] { - auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr)); + auto alpha_pattern = Broadcast(ConstantScalar(&alpha_conv_instr)); auto conv_pattern = GetTupleElement( >e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0); return AnyOf( MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern); }(); const auto side_input_pattern = [&] { - auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr)); + auto alpha_pattern = Broadcast(ConstantScalar(&alpha_side_input_instr)); // If bias is already matched, match arbitrary additional input as side // input. Note this may force a cheap operation (e.g. broadcast) to be // materialized into a large buffer, as large as the output buffer.