[XLA:SPMD] 1st step to refactor convolution_handler.

PiperOrigin-RevId: 331267719
Change-Id: I33ecb8ed0c8596fd11b50daa33160990ffc7ee5e
This commit is contained in:
A. Unique TensorFlower 2020-09-11 19:28:59 -07:00 committed by TensorFlower Gardener
parent f1f8573343
commit 65140f3cc3
8 changed files with 229 additions and 465 deletions

View File

@ -49,14 +49,11 @@ bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size) {
return false;
}
/* static */ absl::optional<DotGeneralAsConvolutionDimsInfo>
ParseDotGeneralFromConvolution(const HloInstruction* conv) {
/* static */ DotConvolutionDimsInfo ParseConvolutionDimsInfo(
const HloInstruction* conv) {
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
if (conv->feature_group_count() != 1 || conv->batch_group_count() != 1) {
return absl::nullopt;
}
const auto& conv_dims = conv->convolution_dimension_numbers();
DotGeneralAsConvolutionDimsInfo dims;
DotConvolutionDimsInfo dims;
dims.lhs_non_contracting_dims.push_back(
{conv_dims.input_batch_dimension(), -1,
conv_dims.output_batch_dimension(), -1});
@ -98,10 +95,10 @@ ParseDotGeneralFromConvolution(const HloInstruction* conv) {
// padding N - 1, high padding N - 1 and window reversal.
dims.rhs_non_contracting_dims.push_back({lhs, rhs, output, i});
} else {
return absl::nullopt;
dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
}
} else {
return absl::nullopt;
dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
}
}
@ -110,8 +107,7 @@ ParseDotGeneralFromConvolution(const HloInstruction* conv) {
StatusOr<std::unique_ptr<HloInstruction>>
CreateShardedConvForDotGeneralConvolution(
const HloInstruction& conv,
const DotGeneralAsConvolutionDimsInfo& dot_dnums,
const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums,
HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo) {
CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
const auto& conv_dnums = conv.convolution_dimension_numbers();
@ -153,10 +149,9 @@ CreateShardedConvForDotGeneralConvolution(
/*batch_group_count=*/1, window, conv_dnums, conv.precision_config());
}
DotGeneralAsConvolutionDimsInfo ParseDotGeneralFromDot(
const HloInstruction* dot) {
DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot) {
const auto& dot_dim_numbs = dot->dot_dimension_numbers();
dot_as_convolution_util::DotGeneralAsConvolutionDimsInfo dnums;
dot_as_convolution_util::DotConvolutionDimsInfo dnums;
for (int64 i = 0; i < dot_dim_numbs.lhs_batch_dimensions().size(); ++i) {
dnums.batch_dims.emplace_back();
dnums.batch_dims.back().lhs = dot_dim_numbs.lhs_batch_dimensions(i);

View File

@ -25,8 +25,9 @@ limitations under the License.
namespace xla {
namespace dot_as_convolution_util {
// Describes the dimensions of a convolution that can be interpreted as a dot.
struct DotGeneralAsConvolutionDimsInfo {
// Describes the dimensions of a convolution that can be interpreted as a dot
// or a normal convolution.
struct DotConvolutionDimsInfo {
// The dimension numbers for the operands and output corresponding to a
// logical dimension (e.g., batch, contracting, non-contracting). If an
// operand or the output doesn't have the logical dimension, it is set to
@ -43,23 +44,22 @@ struct DotGeneralAsConvolutionDimsInfo {
std::vector<DimNums> contracting_dims;
std::vector<DimNums> lhs_non_contracting_dims;
std::vector<DimNums> rhs_non_contracting_dims;
std::vector<DimNums> conv_spatial_dims;
};
// Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo if it can
// be interpreted as a dot, or absl::nullopt otherwise.
absl::optional<DotGeneralAsConvolutionDimsInfo> ParseDotGeneralFromConvolution(
const HloInstruction* conv);
// Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo. If it can
// be interpreted as a dot, there is no conv_spatial_dims.
DotConvolutionDimsInfo ParseConvolutionDimsInfo(const HloInstruction* conv);
// Creates sharded convolution instruction that can be interpreted as a dot.
// This is a utility for per-op partitioners.
// - 'conv' is the original convolution instruction.
// - 'dot_dnums' is the result of ParseDotGeneralFromConvolution() for 'conv'.
// - 'dot_dnums' is the result of ParseDotConvolutionDimsInfo() for 'conv'.
// - 'sharded_lhs_hlo' and 'sharded_rhs_hlo' are sharded inputs for the result
// convolution instruction.
StatusOr<std::unique_ptr<HloInstruction>>
CreateShardedConvForDotGeneralConvolution(
const HloInstruction& conv,
const DotGeneralAsConvolutionDimsInfo& dot_dnums,
const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums,
HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo);
// Check if a spatial dim is parallel batch dimension.
@ -68,10 +68,9 @@ CreateShardedConvForDotGeneralConvolution(
// dilation B.
bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size);
// Returns a DotGeneralAsConvolutionDimsInfo from a kDot instruction, where all
// Returns a DotConvolutionDimsInfo from a kDot instruction, where all
// the spatial_dim values are set to -1.
DotGeneralAsConvolutionDimsInfo ParseDotGeneralFromDot(
const HloInstruction* dot);
DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot);
} // namespace dot_as_convolution_util
} // namespace xla

View File

@ -476,7 +476,7 @@ bool SupportSpatialPartitioning(const HloInstruction* instruction,
bool InferDotShardingFromOperands(
HloInstruction* instruction,
const dot_as_convolution_util::DotGeneralAsConvolutionDimsInfo& dnums,
const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
bool may_combine_partial_sharding) {
auto from_operand = [&](int64 operand_index) {
auto operand = instruction->operand(operand_index);
@ -543,9 +543,10 @@ bool InferDotShardingFromOperands(
bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
int64 aggressiveness,
bool may_combine_partial_sharding) {
if (auto dot_dims = dot_as_convolution_util::ParseDotGeneralFromConvolution(
instruction)) {
return InferDotShardingFromOperands(instruction, *dot_dims,
auto dot_dims =
dot_as_convolution_util::ParseConvolutionDimsInfo(instruction);
if (dot_dims.conv_spatial_dims.empty()) {
return InferDotShardingFromOperands(instruction, dot_dims,
may_combine_partial_sharding);
}
const auto& dnums = instruction->convolution_dimension_numbers();
@ -1031,7 +1032,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
HloSharding InferDotOperandSharding(
const HloInstruction* instruction,
const dot_as_convolution_util::DotGeneralAsConvolutionDimsInfo& dnums,
const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
int64 operand_index, bool may_combine_partial_sharding) {
auto operand = instruction->operand(operand_index);
auto other = instruction->operand(1 - operand_index);
@ -1185,10 +1186,10 @@ absl::optional<HloSharding> GetShardingFromUser(
return HloSharding::Tile(new_tile_assignment);
}
case HloOpcode::kConvolution: {
if (auto dot_dims =
dot_as_convolution_util::ParseDotGeneralFromConvolution(&user)) {
auto dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(&user);
if (dot_dims.conv_spatial_dims.empty()) {
int64 op_idx = user.operand_index(&instruction);
return InferDotOperandSharding(&user, *dot_dims, op_idx,
return InferDotOperandSharding(&user, dot_dims, op_idx,
may_combine_partial_sharding);
}
return absl::nullopt;

View File

@ -23,6 +23,7 @@ cc_library(
"spmd_partitioner_util.cc",
],
hdrs = [
"convolution_handler.h",
"spmd_partitioner.h",
"spmd_partitioner_util.h",
],

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/spmd/convolution_handler.h"
#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
@ -32,15 +34,8 @@ limitations under the License.
namespace xla {
namespace spmd {
namespace {
// Partition convolution.
StatusOr<HloInstruction*> PartitionConvolution(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const Window& conv_window,
HloInstruction* original_hlo, int64 num_partitions,
const SpmdPartitionerOptions& options, HloInstruction* partition_id,
HloModule* module, SpmdBuilder* b);
namespace {
// Partition convolution with batch group count.
StatusOr<HloInstruction*> PartitionConvolutionWithBatchGroupCount(
@ -240,95 +235,6 @@ StatusOr<HloInstruction*> PartitionConvolutionWithFeatureGroupCount(
.hlo();
}
// Partition convolution with only paralell dims are tiled
StatusOr<HloInstruction*> PartitionConvolutionWithParallelDimension(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const Window& conv_window,
HloInstruction* original_hlo, int64 num_partitions, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
const auto& dnums = original_hlo->convolution_dimension_numbers();
std::vector<int64> rhs_to_lhs_indices(output_base_shape.rank());
rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
dnums.input_batch_dimension();
rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
dnums.input_feature_dimension();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
dnums.input_spatial_dimensions(i);
}
std::vector<int64> lhs_to_rhs_indices(output_base_shape.rank());
for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
}
auto aligned_rhs_sharding =
hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
auto aligned_lhs_sharding =
hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
// Handling cases where all the partitioned dimensions are parallel
// dimensions.
int64 lhs_parallel_dim_partitions = 1;
int64 rhs_parallel_dim_partitions = 1;
std::vector<int64> parallel_spatial_dims;
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
int64 lhs_dim = dnums.input_spatial_dimensions(i);
int64 lhs_size = lhs.base_shape().dimensions(lhs_dim);
const auto& wd = conv_window.dimensions(i);
int64 rhs_dim = dnums.kernel_spatial_dimensions(i);
if (dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) {
parallel_spatial_dims.emplace_back(i);
lhs_parallel_dim_partitions *= ShardCountAtDim(lhs.sharding(), lhs_dim);
rhs_parallel_dim_partitions *= ShardCountAtDim(rhs.sharding(), rhs_dim);
}
}
bool lhs_partition_dims_are_parallel =
(lhs_parallel_dim_partitions == num_partitions);
bool rhs_partition_dims_are_parallel =
(rhs_parallel_dim_partitions == num_partitions);
// If there is a parallel dim and all the partitioned dimensions are parallel
// dimensions in either LHS or RHS, simply create partitioned convolutions.
if (parallel_spatial_dims.empty() || ((!lhs_partition_dims_are_parallel) &&
(!rhs_partition_dims_are_parallel))) {
return nullptr;
}
// Reshard LHS or RHS to partition at parallel dimensions as the other
// operand.
if (lhs_partition_dims_are_parallel) {
rhs = rhs.Reshard(aligned_rhs_sharding);
} else {
lhs = lhs.Reshard(aligned_lhs_sharding);
}
// Get LHS and RHS sharded shape.
auto lhs_shard_shape = MakePartitionedShape(lhs.base_shape(), lhs.sharding());
auto rhs_shard_shape = MakePartitionedShape(rhs.base_shape(), rhs.sharding());
// Update convolution window.
auto new_window = conv_window;
for (const auto& spatial_dim : parallel_spatial_dims) {
auto wd = new_window.mutable_dimensions(spatial_dim);
wd->set_size(lhs_shard_shape.dimensions(
dnums.input_spatial_dimensions(spatial_dim)));
wd->set_stride(std::max<int64>(1, wd->size() - 1));
wd->set_base_dilation(wd->size());
}
TF_ASSIGN_OR_RETURN(
Shape sharded_conv_shape,
ShapeInference::InferConvolveShape(
lhs_shard_shape, rhs_shard_shape, original_hlo->feature_group_count(),
original_hlo->batch_group_count(), new_window, dnums));
auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve(
sharded_conv_shape, lhs.hlo(), rhs.hlo(),
original_hlo->feature_group_count(), original_hlo->batch_group_count(),
new_window, dnums, original_hlo->precision_config()));
sharded_conv->set_sharding(original_hlo->sharding());
return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
.Reshard(output_sharding)
.hlo();
}
// Partition convolution when both LHS and RHS are partitioned at spatial
// dimensions. Halo exchange will happen on RHS only.
StatusOr<HloInstruction*>
@ -412,7 +318,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension);
auto wd = conv_window.dimensions(i);
const auto& wd = conv_window.dimensions(i);
if (wd.base_dilation() != 1 || wd.window_reversal()) {
return nullptr;
}
@ -458,7 +364,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
// Calculate the left and right halo sizes as described in the comments
// above. It calculcates the halo sizes with dilation, so we apply
// CeilOfRatio({left,right}_halo_size, window_dilation).
auto wd = conv_window.dimensions(i);
const auto& wd = conv_window.dimensions(i);
int64 padding_low = wd.padding_low();
int64 padding_high = wd.padding_high();
int64 base = lhs.base_shape().dimensions(lhs_dimension);
@ -628,7 +534,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
}
Window window = conv_window;
const Window& window = conv_window;
std::vector<int64> reversed_rhs_dims;
for (int64 i = 0; i < window.dimensions_size(); ++i) {
if (window.dimensions(i).window_reversal()) {
@ -703,7 +609,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension);
auto wd = window.dimensions(i);
const auto& wd = window.dimensions(i);
if (wd.base_dilation() != 1) {
// TODO(wangtao): support parallel dim if it is replicate here.
return nullptr;
@ -738,7 +644,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
// Calculate the left and right halo sizes as described in the comments
// above.
auto wd = window.dimensions(i);
const auto& wd = window.dimensions(i);
int64 padding_low = wd.padding_low();
int64 padding_high = wd.padding_high();
int64 base = lhs.base_shape().dimensions(lhs_dimension);
@ -890,116 +796,6 @@ StatusOr<HloInstruction*> PartitionConvolutionTiledOutput(
shard_shape.dimensions()));
}
StatusOr<HloInstruction*> PartitionConvolutionGroupOnParallelDim(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const Window& conv_window,
HloInstruction* original_hlo, const ConvolutionDimsMapping& dims_mapping,
int64 num_partitions, const SpmdPartitionerOptions& options,
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
std::vector<int64> lhs_dims;
std::vector<int64> rhs_dims;
std::vector<int64> output_dims;
auto lhs_sharding_dims_adjusted_to_output =
lhs.sharding().IsReplicated()
? std::vector<int64>(lhs.base_shape().rank(), 1)
: lhs.sharding().tile_assignment().dimensions();
auto rhs_sharding_dims_adjusted_to_output =
rhs.sharding().IsReplicated()
? std::vector<int64>(rhs.base_shape().rank(), 1)
: rhs.sharding().tile_assignment().dimensions();
auto output_sharding_dims_adjusted_to_lhs =
output_sharding.tile_assignment().dimensions();
bool lhs_rhs_dims_matching = true;
for (const auto& dim : dims_mapping.parallel_spatial_dims) {
lhs_dims.push_back(dim.lhs);
rhs_dims.push_back(dim.rhs);
output_dims.push_back(dim.output);
if (lhs_sharding_dims_adjusted_to_output[dim.lhs] !=
rhs_sharding_dims_adjusted_to_output[dim.rhs]) {
lhs_rhs_dims_matching = false;
}
lhs_sharding_dims_adjusted_to_output[dim.lhs] =
output_sharding.tile_assignment().dim(dim.output);
rhs_sharding_dims_adjusted_to_output[dim.rhs] =
output_sharding.tile_assignment().dim(dim.output);
output_sharding_dims_adjusted_to_lhs[dim.output] =
lhs.sharding().tile_assignment().dim(dim.lhs);
}
auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims);
auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims);
auto output_grouped = GroupShardingOnDims(output_sharding, output_dims);
if (lhs_rhs_dims_matching) {
if (ShapeUtil::ByteSizeOf(lhs.base_shape()) >
ShapeUtil::ByteSizeOf(rhs.base_shape())) {
rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped);
rhs = rhs.Reshard(UngroupSharding(rhs_grouped));
} else {
lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped);
lhs = lhs.Reshard(UngroupSharding(lhs_grouped));
}
auto reshaped_output_tiling = output_sharding.tile_assignment();
reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs);
output_grouped = AlignGroupsWith(
GroupShardingOnDims(HloSharding::Tile(reshaped_output_tiling),
output_dims),
lhs_grouped);
} else {
auto reshaped_lhs_tiling = lhs.sharding().tile_assignment();
reshaped_lhs_tiling.Reshape(lhs_sharding_dims_adjusted_to_output);
lhs_grouped = AlignGroupsWith(
GroupShardingOnDims(HloSharding::Tile(reshaped_lhs_tiling), lhs_dims),
output_grouped);
lhs = lhs.Reshard(UngroupSharding(lhs_grouped));
auto reshaped_rhs_tiling = rhs.sharding().tile_assignment();
reshaped_rhs_tiling.Reshape(rhs_sharding_dims_adjusted_to_output);
rhs_grouped = AlignGroupsWith(
GroupShardingOnDims(HloSharding::Tile(reshaped_rhs_tiling), rhs_dims),
output_grouped);
rhs = rhs.Reshard(UngroupSharding(rhs_grouped));
}
// Update LHS and RHS sharding and shape.
lhs.hlo()->set_sharding(lhs_grouped.sharding);
rhs.hlo()->set_sharding(rhs_grouped.sharding);
CHECK(lhs.hlo() != rhs.hlo() || lhs_grouped.sharding == rhs_grouped.sharding);
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
lhs.state(), lhs_grouped.device_groups, b);
auto grouped_lhs_base_shape =
GetPerGroupBaseShape(lhs_grouped, lhs.base_shape());
auto grouped_lhs_shard_shape =
MakePartitionedShape(grouped_lhs_base_shape, lhs.sharding());
// Update convolution window with the new shape
auto new_window = conv_window;
for (const auto& dim : dims_mapping.parallel_spatial_dims) {
auto wd = new_window.mutable_dimensions(dim.spatial);
wd->set_size(grouped_lhs_shard_shape.dimensions(dim.lhs));
wd->set_stride(std::max<int64>(1, wd->size() - 1));
wd->set_base_dilation(wd->size());
}
auto new_partition_id =
lhs.state().collective_ops_creator.create_partition_id(b);
TF_ASSIGN_OR_RETURN(
auto conv,
PartitionConvolution(
PartitionedHlo(lhs.hlo(), grouped_lhs_base_shape,
per_group_partitioner_state),
PartitionedHlo(rhs.hlo(),
GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
per_group_partitioner_state),
GetPerGroupBaseShape(output_grouped, output_base_shape),
output_grouped.sharding, new_window, original_hlo,
num_partitions / output_grouped.device_groups.size(), options,
new_partition_id, module, b));
// Reset the LHS sharding to the ungrouped one.
lhs.hlo()->set_sharding(UngroupSharding(lhs_grouped));
rhs.hlo()->set_sharding(UngroupSharding(rhs_grouped));
conv->set_sharding(UngroupSharding(output_grouped));
return PartitionedHlo(conv, output_base_shape, lhs.state())
.Reshard(output_sharding)
.hlo();
}
// Partition convolution with only one kind of dims partitioned.
StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
@ -1009,7 +805,7 @@ StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
HloModule* module, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
// Case 0: Handle depthwise convolution with batch group count or
// Case 1: Handle depthwise convolution with batch group count or
// feature group count.
if (original_hlo->batch_group_count() > 1) {
TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv,
@ -1031,15 +827,6 @@ StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
}
}
// Case 1: Either RHS or LHS is only partitioned at parallel dimensions.
TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv,
PartitionConvolutionWithParallelDimension(
lhs, rhs, output_base_shape, output_sharding,
conv_window, original_hlo, num_partitions, b));
if (parallel_partitioned_conv) {
return parallel_partitioned_conv;
}
// Case 2: both RHS and LHS are tiled.
// Handling cases where both operands' shardings are aligned. We check that
// the LHS batch dimension is not partitioned because it is mapped to the
@ -1082,13 +869,15 @@ StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
return nullptr;
}
} // namespace
// Partition convolution.
StatusOr<HloInstruction*> PartitionConvolution(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const Window& conv_window,
HloInstruction* original_hlo, int64 num_partitions,
const SpmdPartitionerOptions& options, HloInstruction* partition_id,
HloModule* module, SpmdBuilder* b) {
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
const Window& conv_window, HloInstruction* original_hlo,
int64 num_partitions, const SpmdPartitionerOptions& options,
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
TF_ASSIGN_OR_RETURN(
@ -1100,133 +889,57 @@ StatusOr<HloInstruction*> PartitionConvolution(
return try_partitioned_conv;
}
const auto& dnums = original_hlo->convolution_dimension_numbers();
spmd::ConvolutionDimsMapping mapping;
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
int64 lhs_dim = dnums.input_spatial_dimensions(i);
int64 lhs_size = lhs.base_shape().dimensions(lhs_dim);
const auto& wd = original_hlo->window().dimensions(i);
int64 rhs_dim = dnums.kernel_spatial_dimensions(i);
int64 output_dim = dnums.output_spatial_dimensions(i);
if (dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) {
mapping.parallel_spatial_dims.emplace_back();
mapping.parallel_spatial_dims.back().lhs = lhs_dim;
mapping.parallel_spatial_dims.back().rhs = rhs_dim;
mapping.parallel_spatial_dims.back().output = output_dim;
mapping.parallel_spatial_dims.back().spatial = i;
} else {
mapping.non_parallel_spatial_dims.emplace_back();
mapping.non_parallel_spatial_dims.back().lhs = lhs_dim;
mapping.non_parallel_spatial_dims.back().rhs = rhs_dim;
mapping.non_parallel_spatial_dims.back().output = output_dim;
mapping.non_parallel_spatial_dims.back().spatial = i;
}
}
// lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
auto get_partitions_for_dims =
[&](const HloSharding& sharding,
absl::Span<const ConvolutionDimsMapping::DimsMapping> dims,
int lhs_rhs_or_output) {
int64 partitions = 1;
if (sharding.IsTileMaximal()) {
return partitions;
}
for (const auto& dim : dims) {
if (lhs_rhs_or_output == 0) {
partitions *= sharding.tile_assignment().dim(dim.lhs);
} else if (lhs_rhs_or_output == 1) {
partitions *= sharding.tile_assignment().dim(dim.rhs);
} else {
CHECK_EQ(lhs_rhs_or_output, 2);
partitions *= sharding.tile_assignment().dim(dim.output);
}
}
return partitions;
};
const int64 lhs_parallel_spatial_partitions =
get_partitions_for_dims(lhs.sharding(), mapping.parallel_spatial_dims, 0);
const int64 rhs_parallel_spatial_partitions =
get_partitions_for_dims(rhs.sharding(), mapping.parallel_spatial_dims, 1);
const int64 output_parallel_spatial_partitions = get_partitions_for_dims(
original_hlo->sharding(), mapping.parallel_spatial_dims, 2);
// Recursively partition on different types of dimensions.
//
// Case 1: Group partitions by parallel spatial dims.
if (lhs_parallel_spatial_partitions == rhs_parallel_spatial_partitions &&
lhs_parallel_spatial_partitions == output_parallel_spatial_partitions &&
lhs_parallel_spatial_partitions > 1) {
TF_ASSIGN_OR_RETURN(auto try_partitioned_conv,
PartitionConvolutionGroupOnParallelDim(
lhs, rhs, output_base_shape, output_sharding,
conv_window, original_hlo, mapping, num_partitions,
options, partition_id, module, b));
if (try_partitioned_conv) {
return try_partitioned_conv;
}
}
return nullptr;
}
} // namespace
Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
auto dot_dnums = dot_as_convolution_util::ParseDotGeneralFromConvolution(hlo);
if (dot_dnums) {
// Use HandleDotHelper() for convs that are actually einsums.
spmd::DotGeneralDimsMapping mapping;
for (const auto& dims : dot_dnums->batch_dims) {
mapping.batch_dims.emplace_back();
mapping.batch_dims.back().lhs = dims.lhs;
mapping.batch_dims.back().rhs = dims.rhs;
mapping.batch_dims.back().output = dims.output;
}
for (const auto& dims : dot_dnums->contracting_dims) {
mapping.contracting_dims.emplace_back();
mapping.contracting_dims.back().lhs = dims.lhs;
mapping.contracting_dims.back().rhs = dims.rhs;
mapping.contracting_dims.back().output = dims.output;
}
for (const auto& dims : dot_dnums->lhs_non_contracting_dims) {
mapping.lhs_non_contracting_dims.emplace_back();
mapping.lhs_non_contracting_dims.back().lhs = dims.lhs;
mapping.lhs_non_contracting_dims.back().rhs = dims.rhs;
mapping.lhs_non_contracting_dims.back().output = dims.output;
}
for (const auto& dims : dot_dnums->rhs_non_contracting_dims) {
mapping.rhs_non_contracting_dims.emplace_back();
mapping.rhs_non_contracting_dims.back().lhs = dims.lhs;
mapping.rhs_non_contracting_dims.back().rhs = dims.rhs;
mapping.rhs_non_contracting_dims.back().output = dims.output;
}
auto create_sharded_conv =
[&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
spmd::SpmdBuilder* b) -> StatusOr<HloInstruction*> {
TF_ASSIGN_OR_RETURN(
auto sharded_conv,
dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
*hlo, *dot_dnums, lhs_hlo, rhs_hlo));
return b->AddInstruction(std::move(sharded_conv));
};
return HandleDotHelper(hlo, mapping, create_sharded_conv);
auto dims_info = dot_as_convolution_util::ParseConvolutionDimsInfo(hlo);
spmd::DotConvDimsMapping mapping;
for (const auto& dims : dims_info.batch_dims) {
mapping.batch_dims.emplace_back();
mapping.batch_dims.back().lhs = dims.lhs;
mapping.batch_dims.back().rhs = dims.rhs;
mapping.batch_dims.back().output = dims.output;
mapping.batch_dims.back().spatial = dims.spatial_dim;
}
auto lhs = GetPartitionedHlo(hlo->operand(0));
auto rhs = GetPartitionedHlo(hlo->operand(1));
TF_ASSIGN_OR_RETURN(
auto partitioned_conv,
PartitionConvolution(lhs, rhs, hlo->shape(), hlo->sharding(),
hlo->window(), hlo, num_partitions_, options_,
partition_id_, module_, &b_));
if (partitioned_conv) {
SetPartitionedHlo(hlo, [&] { return partitioned_conv; });
return Status::OK();
for (const auto& dims : dims_info.contracting_dims) {
mapping.contracting_dims.emplace_back();
mapping.contracting_dims.back().lhs = dims.lhs;
mapping.contracting_dims.back().rhs = dims.rhs;
mapping.contracting_dims.back().output = dims.output;
mapping.contracting_dims.back().spatial = dims.spatial_dim;
}
return DefaultAction(hlo);
for (const auto& dims : dims_info.lhs_non_contracting_dims) {
mapping.lhs_non_contracting_dims.emplace_back();
mapping.lhs_non_contracting_dims.back().lhs = dims.lhs;
mapping.lhs_non_contracting_dims.back().rhs = dims.rhs;
mapping.lhs_non_contracting_dims.back().output = dims.output;
mapping.lhs_non_contracting_dims.back().spatial = dims.spatial_dim;
}
for (const auto& dims : dims_info.rhs_non_contracting_dims) {
mapping.rhs_non_contracting_dims.emplace_back();
mapping.rhs_non_contracting_dims.back().lhs = dims.lhs;
mapping.rhs_non_contracting_dims.back().rhs = dims.rhs;
mapping.rhs_non_contracting_dims.back().output = dims.output;
mapping.rhs_non_contracting_dims.back().spatial = dims.spatial_dim;
}
for (const auto& dims : dims_info.conv_spatial_dims) {
mapping.conv_spatial_dims.emplace_back();
mapping.conv_spatial_dims.back().lhs = dims.lhs;
mapping.conv_spatial_dims.back().rhs = dims.rhs;
mapping.conv_spatial_dims.back().output = dims.output;
mapping.conv_spatial_dims.back().spatial = dims.spatial_dim;
}
auto create_sharded_conv =
[&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
spmd::SpmdBuilder* b) -> StatusOr<HloInstruction*> {
TF_ASSIGN_OR_RETURN(
auto sharded_conv,
dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
*hlo, dims_info, lhs_hlo, rhs_hlo));
return b->AddInstruction(std::move(sharded_conv));
};
return HandleDotHelper(hlo, mapping, create_sharded_conv);
}
} // namespace spmd

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_CONVOLUTION_HANDLER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_CONVOLUTION_HANDLER_H_
#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
namespace xla {
namespace spmd {
// Partition convolution.
StatusOr<HloInstruction*> PartitionConvolution(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
const Window& conv_window, HloInstruction* original_hlo,
int64 num_partitions, const SpmdPartitionerOptions& options,
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b);
} // namespace spmd
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_CONVOLUTION_HANDLER_H_

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/service/spmd/convolution_handler.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -36,7 +37,7 @@ namespace xla {
namespace spmd {
Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
DotGeneralDimsMapping mapping;
DotConvDimsMapping mapping;
const auto& dnums = hlo->dot_dimension_numbers();
int64 next_output_dim = 0;
for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) {
@ -88,8 +89,8 @@ namespace {
StatusOr<HloInstruction*> PartitionBaseCase(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding,
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo, int64 lhs_batch_partitions,
@ -98,7 +99,7 @@ StatusOr<HloInstruction*> PartitionBaseCase(
int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions,
int64 output_lhs_non_contracting_partitions,
int64 output_rhs_non_contracting_partitions,
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops,
bool may_reshard_without_detecting_match) {
@ -116,7 +117,7 @@ StatusOr<HloInstruction*> PartitionBaseCase(
std::vector<int64> output_to_lhs_indices(output_base_shape.rank(), -1);
std::vector<int64> output_to_rhs_indices(output_base_shape.rank(), -1);
auto populate_indices_mapping =
[&](const DotGeneralDimsMapping::DimsMapping& mapping) {
[&](const DotConvDimsMapping::DimsMapping& mapping) {
if (mapping.lhs >= 0) {
lhs_to_rhs_indices[mapping.lhs] = mapping.rhs;
lhs_to_output_indices[mapping.lhs] = mapping.output;
@ -142,6 +143,9 @@ StatusOr<HloInstruction*> PartitionBaseCase(
for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) {
populate_indices_mapping(mapping);
}
for (const auto& mapping : dims_mapping.conv_spatial_dims) {
populate_indices_mapping(mapping);
}
auto lhs_sharding_transposed_to_match_rhs =
hlo_sharding_util::TransposeShardingWithCollapsedDims(
lhs_sharding, lhs_to_rhs_indices, rhs_to_lhs_indices);
@ -408,7 +412,7 @@ StatusOr<HloInstruction*> PartitionBaseCase(
if (output_lhs_non_contracting_partitions == num_partitions &&
output_sharding_transposed_to_match_lhs == lhs_sharding &&
ShapeSizeInBytes(rhs.base_shape()) >=
threshold_for_windowed_einsum_mib * 1024 * 1024) {
options.threshold_for_windowed_einsum_mib * 1024 * 1024) {
if (rhs_contracting_partitions == num_partitions) {
return emit_windowed_dot_general(0, 1, true, false);
}
@ -422,7 +426,7 @@ StatusOr<HloInstruction*> PartitionBaseCase(
if (output_rhs_non_contracting_partitions == num_partitions &&
output_sharding_transposed_to_match_rhs == rhs_sharding &&
ShapeSizeInBytes(lhs.base_shape()) >=
threshold_for_windowed_einsum_mib * 1024 * 1024) {
options.threshold_for_windowed_einsum_mib * 1024 * 1024) {
if (lhs_contracting_partitions == num_partitions) {
return emit_windowed_dot_general(1, 0, true, false);
}
@ -572,26 +576,26 @@ StatusOr<HloInstruction*> PartitionBaseCase(
StatusOr<HloInstruction*> PartitionDot(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding,
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops);
StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding,
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
int64 lhs_contracting_partitions, int64 rhs_contracting_partitions,
int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions,
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions, int64 lhs_contracting_partitions,
int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
int64 rhs_non_contracting_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
bool require_matching_devices_to_group,
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops) {
std::vector<std::pair<HloInstruction*, HloSharding>>
@ -804,8 +808,7 @@ StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
GetPerGroupBaseShape(output_grouped, output_base_shape),
output_grouped.sharding, dims_mapping,
num_partitions / output_grouped.device_groups.size(),
create_sharded_dot, module, original_hlo,
threshold_for_windowed_einsum_mib, b,
create_sharded_dot, module, original_hlo, options, b,
windowed_dot_general_loops));
dot->set_sharding(UngroupSharding(output_grouped));
return PartitionedHlo(dot, output_base_shape, lhs.state())
@ -816,17 +819,17 @@ StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
bool lhs_matching, PartitionedHlo matching, PartitionedHlo other,
int64 matching_contracting_partitions, int64 other_contracting_partitions,
absl::Span<const DotGeneralDimsMapping::DimsMapping>
absl::Span<const DotConvDimsMapping::DimsMapping>
partitioned_non_contractin_dims,
int64 other_non_contracting_partitions,
int64 output_other_non_contracting_partitions,
const Shape& output_base_shape, const HloSharding& output_sharding,
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
const DotConvDimsMapping& dims_mapping, int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
bool require_matching_devices_to_group,
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops) {
std::vector<std::pair<HloInstruction*, HloSharding>>
@ -949,25 +952,24 @@ StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
GetPerGroupBaseShape(output_grouped, output_base_shape),
output_grouped.sharding, dims_mapping,
num_partitions / matching_grouped.device_groups.size(),
create_sharded_dot, module, original_hlo,
threshold_for_windowed_einsum_mib, b,
create_sharded_dot, module, original_hlo, options, b,
windowed_dot_general_loops));
return dot;
}
StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
PartitionedHlo lhs, PartitionedHlo rhs,
absl::Span<const DotGeneralDimsMapping::DimsMapping>
absl::Span<const DotConvDimsMapping::DimsMapping>
partitioned_contractin_dims,
int64 output_batch_partitions, int64 output_lhs_non_contracting_partitions,
int64 output_rhs_non_contracting_partitions, const Shape& output_base_shape,
const HloSharding& output_sharding,
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
bool require_matching_devices_to_group,
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops) {
std::vector<std::pair<HloInstruction*, HloSharding>>
@ -1090,8 +1092,8 @@ StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
inner_state),
MakePartitionedShape(output_base_shape, outer_output_tmp_sharding),
inner_output_sharding, dims_mapping, num_partitions / group_count,
create_sharded_dot, module, original_hlo,
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
create_sharded_dot, module, original_hlo, options, b,
windowed_dot_general_loops));
if (!dot) {
return nullptr;
}
@ -1119,19 +1121,19 @@ StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
// in-group dot.
StatusOr<HloInstruction*> PartitionDot(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding,
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
bool require_matching_devices_to_group,
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops) {
// lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
auto get_partitions_for_dims =
[&](const HloSharding& sharding,
absl::Span<const DotGeneralDimsMapping::DimsMapping> dims,
absl::Span<const DotConvDimsMapping::DimsMapping> dims,
int lhs_rhs_or_output) {
int64 partitions = 1;
if (sharding.IsTileMaximal()) {
@ -1169,6 +1171,39 @@ StatusOr<HloInstruction*> PartitionDot(
output_sharding, dims_mapping.rhs_non_contracting_dims, 2);
// Before we find partial matches along the dimensions, invoke base case again
// without may_reshard_without_detecting_match.
// Try partition the purely spatially-partitioned convolution first.
if (!dims_mapping.conv_spatial_dims.empty()) {
const auto& conv_dnums = original_hlo->convolution_dimension_numbers();
auto window = original_hlo->window();
// TODO(wangtao): remove this hack by passing create_sharded_conv to
// PartitionConv.
// Update convolution window when it is in the recursive call for
// batch_dims.
if (original_hlo->batch_group_count() == 1 &&
original_hlo->feature_group_count() == 1 &&
!ShapeUtil::Compatible(original_hlo->shape(), output_base_shape)) {
for (const auto& dim : dims_mapping.batch_dims) {
auto wd = window.mutable_dimensions(dim.spatial);
wd->set_size(lhs.hlo()->shape().dimensions(
conv_dnums.input_spatial_dimensions(dim.spatial)));
wd->set_stride(std::max<int64>(1, wd->size() - 1));
wd->set_base_dilation(wd->size());
}
}
TF_ASSIGN_OR_RETURN(
auto partitioned_conv,
PartitionConvolution(lhs, rhs, output_base_shape, output_sharding,
dims_mapping, window, original_hlo, num_partitions,
options, lhs.state().partition_id, module, b));
if (partitioned_conv) {
return partitioned_conv;
}
}
TF_ASSIGN_OR_RETURN(
auto try_partitioned_dot,
PartitionBaseCase(
@ -1178,8 +1213,8 @@ StatusOr<HloInstruction*> PartitionDot(
lhs_contracting_partitions, rhs_contracting_partitions,
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
output_lhs_non_contracting_partitions,
output_rhs_non_contracting_partitions,
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops,
output_rhs_non_contracting_partitions, options, b,
windowed_dot_general_loops,
/*may_reshard_without_detecting_match=*/false));
if (try_partitioned_dot) {
return try_partitioned_dot;
@ -1198,8 +1233,8 @@ StatusOr<HloInstruction*> PartitionDot(
num_partitions, lhs_contracting_partitions,
rhs_contracting_partitions, lhs_non_contracting_partitions,
rhs_non_contracting_partitions, create_sharded_dot, module,
original_hlo, require_matching_devices_to_group,
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
original_hlo, require_matching_devices_to_group, options, b,
windowed_dot_general_loops));
if (dot) {
return dot;
}
@ -1239,8 +1274,8 @@ StatusOr<HloInstruction*> PartitionDot(
: output_lhs_non_contracting_partitions,
output_base_shape, output_sharding, dims_mapping, num_partitions,
create_sharded_dot, module, original_hlo,
require_matching_devices_to_group,
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
require_matching_devices_to_group, options, b,
windowed_dot_general_loops));
if (dot) {
return dot;
}
@ -1248,7 +1283,7 @@ StatusOr<HloInstruction*> PartitionDot(
if (lhs_non_contracting_partitions > 1 &&
output_lhs_non_contracting_partitions > 1) {
// If part of LHS non-contracting dims match output, try them.
std::vector<DotGeneralDimsMapping::DimsMapping> matching_dims;
std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
int64 lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
if (lhs_partitions > 1 &&
@ -1265,9 +1300,8 @@ StatusOr<HloInstruction*> PartitionDot(
rhs_non_contracting_partitions,
output_rhs_non_contracting_partitions, output_base_shape,
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
module, original_hlo, require_matching_devices_to_group,
threshold_for_windowed_einsum_mib, b,
windowed_dot_general_loops));
module, original_hlo, require_matching_devices_to_group, options,
b, windowed_dot_general_loops));
if (dot) {
return dot;
}
@ -1276,7 +1310,7 @@ StatusOr<HloInstruction*> PartitionDot(
if (rhs_non_contracting_partitions > 1 &&
output_rhs_non_contracting_partitions > 1) {
// If part of RHS non-contracting dims match output, try them.
std::vector<DotGeneralDimsMapping::DimsMapping> matching_dims;
std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
int64 rhs_partitions = rhs.sharding().tile_assignment().dim(dim.rhs);
if (rhs_partitions > 1 &&
@ -1293,9 +1327,8 @@ StatusOr<HloInstruction*> PartitionDot(
lhs_non_contracting_partitions,
output_lhs_non_contracting_partitions, output_base_shape,
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
module, original_hlo, require_matching_devices_to_group,
threshold_for_windowed_einsum_mib, b,
windowed_dot_general_loops));
module, original_hlo, require_matching_devices_to_group, options,
b, windowed_dot_general_loops));
if (dot) {
return dot;
}
@ -1312,15 +1345,15 @@ StatusOr<HloInstruction*> PartitionDot(
output_lhs_non_contracting_partitions,
output_rhs_non_contracting_partitions, output_base_shape,
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
module, original_hlo, require_matching_devices_to_group,
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
module, original_hlo, require_matching_devices_to_group, options, b,
windowed_dot_general_loops));
if (dot) {
return dot;
}
}
if (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1) {
// If part of contracting dims match, try them.
std::vector<DotGeneralDimsMapping::DimsMapping> matching_dims;
std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
for (const auto& dim : dims_mapping.contracting_dims) {
int64 lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
if (lhs_partitions > 1 &&
@ -1336,9 +1369,8 @@ StatusOr<HloInstruction*> PartitionDot(
output_lhs_non_contracting_partitions,
output_rhs_non_contracting_partitions, output_base_shape,
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
module, original_hlo, require_matching_devices_to_group,
threshold_for_windowed_einsum_mib, b,
windowed_dot_general_loops));
module, original_hlo, require_matching_devices_to_group, options,
b, windowed_dot_general_loops));
if (dot) {
return dot;
}
@ -1359,8 +1391,7 @@ StatusOr<HloInstruction*> PartitionDot(
PartitionedHlo(rhs.hlo(), rhs.base_shape(), inner_state),
output_base_shape, grouped_output.sharding, dims_mapping,
output_sharding.NumTiles(), create_sharded_dot, module,
original_hlo, threshold_for_windowed_einsum_mib, b,
windowed_dot_general_loops));
original_hlo, options, b, windowed_dot_general_loops));
if (dot) {
return dot;
}
@ -1377,8 +1408,8 @@ StatusOr<HloInstruction*> PartitionDot(
lhs_contracting_partitions, rhs_contracting_partitions,
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
output_lhs_non_contracting_partitions,
output_rhs_non_contracting_partitions,
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops,
output_rhs_non_contracting_partitions, options, b,
windowed_dot_general_loops,
/*may_reshard_without_detecting_match=*/true));
if (dot) {
return dot;
@ -1388,12 +1419,12 @@ StatusOr<HloInstruction*> PartitionDot(
StatusOr<HloInstruction*> PartitionDot(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding,
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops) {
// First try partitioning without resharding the groups, then try allow
@ -1403,8 +1434,7 @@ StatusOr<HloInstruction*> PartitionDot(
auto try_partition,
PartitionDot(lhs, rhs, output_base_shape, output_sharding, dims_mapping,
num_partitions, create_sharded_dot, module, original_hlo,
require_matching_devices_to_group,
threshold_for_windowed_einsum_mib, b,
require_matching_devices_to_group, options, b,
windowed_dot_general_loops));
if (try_partition) {
return try_partition;
@ -1423,7 +1453,7 @@ StatusOr<HloInstruction*> PartitionDot(
} // namespace
Status SpmdPartitioningVisitor::HandleDotHelper(
HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping,
HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) {
auto& lhs = GetPartitionedHlo(hlo->operand(0));
@ -1431,9 +1461,8 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
TF_ASSIGN_OR_RETURN(
auto partitioned_dot,
PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping,
num_partitions_, create_sharded_dot, module_, hlo,
options_.threshold_for_windowed_einsum_mib, &b_,
&windowed_dot_general_loops_));
num_partitions_, create_sharded_dot, module_, hlo, options_,
&b_, &windowed_dot_general_loops_));
SetPartitionedHlo(hlo, [&] { return partitioned_dot; });
return Status::OK();
}

View File

@ -330,27 +330,11 @@ class PartitionedHlo {
PartitioningState state_;
};
struct DotGeneralDimsMapping {
struct DotConvDimsMapping {
// The dimension numbers for the operands and output corresponding to a
// logical dimension (e.g., batch, contracting, non-contracting). If an
// operand or the output doesn't have the logical dimension, it is set to
// -1.
struct DimsMapping {
int64 lhs;
int64 rhs;
int64 output;
};
std::vector<DimsMapping> batch_dims;
std::vector<DimsMapping> contracting_dims;
std::vector<DimsMapping> lhs_non_contracting_dims;
std::vector<DimsMapping> rhs_non_contracting_dims;
};
struct ConvolutionDimsMapping {
// The dimension numbers for the operands and output corresponding to a
// logical dimension (e.g., batch, parallel, non-parallel). If an
// operand or the output doesn't have the logical dimension, it is set to
// -1.
struct DimsMapping {
int64 lhs;
int64 rhs;
@ -358,8 +342,11 @@ struct ConvolutionDimsMapping {
// input mapped to index in input_spatial_dimensions().
int64 spatial;
};
std::vector<DimsMapping> parallel_spatial_dims;
std::vector<DimsMapping> non_parallel_spatial_dims;
std::vector<DimsMapping> batch_dims;
std::vector<DimsMapping> contracting_dims;
std::vector<DimsMapping> lhs_non_contracting_dims;
std::vector<DimsMapping> rhs_non_contracting_dims;
std::vector<DimsMapping> conv_spatial_dims;
};
class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
@ -404,7 +391,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
// Implementation of dot partitioning given DotGeneralDimsMapping.
Status HandleDotHelper(
HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping,
HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot);