687 lines
28 KiB
C++
687 lines
28 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
|
|
|
|
#include <algorithm>
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include "absl/memory/memory.h"
|
|
#include "tensorflow/compiler/xla/literal.h"
|
|
#include "tensorflow/compiler/xla/literal_util.h"
|
|
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
|
#include "tensorflow/compiler/xla/shape_util.h"
|
|
#include "tensorflow/compiler/xla/status_macros.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
#include "tensorflow/compiler/xla/util.h"
|
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
|
#include "tensorflow/core/lib/core/errors.h"
|
|
#include "tensorflow/core/lib/core/status.h"
|
|
#include "tensorflow/core/platform/logging.h"
|
|
|
|
namespace xla {
|
|
|
|
namespace {
|
|
|
|
// ConvolutionVisitor traverses the HLO computation and rewrites Convolution
|
|
// operations with feature_group_count > 1 into convolutions with
|
|
// feature_group_count = 1.
|
|
class ConvolutionVisitor : public DfsHloVisitorWithDefault {
|
|
public:
|
|
// Default visitor action is to do nothing and return OK.
|
|
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
|
|
return Status::OK();
|
|
}
|
|
|
|
Status HandleConvolution(HloInstruction* convolution) override;
|
|
|
|
Status HandleBatchGroupCount(HloInstruction* convolution);
|
|
|
|
// Runs the visitor on a computation.
|
|
static bool Run(HloComputation* computation,
|
|
std::function<bool(HloInstruction*)> is_cost_viable,
|
|
bool convert_batch_groups_only, bool filter_expansion);
|
|
|
|
// Returns whether any convolution ops were rewritten.
|
|
const bool changed() const { return changed_; }
|
|
|
|
~ConvolutionVisitor() override = default;
|
|
|
|
private:
|
|
explicit ConvolutionVisitor(
|
|
HloComputation* computation,
|
|
std::function<bool(HloInstruction*)> is_cost_viable,
|
|
bool convert_batch_groups_only, bool filter_expansion)
|
|
: computation_(computation),
|
|
filter_expansion_(filter_expansion),
|
|
convert_batch_groups_only_(convert_batch_groups_only),
|
|
is_cost_viable_(is_cost_viable) {}
|
|
|
|
// Current HloComputation instance the ConvolutionVisitor is traversing.
|
|
HloComputation* computation_;
|
|
|
|
// Whether rewrite has occurred.
|
|
bool changed_ = false;
|
|
|
|
// Whether filter expansion is required.
|
|
bool filter_expansion_;
|
|
|
|
// Decides whether to convert batch groups or feature groups.
|
|
bool convert_batch_groups_only_;
|
|
|
|
// std::function<std::vector<LloValue*>(int64, int64)> chunk_fetcher
|
|
std::function<bool(HloInstruction*)> is_cost_viable_;
|
|
};
|
|
|
|
bool ConvolutionVisitor::Run(
|
|
HloComputation* computation,
|
|
std::function<bool(HloInstruction*)> is_cost_viable,
|
|
bool convert_batch_groups_only, bool filter_expansion) {
|
|
ConvolutionVisitor visitor(computation, is_cost_viable,
|
|
convert_batch_groups_only, filter_expansion);
|
|
TF_CHECK_OK(computation->Accept(&visitor));
|
|
return visitor.changed_;
|
|
}
|
|
|
|
Shape ExpandedFilterShape(const Shape& shape, int64 group_count,
|
|
int64 input_feature_dim) {
|
|
int64 num_dims = shape.dimensions_size();
|
|
CHECK_GE(num_dims, 2);
|
|
Shape expanded_shape = shape;
|
|
expanded_shape.set_dimensions(
|
|
input_feature_dim, shape.dimensions(input_feature_dim) * group_count);
|
|
return expanded_shape;
|
|
}
|
|
|
|
// Returns a vector with 'group_count' many groups, where the i-th group
|
|
// consists of 'group_size' times the value i.
|
|
std::vector<int32> GetMaskIds(int64 group_size, int64 group_count) {
|
|
std::vector<int32> values;
|
|
for (int i = 0; i < group_count; ++i) {
|
|
for (int j = 0; j < group_size; ++j) {
|
|
values.push_back(i);
|
|
}
|
|
}
|
|
return values;
|
|
}
|
|
|
|
// Create a mask for grouped convolution that will make a normal convolution
|
|
// produce the same results as a grouped convolution. For a [2, 1, 6]
|
|
// filter this returns a [2, 3, 6] mask
|
|
// 1 1 0 0 0 0
|
|
// 0 0 1 1 0 0
|
|
// 0 0 0 0 1 1
|
|
//
|
|
// 1 1 0 0 0 0
|
|
// 0 0 1 1 0 0
|
|
// 0 0 0 0 1 1
|
|
//
|
|
// The first step is to create a rank 1 constant:
|
|
// 0 1 2
|
|
//
|
|
// This is broadcasted to
|
|
// 0 0 0 0 0 0
|
|
// 1 1 1 1 1 1
|
|
// 2 2 2 2 2 2
|
|
//
|
|
// 0 0 0 0 0 0
|
|
// 1 1 1 1 1 1
|
|
// 2 2 2 2 2 2
|
|
//
|
|
// Then we create another rank 1 constant
|
|
// 0 0 1 1 2 2
|
|
//
|
|
// This is broadcasted to
|
|
// 0 0 1 1 2 2
|
|
// 0 0 1 1 2 2
|
|
// 0 0 1 1 2 2
|
|
//
|
|
// 0 0 1 1 2 2
|
|
// 0 0 1 1 2 2
|
|
// 0 0 1 1 2 2
|
|
//
|
|
// Finally we use the Eq op of these two broadcasted constants and get the
|
|
// desired mask.
|
|
HloInstruction* GetExpandedFilterMask(
|
|
const Shape& filter_shape, int64 kernel_input_feature_dim,
|
|
int64 kernel_output_feature_dim, int64 group_count,
|
|
const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
|
|
add_instruction) {
|
|
Shape expanded_filter_shape =
|
|
ExpandedFilterShape(filter_shape, group_count, kernel_input_feature_dim);
|
|
Shape mask_shape = ShapeUtil::MakeShape(
|
|
S32, AsInt64Slice(expanded_filter_shape.dimensions()));
|
|
int64 output_feature = filter_shape.dimensions(kernel_output_feature_dim);
|
|
int64 group_size = filter_shape.dimensions(kernel_input_feature_dim);
|
|
|
|
// Create a 'input_feature' sized linspace and 'output_feature' sized linspace
|
|
// that will be broadcasted into perpendicular dimensions and compared.
|
|
const std::vector<int32> input_feature_filter_mask =
|
|
GetMaskIds(group_size, group_count);
|
|
const std::vector<int32> output_feature_filter_mask =
|
|
GetMaskIds(output_feature / group_count, group_count);
|
|
auto mask1 = add_instruction(HloInstruction::CreateConstant(
|
|
LiteralUtil::CreateR1<int32>(input_feature_filter_mask)));
|
|
auto broadcasted_mask1 = add_instruction(HloInstruction::CreateBroadcast(
|
|
mask_shape, mask1, {kernel_input_feature_dim}));
|
|
auto mask2 = add_instruction(HloInstruction::CreateConstant(
|
|
LiteralUtil::CreateR1<int32>(output_feature_filter_mask)));
|
|
auto broadcasted_mask2 = add_instruction(HloInstruction::CreateBroadcast(
|
|
mask_shape, mask2, {kernel_output_feature_dim}));
|
|
|
|
// Compare the broadcasted output feature linspace to the input feature
|
|
// linspace to create a diagonal predicate.
|
|
Shape predicate_shape = ShapeUtil::MakeShape(
|
|
PRED, AsInt64Slice(expanded_filter_shape.dimensions()));
|
|
return add_instruction(HloInstruction::CreateCompare(
|
|
predicate_shape, broadcasted_mask1, broadcasted_mask2,
|
|
ComparisonDirection::kEq));
|
|
}
|
|
|
|
// This function handles batch_group_counts which are relevant only for
|
|
// depthwise backprop filter convolutions.
|
|
Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
|
|
auto dim_numbers = convolution->convolution_dimension_numbers();
|
|
auto activation = convolution->mutable_operand(0);
|
|
auto filter = convolution->mutable_operand(1);
|
|
int64 batch_group_count = convolution->batch_group_count();
|
|
|
|
if (batch_group_count == 1) {
|
|
return Status::OK();
|
|
}
|
|
|
|
VLOG(2) << "Dealing with batch_group_count " << batch_group_count
|
|
<< " for convolution " << convolution->ToString() << "\n";
|
|
|
|
auto add = [&](std::unique_ptr<HloInstruction> inst) {
|
|
return computation_->AddInstruction(std::move(inst));
|
|
};
|
|
|
|
int64 input_batch_dimension = dim_numbers.input_batch_dimension();
|
|
const int64 input_feature_dimension = dim_numbers.input_feature_dimension();
|
|
|
|
int64 output_batch_dimension = dim_numbers.output_batch_dimension();
|
|
int64 output_feature_dimension = dim_numbers.output_feature_dimension();
|
|
|
|
const int64 kernel_input_feature_dimension =
|
|
dim_numbers.kernel_input_feature_dimension();
|
|
const int64 kernel_output_feature_dimension =
|
|
dim_numbers.kernel_output_feature_dimension();
|
|
|
|
const int64 input_batch =
|
|
activation->shape().dimensions(input_batch_dimension);
|
|
const int64 output_feature =
|
|
filter->shape().dimensions(kernel_output_feature_dimension);
|
|
|
|
if (output_feature != batch_group_count || input_batch != batch_group_count) {
|
|
// Insert a spatial dimension to the activation before the input batch
|
|
// dimension to represent the batch group.
|
|
std::vector<int64> input_sizes(activation->shape().dimensions().begin(),
|
|
activation->shape().dimensions().end());
|
|
input_sizes[input_batch_dimension] /= batch_group_count;
|
|
input_sizes.insert(input_sizes.begin() + input_batch_dimension,
|
|
batch_group_count);
|
|
activation = MakeReshapeHlo(input_sizes, activation).ValueOrDie();
|
|
for (auto& d : *dim_numbers.mutable_input_spatial_dimensions()) {
|
|
if (d > input_batch_dimension) {
|
|
++d;
|
|
}
|
|
}
|
|
dim_numbers.add_input_spatial_dimensions(input_batch_dimension);
|
|
dim_numbers.set_input_batch_dimension(input_batch_dimension + 1);
|
|
if (input_feature_dimension > input_batch_dimension) {
|
|
dim_numbers.set_input_feature_dimension(input_feature_dimension + 1);
|
|
}
|
|
|
|
// Insert a spatial dimension to the kernel before the output feature
|
|
// dimension to represent the batch group.
|
|
std::vector<int64> kernel_sizes(filter->shape().dimensions().begin(),
|
|
filter->shape().dimensions().end());
|
|
kernel_sizes[kernel_output_feature_dimension] /= batch_group_count;
|
|
kernel_sizes.insert(kernel_sizes.begin() + kernel_output_feature_dimension,
|
|
batch_group_count);
|
|
filter = MakeReshapeHlo(kernel_sizes, filter).ValueOrDie();
|
|
for (auto& d : *dim_numbers.mutable_kernel_spatial_dimensions()) {
|
|
if (d > kernel_output_feature_dimension) {
|
|
++d;
|
|
}
|
|
}
|
|
dim_numbers.add_kernel_spatial_dimensions(kernel_output_feature_dimension);
|
|
dim_numbers.set_kernel_output_feature_dimension(
|
|
kernel_output_feature_dimension + 1);
|
|
if (kernel_input_feature_dimension > kernel_output_feature_dimension) {
|
|
dim_numbers.set_kernel_input_feature_dimension(
|
|
kernel_input_feature_dimension + 1);
|
|
}
|
|
|
|
// Insert a spatial dimension to the output before the output feature
|
|
// dimension to represent the batch group.
|
|
for (auto& d : *dim_numbers.mutable_output_spatial_dimensions()) {
|
|
if (d > output_feature_dimension) {
|
|
++d;
|
|
}
|
|
}
|
|
dim_numbers.add_output_spatial_dimensions(output_feature_dimension);
|
|
dim_numbers.set_output_feature_dimension(output_feature_dimension + 1);
|
|
if (output_batch_dimension > output_feature_dimension) {
|
|
dim_numbers.set_output_batch_dimension(output_batch_dimension + 1);
|
|
}
|
|
|
|
// To represent a batch group count of 3 you can slide a 3 wide window
|
|
// [X Y Z]
|
|
// across [A 0 0 B 0 0 C] with stride 2 to produce
|
|
// [AX+0Y+0Z 0X+BY+0Z 0X+0Y+CZ] -> [AX BY CZ] which will behave the same as
|
|
// a batch group count.
|
|
Window window = convolution->window();
|
|
auto window_dim = window.add_dimensions();
|
|
window_dim->set_base_dilation(batch_group_count);
|
|
window_dim->set_size(batch_group_count);
|
|
window_dim->set_stride(batch_group_count - 1);
|
|
window_dim->set_padding_low(0);
|
|
window_dim->set_padding_high(0);
|
|
window_dim->set_window_reversal(false);
|
|
window_dim->set_window_dilation(1);
|
|
HloInstruction* new_convolution =
|
|
MakeConvolveHlo(
|
|
activation, filter, convolution->feature_group_count(),
|
|
/*batch_group_count=*/1, window, dim_numbers,
|
|
convolution->precision_config(),
|
|
/*preferred_element_type=*/convolution->shape().element_type())
|
|
.ValueOrDie();
|
|
convolution->SetupDerivedInstruction(new_convolution);
|
|
TF_CHECK_OK(computation_->ReplaceInstruction(
|
|
convolution,
|
|
MakeReshapeHlo(convolution->shape(), new_convolution).ValueOrDie()));
|
|
changed_ = true;
|
|
return Status::OK();
|
|
}
|
|
|
|
VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution);
|
|
const bool cost_too_high = !is_cost_viable_(convolution);
|
|
if (cost_too_high || filter_expansion_) {
|
|
// We first obtain the expanded the filter (which is the convolution
|
|
// output). The batch dimension is the expanded one (which originally
|
|
// represents kernel input feature dimension). We mask the filter to zero
|
|
// out the expanded regions. Next we reduce the filter in the batch
|
|
// dimension to obtain the original filter size.
|
|
|
|
HloInstruction* filter_mask =
|
|
GetExpandedFilterMask(convolution->shape(), output_batch_dimension,
|
|
output_feature_dimension, batch_group_count, add);
|
|
auto expanded_filter_shape = ExpandedFilterShape(
|
|
convolution->shape(), batch_group_count, output_batch_dimension);
|
|
|
|
VLOG(2) << "output_batch_dimension " << output_batch_dimension;
|
|
VLOG(2) << "New output shape of convolution "
|
|
<< expanded_filter_shape.ToString();
|
|
|
|
auto new_convolution = add(HloInstruction::CreateConvolve(
|
|
expanded_filter_shape, activation, filter,
|
|
/*feature_group_count=*/1, /*batch_group_count=*/1,
|
|
convolution->window(), dim_numbers, convolution->precision_config()));
|
|
|
|
VLOG(2) << "Expanded convolution " << new_convolution->ToString();
|
|
|
|
auto zero = add(HloInstruction::CreateConstant(
|
|
LiteralUtil::Zero(expanded_filter_shape.element_type())));
|
|
auto zero_filter =
|
|
add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
|
|
|
|
auto new_filter = add(HloInstruction::CreateTernary(
|
|
expanded_filter_shape, HloOpcode::kSelect, filter_mask, new_convolution,
|
|
zero_filter));
|
|
|
|
PrimitiveType reduce_type = new_filter->shape().element_type();
|
|
auto reduce_window_shape = new_convolution->shape();
|
|
reduce_window_shape.set_dimensions(output_batch_dimension, 1);
|
|
|
|
// Ensure that data input to reduce window uses at least 32 bits.
|
|
if (primitive_util::BitWidth(reduce_type) < primitive_util::BitWidth(F32)) {
|
|
reduce_type = F32;
|
|
reduce_window_shape.set_element_type(F32);
|
|
Shape convert_shape = new_filter->shape();
|
|
convert_shape.set_element_type(F32);
|
|
new_filter =
|
|
add(HloInstruction::CreateConvert(convert_shape, new_filter));
|
|
}
|
|
|
|
auto zero_literal = LiteralUtil::Zero(reduce_type);
|
|
auto zero_scalar =
|
|
add(HloInstruction::CreateConstant(std::move(zero_literal)));
|
|
|
|
auto reduce_function = [&]() -> HloComputation* {
|
|
HloComputation::Builder b("add_computation");
|
|
Shape shape = ShapeUtil::MakeShape(reduce_type, {});
|
|
auto lhs =
|
|
b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs"));
|
|
auto rhs =
|
|
b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs"));
|
|
auto scalar_op = b.AddInstruction(
|
|
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs, rhs));
|
|
return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
|
|
};
|
|
|
|
// Create the reduce window.
|
|
Window window;
|
|
for (int64 i = 0; i < new_convolution->shape().dimensions_size(); ++i) {
|
|
auto* dim = window.add_dimensions();
|
|
dim->set_padding_low(0);
|
|
dim->set_padding_high(0);
|
|
dim->set_window_dilation(1);
|
|
dim->set_base_dilation(1);
|
|
if (i == output_batch_dimension) {
|
|
dim->set_stride(batch_group_count);
|
|
dim->set_size(batch_group_count);
|
|
} else {
|
|
dim->set_stride(1);
|
|
dim->set_size(1);
|
|
}
|
|
}
|
|
auto reduce_window = add(HloInstruction::CreateReduceWindow(
|
|
reduce_window_shape, new_filter, zero_scalar, window,
|
|
reduce_function()));
|
|
|
|
Shape convert_back_shape = reduce_window->shape();
|
|
convert_back_shape.set_element_type(activation->shape().element_type());
|
|
|
|
// Convert reduced data back to the original data type.
|
|
auto reduce_window_converted =
|
|
HloInstruction::CreateConvert(convert_back_shape, reduce_window);
|
|
|
|
TF_CHECK_OK(computation_->ReplaceWithNewInstruction(
|
|
convolution, std::move(reduce_window_converted)));
|
|
changed_ = true;
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
|
if (convert_batch_groups_only_) {
|
|
return HandleBatchGroupCount(convolution);
|
|
}
|
|
|
|
auto add = [&](std::unique_ptr<HloInstruction> inst) {
|
|
return computation_->AddInstruction(std::move(inst));
|
|
};
|
|
|
|
int64 group_count = convolution->feature_group_count();
|
|
if (group_count == 1) {
|
|
return Status::OK();
|
|
}
|
|
|
|
changed_ = true;
|
|
ConvolutionDimensionNumbers dim_numbers =
|
|
convolution->convolution_dimension_numbers();
|
|
auto filter = convolution->mutable_operand(1);
|
|
int64 kernel_input_feature_dim = dim_numbers.kernel_input_feature_dimension();
|
|
int64 group_size = filter->shape().dimensions(kernel_input_feature_dim);
|
|
int64 kernel_output_feature_dim =
|
|
dim_numbers.kernel_output_feature_dimension();
|
|
auto expanded_filter_shape = ExpandedFilterShape(filter->shape(), group_count,
|
|
kernel_input_feature_dim);
|
|
HloInstruction* filter_mask =
|
|
GetExpandedFilterMask(filter->shape(), kernel_input_feature_dim,
|
|
kernel_output_feature_dim, group_count, add);
|
|
HloInstruction* expanded_filter;
|
|
|
|
if (group_size == 1) {
|
|
bool depthwise_separable =
|
|
(group_count == filter->shape().dimensions(kernel_output_feature_dim));
|
|
// If the code generator handles depthwise separable convolutions
|
|
// inherently, then no filter expansion is needed.
|
|
if (!filter_expansion_ && depthwise_separable) {
|
|
changed_ = false;
|
|
return Status::OK();
|
|
}
|
|
VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution);
|
|
// We want to repeat 'filter' in the 'input_feature_dim' dimension
|
|
// 'group_count' times.
|
|
if (!is_cost_viable_(convolution) || filter_expansion_) {
|
|
Shape reshaped_filter_shape =
|
|
ShapeUtil::DeleteDimension(kernel_input_feature_dim, filter->shape());
|
|
auto reshaped_filter =
|
|
add(HloInstruction::CreateReshape(reshaped_filter_shape, filter));
|
|
std::vector<int64> broadcast_dims;
|
|
for (int64 i = 0; i < filter->shape().dimensions_size(); ++i) {
|
|
if (i == kernel_input_feature_dim) {
|
|
continue;
|
|
}
|
|
broadcast_dims.push_back(i);
|
|
}
|
|
expanded_filter = add(HloInstruction::CreateBroadcast(
|
|
expanded_filter_shape, reshaped_filter, broadcast_dims));
|
|
|
|
auto zero = add(HloInstruction::CreateConstant(
|
|
LiteralUtil::Zero(expanded_filter_shape.element_type())));
|
|
auto zero_filter =
|
|
add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
|
|
auto new_filter = add(HloInstruction::CreateTernary(
|
|
expanded_filter_shape, HloOpcode::kSelect, filter_mask,
|
|
expanded_filter, zero_filter));
|
|
|
|
auto new_convolution = HloInstruction::CreateConvolve(
|
|
convolution->shape(), convolution->mutable_operand(0), new_filter,
|
|
/*feature_group_count=*/1, /*batch_group_count=*/1,
|
|
convolution->window(), dim_numbers, convolution->precision_config());
|
|
return computation_->ReplaceWithNewInstruction(
|
|
convolution, std::move(new_convolution));
|
|
}
|
|
// Add a spatial dimension to emulate a larger output feature dimension
|
|
// to avoid creating a convolution with group_count = 1.
|
|
std::vector<int64> new_filter_dimension;
|
|
new_filter_dimension.reserve(filter->shape().rank() + 1);
|
|
const int64 depthwise_multiplier =
|
|
filter->shape().dimensions(kernel_output_feature_dim) / group_count;
|
|
// Split the kernel output feature dimension into group count and
|
|
// depthwise mutilipler.
|
|
for (int64 i = 0; i < filter->shape().rank(); ++i) {
|
|
if (i == kernel_output_feature_dim) {
|
|
new_filter_dimension.push_back(group_count);
|
|
new_filter_dimension.push_back(depthwise_multiplier);
|
|
} else {
|
|
new_filter_dimension.push_back(filter->shape().dimensions(i));
|
|
}
|
|
}
|
|
if (kernel_input_feature_dim > kernel_output_feature_dim) {
|
|
dim_numbers.set_kernel_input_feature_dimension(kernel_input_feature_dim +
|
|
1);
|
|
}
|
|
for (auto& dim : *dim_numbers.mutable_kernel_spatial_dimensions()) {
|
|
if (dim > kernel_output_feature_dim) {
|
|
++dim;
|
|
}
|
|
}
|
|
dim_numbers.add_kernel_spatial_dimensions(kernel_output_feature_dim + 1);
|
|
HloInstruction* new_filter =
|
|
computation_->AddInstruction(HloInstruction::CreateReshape(
|
|
ShapeUtil::MakeShape(filter->shape().element_type(),
|
|
new_filter_dimension),
|
|
filter));
|
|
|
|
auto new_activation_shape = convolution->operand(0)->shape();
|
|
dim_numbers.add_input_spatial_dimensions(new_activation_shape.rank());
|
|
|
|
// Create and activations spatial dimension of size 1 with a reversed
|
|
// window and high and low padding equal to the depthwise_multiplier -1.
|
|
// This emulates a larger output feature dimension with an extra spatial
|
|
// dimension.
|
|
ShapeUtil::AppendMajorDimension(1, &new_activation_shape);
|
|
HloInstruction* new_activation =
|
|
computation_->AddInstruction(HloInstruction::CreateReshape(
|
|
new_activation_shape, convolution->mutable_operand(0)));
|
|
auto new_window = convolution->window();
|
|
auto new_dim = new_window.add_dimensions();
|
|
new_dim->set_size(depthwise_multiplier);
|
|
new_dim->set_window_reversal(true);
|
|
new_dim->set_padding_low(depthwise_multiplier - 1);
|
|
new_dim->set_padding_high(depthwise_multiplier - 1);
|
|
new_dim->set_stride(1);
|
|
new_dim->set_window_dilation(1);
|
|
new_dim->set_base_dilation(1);
|
|
|
|
// Split the output feature dimension into and output feature of group
|
|
// count and depthwise multipler as an output spatial dimension.
|
|
std::vector<int64> new_output_dimension;
|
|
new_output_dimension.reserve(convolution->shape().rank() + 1);
|
|
for (int64 i = 0; i < convolution->shape().rank(); ++i) {
|
|
if (i == dim_numbers.output_feature_dimension()) {
|
|
new_output_dimension.push_back(group_count);
|
|
new_output_dimension.push_back(depthwise_multiplier);
|
|
} else {
|
|
new_output_dimension.push_back(convolution->shape().dimensions(i));
|
|
}
|
|
}
|
|
if (dim_numbers.output_batch_dimension() >
|
|
dim_numbers.output_feature_dimension()) {
|
|
dim_numbers.set_output_batch_dimension(
|
|
dim_numbers.output_batch_dimension() + 1);
|
|
}
|
|
for (auto& dim : *dim_numbers.mutable_output_spatial_dimensions()) {
|
|
if (dim > dim_numbers.output_feature_dimension()) {
|
|
++dim;
|
|
}
|
|
}
|
|
dim_numbers.add_output_spatial_dimensions(
|
|
dim_numbers.output_feature_dimension() + 1);
|
|
auto new_convolution_output_shape = ShapeUtil::MakeShape(
|
|
convolution->shape().element_type(), new_output_dimension);
|
|
HloInstruction* new_convolution =
|
|
computation_->AddInstruction(HloInstruction::CreateConvolve(
|
|
new_convolution_output_shape, new_activation, new_filter,
|
|
/*feature_group_count=*/group_count, /*batch_group_count=*/1,
|
|
new_window, dim_numbers, convolution->precision_config()));
|
|
return computation_->ReplaceWithNewInstruction(
|
|
convolution,
|
|
HloInstruction::CreateReshape(convolution->shape(), new_convolution));
|
|
}
|
|
|
|
// Implement general grouped convolution using an extra spatial dimension to
|
|
// represent the feature group count.
|
|
//
|
|
// Insert a spatial dimension to the input before the input feature
|
|
// dimension to represent the feature group.
|
|
HloInstruction* activation = convolution->mutable_operand(0);
|
|
std::vector<int64> input_sizes(activation->shape().dimensions().begin(),
|
|
activation->shape().dimensions().end());
|
|
const int64 input_feature_dimension = dim_numbers.input_feature_dimension();
|
|
input_sizes[input_feature_dimension] /= group_count;
|
|
input_sizes.insert(input_sizes.begin() + input_feature_dimension,
|
|
group_count);
|
|
activation = MakeReshapeHlo(input_sizes, activation).ValueOrDie();
|
|
for (auto& d : *dim_numbers.mutable_input_spatial_dimensions()) {
|
|
if (d > input_feature_dimension) {
|
|
++d;
|
|
}
|
|
}
|
|
dim_numbers.add_input_spatial_dimensions(input_feature_dimension);
|
|
dim_numbers.set_input_feature_dimension(input_feature_dimension + 1);
|
|
if (dim_numbers.input_batch_dimension() > input_feature_dimension) {
|
|
dim_numbers.set_input_batch_dimension(dim_numbers.input_batch_dimension() +
|
|
1);
|
|
}
|
|
|
|
// Insert a spatial dimension to the kernel before the output feature
|
|
// dimension to represent the feature group.
|
|
std::vector<int64> kernel_sizes(filter->shape().dimensions().begin(),
|
|
filter->shape().dimensions().end());
|
|
const int64 kernel_output_feature_dimension =
|
|
dim_numbers.kernel_output_feature_dimension();
|
|
kernel_sizes[kernel_output_feature_dimension] /= group_count;
|
|
kernel_sizes.insert(kernel_sizes.begin() + kernel_output_feature_dimension,
|
|
group_count);
|
|
filter = MakeReshapeHlo(kernel_sizes, filter).ValueOrDie();
|
|
for (auto& d : *dim_numbers.mutable_kernel_spatial_dimensions()) {
|
|
if (d > kernel_output_feature_dimension) {
|
|
++d;
|
|
}
|
|
}
|
|
dim_numbers.add_kernel_spatial_dimensions(kernel_output_feature_dimension);
|
|
dim_numbers.set_kernel_output_feature_dimension(
|
|
kernel_output_feature_dimension + 1);
|
|
if (dim_numbers.kernel_input_feature_dimension() >
|
|
kernel_output_feature_dimension) {
|
|
dim_numbers.set_kernel_input_feature_dimension(
|
|
dim_numbers.kernel_input_feature_dimension() + 1);
|
|
}
|
|
|
|
// Insert a spatial dimension to the output before the output feature
|
|
// dimension to represent the feature group.
|
|
const int64 output_feature_dimension = dim_numbers.output_feature_dimension();
|
|
for (auto& d : *dim_numbers.mutable_output_spatial_dimensions()) {
|
|
if (d > output_feature_dimension) {
|
|
++d;
|
|
}
|
|
}
|
|
dim_numbers.add_output_spatial_dimensions(output_feature_dimension);
|
|
dim_numbers.set_output_feature_dimension(output_feature_dimension + 1);
|
|
if (dim_numbers.output_batch_dimension() > output_feature_dimension) {
|
|
dim_numbers.set_output_batch_dimension(
|
|
dim_numbers.output_batch_dimension() + 1);
|
|
}
|
|
|
|
// To represent a feature group count of 3 you can slide a 3 wide window
|
|
// [X Y Z]
|
|
// across [A 0 0 B 0 0 C] with stride 2 to produce
|
|
// [AX+0Y+0Z 0X+BY+0Z 0X+0Y+CZ] -> [AX BY CZ] which will behave the same as
|
|
// a batch group count.
|
|
Window window = convolution->window();
|
|
auto window_dim = window.add_dimensions();
|
|
window_dim->set_base_dilation(group_count);
|
|
window_dim->set_size(group_count);
|
|
window_dim->set_stride(group_count - 1);
|
|
window_dim->set_padding_low(0);
|
|
window_dim->set_padding_high(0);
|
|
window_dim->set_window_reversal(false);
|
|
window_dim->set_window_dilation(1);
|
|
HloInstruction* new_convolution =
|
|
MakeConvolveHlo(
|
|
activation, filter, /*feature_group_count=*/1,
|
|
/*batch_group_count=*/1, window, dim_numbers,
|
|
convolution->precision_config(),
|
|
/*preferred_element_type=*/convolution->shape().element_type())
|
|
.ValueOrDie();
|
|
convolution->SetupDerivedInstruction(new_convolution);
|
|
changed_ = true;
|
|
return computation_->ReplaceInstruction(
|
|
convolution,
|
|
MakeReshapeHlo(convolution->shape(), new_convolution).ValueOrDie());
|
|
}
|
|
|
|
} // namespace
|
|
|
|
StatusOr<bool> ConvolutionGroupConverter::Run(HloModule* module) {
|
|
XLA_VLOG_LINES(
|
|
2, "ConvolutionGroupConverter::Run(), before:\n" + module->ToString());
|
|
bool changed = false;
|
|
for (auto* comp : module->MakeNonfusionComputations()) {
|
|
if (ConvolutionVisitor::Run(comp, is_cost_viable_,
|
|
convert_batch_groups_only_,
|
|
filter_expansion_)) {
|
|
changed = true;
|
|
}
|
|
}
|
|
XLA_VLOG_LINES(
|
|
2, "ConvolutionGroupConverter::Run(), after:\n" + module->ToString());
|
|
return changed;
|
|
}
|
|
|
|
} // namespace xla
|