From 50960383761e45183634327f72e48677d82138a7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Sep 2020 11:12:28 -0700 Subject: [PATCH] HLO creation utils API change to allow passing batch_group_count. PiperOrigin-RevId: 329955093 Change-Id: I69a3616842b7f2a9ed177ea2bd407351acbac552 --- tensorflow/compiler/xla/service/algebraic_simplifier.cc | 8 ++++---- .../compiler/xla/service/convolution_group_converter.cc | 6 ++++-- tensorflow/compiler/xla/service/hlo_creation_utils.cc | 9 +++++---- tensorflow/compiler/xla/service/hlo_creation_utils.h | 3 ++- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 3d49700c1b5..6bbde42bad9 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -5260,10 +5260,10 @@ StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( if (!reverse_dimensions.empty()) { TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions)); } - TF_ASSIGN_OR_RETURN( - HloInstruction * new_convolution, - MakeConvolveHlo(kernel, input, /*feature_group_count=*/1, swapped_window, - swapped_dnums, precision_config)); + TF_ASSIGN_OR_RETURN(HloInstruction * new_convolution, + MakeConvolveHlo(kernel, input, /*feature_group_count=*/1, + /*batch_group_count=*/1, swapped_window, + swapped_dnums, precision_config)); convolution->SetupDerivedInstruction(new_convolution); TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution)); diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index 323bf44dcd3..f5506b894fd 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -300,7 +300,8 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { window_dim->set_window_dilation(1); HloInstruction* new_convolution = MakeConvolveHlo(activation, filter, convolution->feature_group_count(), - window, dim_numbers, convolution->precision_config()) + /*batch_group_count=*/1, window, dim_numbers, + convolution->precision_config()) .ValueOrDie(); convolution->SetupDerivedInstruction(new_convolution); TF_CHECK_OK(computation_->ReplaceInstruction( @@ -649,7 +650,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { window_dim->set_window_reversal(false); window_dim->set_window_dilation(1); HloInstruction* new_convolution = - MakeConvolveHlo(activation, filter, 1, window, dim_numbers, + MakeConvolveHlo(activation, filter, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dim_numbers, convolution->precision_config()) .ValueOrDie(); convolution->SetupDerivedInstruction(new_convolution); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 4ba67888409..4aeeb6d27ac 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -92,16 +92,17 @@ StatusOr MakeSliceHlo(HloInstruction* operand, StatusOr MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + int64 batch_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape( - lhs->shape(), rhs->shape(), feature_group_count, 1, - window, dimension_numbers)); + lhs->shape(), rhs->shape(), feature_group_count, + batch_group_count, window, dimension_numbers)); return computation->AddInstruction(HloInstruction::CreateConvolve( - convolve_shape, lhs, rhs, feature_group_count, 1, window, + convolve_shape, lhs, rhs, feature_group_count, batch_group_count, window, dimension_numbers, precision_config)); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 2b17ae3d967..53eeeffb858 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -61,7 +61,8 @@ StatusOr MakeSliceHlo(HloInstruction* operand, // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). StatusOr MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + int64 batch_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config); // Creates a transpose HLO instruction and adds it to the computation containing