HLO creation utils API change to allow passing batch_group_count.
PiperOrigin-RevId: 329955093 Change-Id: I69a3616842b7f2a9ed177ea2bd407351acbac552
This commit is contained in:
parent
a13c9b3d62
commit
5096038376
@ -5260,10 +5260,10 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
|
||||
if (!reverse_dimensions.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * new_convolution,
|
||||
MakeConvolveHlo(kernel, input, /*feature_group_count=*/1, swapped_window,
|
||||
swapped_dnums, precision_config));
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_convolution,
|
||||
MakeConvolveHlo(kernel, input, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, swapped_window,
|
||||
swapped_dnums, precision_config));
|
||||
|
||||
convolution->SetupDerivedInstruction(new_convolution);
|
||||
TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution));
|
||||
|
||||
@ -300,7 +300,8 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
|
||||
window_dim->set_window_dilation(1);
|
||||
HloInstruction* new_convolution =
|
||||
MakeConvolveHlo(activation, filter, convolution->feature_group_count(),
|
||||
window, dim_numbers, convolution->precision_config())
|
||||
/*batch_group_count=*/1, window, dim_numbers,
|
||||
convolution->precision_config())
|
||||
.ValueOrDie();
|
||||
convolution->SetupDerivedInstruction(new_convolution);
|
||||
TF_CHECK_OK(computation_->ReplaceInstruction(
|
||||
@ -649,7 +650,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
||||
window_dim->set_window_reversal(false);
|
||||
window_dim->set_window_dilation(1);
|
||||
HloInstruction* new_convolution =
|
||||
MakeConvolveHlo(activation, filter, 1, window, dim_numbers,
|
||||
MakeConvolveHlo(activation, filter, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, window, dim_numbers,
|
||||
convolution->precision_config())
|
||||
.ValueOrDie();
|
||||
convolution->SetupDerivedInstruction(new_convolution);
|
||||
|
||||
@ -92,16 +92,17 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
|
||||
|
||||
StatusOr<HloInstruction*> MakeConvolveHlo(
|
||||
HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
|
||||
const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 batch_group_count, const Window& window,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
const PrecisionConfig& precision_config) {
|
||||
HloComputation* computation = lhs->parent();
|
||||
CHECK_EQ(computation, rhs->parent());
|
||||
TF_ASSIGN_OR_RETURN(Shape convolve_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
lhs->shape(), rhs->shape(), feature_group_count, 1,
|
||||
window, dimension_numbers));
|
||||
lhs->shape(), rhs->shape(), feature_group_count,
|
||||
batch_group_count, window, dimension_numbers));
|
||||
return computation->AddInstruction(HloInstruction::CreateConvolve(
|
||||
convolve_shape, lhs, rhs, feature_group_count, 1, window,
|
||||
convolve_shape, lhs, rhs, feature_group_count, batch_group_count, window,
|
||||
dimension_numbers, precision_config));
|
||||
}
|
||||
|
||||
|
||||
@ -61,7 +61,8 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
|
||||
// containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
|
||||
StatusOr<HloInstruction*> MakeConvolveHlo(
|
||||
HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
|
||||
const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 batch_group_count, const Window& window,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
const PrecisionConfig& precision_config);
|
||||
|
||||
// Creates a transpose HLO instruction and adds it to the computation containing
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user