HLO creation utils API change to allow passing batch_group_count.

PiperOrigin-RevId: 329955093
Change-Id: I69a3616842b7f2a9ed177ea2bd407351acbac552
This commit is contained in:
A. Unique TensorFlower 2020-09-03 11:12:28 -07:00 committed by TensorFlower Gardener
parent a13c9b3d62
commit 5096038376
4 changed files with 15 additions and 11 deletions

View File

@ -5260,10 +5260,10 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
if (!reverse_dimensions.empty()) {
TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions));
}
TF_ASSIGN_OR_RETURN(
HloInstruction * new_convolution,
MakeConvolveHlo(kernel, input, /*feature_group_count=*/1, swapped_window,
swapped_dnums, precision_config));
TF_ASSIGN_OR_RETURN(HloInstruction * new_convolution,
MakeConvolveHlo(kernel, input, /*feature_group_count=*/1,
/*batch_group_count=*/1, swapped_window,
swapped_dnums, precision_config));
convolution->SetupDerivedInstruction(new_convolution);
TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution));

View File

@ -300,7 +300,8 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
window_dim->set_window_dilation(1);
HloInstruction* new_convolution =
MakeConvolveHlo(activation, filter, convolution->feature_group_count(),
window, dim_numbers, convolution->precision_config())
/*batch_group_count=*/1, window, dim_numbers,
convolution->precision_config())
.ValueOrDie();
convolution->SetupDerivedInstruction(new_convolution);
TF_CHECK_OK(computation_->ReplaceInstruction(
@ -649,7 +650,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
window_dim->set_window_reversal(false);
window_dim->set_window_dilation(1);
HloInstruction* new_convolution =
MakeConvolveHlo(activation, filter, 1, window, dim_numbers,
MakeConvolveHlo(activation, filter, /*feature_group_count=*/1,
/*batch_group_count=*/1, window, dim_numbers,
convolution->precision_config())
.ValueOrDie();
convolution->SetupDerivedInstruction(new_convolution);

View File

@ -92,16 +92,17 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
StatusOr<HloInstruction*> MakeConvolveHlo(
HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
int64 batch_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config) {
HloComputation* computation = lhs->parent();
CHECK_EQ(computation, rhs->parent());
TF_ASSIGN_OR_RETURN(Shape convolve_shape,
ShapeInference::InferConvolveShape(
lhs->shape(), rhs->shape(), feature_group_count, 1,
window, dimension_numbers));
lhs->shape(), rhs->shape(), feature_group_count,
batch_group_count, window, dimension_numbers));
return computation->AddInstruction(HloInstruction::CreateConvolve(
convolve_shape, lhs, rhs, feature_group_count, 1, window,
convolve_shape, lhs, rhs, feature_group_count, batch_group_count, window,
dimension_numbers, precision_config));
}

View File

@ -61,7 +61,8 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
// containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
StatusOr<HloInstruction*> MakeConvolveHlo(
HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
int64 batch_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config);
// Creates a transpose HLO instruction and adds it to the computation containing