[XLA:SPMD] 1st step to refactor convolution_handler.
PiperOrigin-RevId: 331267719 Change-Id: I33ecb8ed0c8596fd11b50daa33160990ffc7ee5e
This commit is contained in:
parent
f1f8573343
commit
65140f3cc3
@ -49,14 +49,11 @@ bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/* static */ absl::optional<DotGeneralAsConvolutionDimsInfo>
|
||||
ParseDotGeneralFromConvolution(const HloInstruction* conv) {
|
||||
/* static */ DotConvolutionDimsInfo ParseConvolutionDimsInfo(
|
||||
const HloInstruction* conv) {
|
||||
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
|
||||
if (conv->feature_group_count() != 1 || conv->batch_group_count() != 1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
const auto& conv_dims = conv->convolution_dimension_numbers();
|
||||
DotGeneralAsConvolutionDimsInfo dims;
|
||||
DotConvolutionDimsInfo dims;
|
||||
dims.lhs_non_contracting_dims.push_back(
|
||||
{conv_dims.input_batch_dimension(), -1,
|
||||
conv_dims.output_batch_dimension(), -1});
|
||||
@ -98,10 +95,10 @@ ParseDotGeneralFromConvolution(const HloInstruction* conv) {
|
||||
// padding N - 1, high padding N - 1 and window reversal.
|
||||
dims.rhs_non_contracting_dims.push_back({lhs, rhs, output, i});
|
||||
} else {
|
||||
return absl::nullopt;
|
||||
dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
|
||||
}
|
||||
} else {
|
||||
return absl::nullopt;
|
||||
dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
|
||||
}
|
||||
}
|
||||
|
||||
@ -110,8 +107,7 @@ ParseDotGeneralFromConvolution(const HloInstruction* conv) {
|
||||
|
||||
StatusOr<std::unique_ptr<HloInstruction>>
|
||||
CreateShardedConvForDotGeneralConvolution(
|
||||
const HloInstruction& conv,
|
||||
const DotGeneralAsConvolutionDimsInfo& dot_dnums,
|
||||
const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums,
|
||||
HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo) {
|
||||
CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
|
||||
const auto& conv_dnums = conv.convolution_dimension_numbers();
|
||||
@ -153,10 +149,9 @@ CreateShardedConvForDotGeneralConvolution(
|
||||
/*batch_group_count=*/1, window, conv_dnums, conv.precision_config());
|
||||
}
|
||||
|
||||
DotGeneralAsConvolutionDimsInfo ParseDotGeneralFromDot(
|
||||
const HloInstruction* dot) {
|
||||
DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot) {
|
||||
const auto& dot_dim_numbs = dot->dot_dimension_numbers();
|
||||
dot_as_convolution_util::DotGeneralAsConvolutionDimsInfo dnums;
|
||||
dot_as_convolution_util::DotConvolutionDimsInfo dnums;
|
||||
for (int64 i = 0; i < dot_dim_numbs.lhs_batch_dimensions().size(); ++i) {
|
||||
dnums.batch_dims.emplace_back();
|
||||
dnums.batch_dims.back().lhs = dot_dim_numbs.lhs_batch_dimensions(i);
|
||||
|
@ -25,8 +25,9 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace dot_as_convolution_util {
|
||||
|
||||
// Describes the dimensions of a convolution that can be interpreted as a dot.
|
||||
struct DotGeneralAsConvolutionDimsInfo {
|
||||
// Describes the dimensions of a convolution that can be interpreted as a dot
|
||||
// or a normal convolution.
|
||||
struct DotConvolutionDimsInfo {
|
||||
// The dimension numbers for the operands and output corresponding to a
|
||||
// logical dimension (e.g., batch, contracting, non-contracting). If an
|
||||
// operand or the output doesn't have the logical dimension, it is set to
|
||||
@ -43,23 +44,22 @@ struct DotGeneralAsConvolutionDimsInfo {
|
||||
std::vector<DimNums> contracting_dims;
|
||||
std::vector<DimNums> lhs_non_contracting_dims;
|
||||
std::vector<DimNums> rhs_non_contracting_dims;
|
||||
std::vector<DimNums> conv_spatial_dims;
|
||||
};
|
||||
|
||||
// Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo if it can
|
||||
// be interpreted as a dot, or absl::nullopt otherwise.
|
||||
absl::optional<DotGeneralAsConvolutionDimsInfo> ParseDotGeneralFromConvolution(
|
||||
const HloInstruction* conv);
|
||||
// Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo. If it can
|
||||
// be interpreted as a dot, there is no conv_spatial_dims.
|
||||
DotConvolutionDimsInfo ParseConvolutionDimsInfo(const HloInstruction* conv);
|
||||
|
||||
// Creates sharded convolution instruction that can be interpreted as a dot.
|
||||
// This is a utility for per-op partitioners.
|
||||
// - 'conv' is the original convolution instruction.
|
||||
// - 'dot_dnums' is the result of ParseDotGeneralFromConvolution() for 'conv'.
|
||||
// - 'dot_dnums' is the result of ParseDotConvolutionDimsInfo() for 'conv'.
|
||||
// - 'sharded_lhs_hlo' and 'sharded_rhs_hlo' are sharded inputs for the result
|
||||
// convolution instruction.
|
||||
StatusOr<std::unique_ptr<HloInstruction>>
|
||||
CreateShardedConvForDotGeneralConvolution(
|
||||
const HloInstruction& conv,
|
||||
const DotGeneralAsConvolutionDimsInfo& dot_dnums,
|
||||
const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums,
|
||||
HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo);
|
||||
|
||||
// Check if a spatial dim is parallel batch dimension.
|
||||
@ -68,10 +68,9 @@ CreateShardedConvForDotGeneralConvolution(
|
||||
// dilation B.
|
||||
bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size);
|
||||
|
||||
// Returns a DotGeneralAsConvolutionDimsInfo from a kDot instruction, where all
|
||||
// Returns a DotConvolutionDimsInfo from a kDot instruction, where all
|
||||
// the spatial_dim values are set to -1.
|
||||
DotGeneralAsConvolutionDimsInfo ParseDotGeneralFromDot(
|
||||
const HloInstruction* dot);
|
||||
DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot);
|
||||
|
||||
} // namespace dot_as_convolution_util
|
||||
} // namespace xla
|
||||
|
@ -476,7 +476,7 @@ bool SupportSpatialPartitioning(const HloInstruction* instruction,
|
||||
|
||||
bool InferDotShardingFromOperands(
|
||||
HloInstruction* instruction,
|
||||
const dot_as_convolution_util::DotGeneralAsConvolutionDimsInfo& dnums,
|
||||
const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
|
||||
bool may_combine_partial_sharding) {
|
||||
auto from_operand = [&](int64 operand_index) {
|
||||
auto operand = instruction->operand(operand_index);
|
||||
@ -543,9 +543,10 @@ bool InferDotShardingFromOperands(
|
||||
bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
|
||||
int64 aggressiveness,
|
||||
bool may_combine_partial_sharding) {
|
||||
if (auto dot_dims = dot_as_convolution_util::ParseDotGeneralFromConvolution(
|
||||
instruction)) {
|
||||
return InferDotShardingFromOperands(instruction, *dot_dims,
|
||||
auto dot_dims =
|
||||
dot_as_convolution_util::ParseConvolutionDimsInfo(instruction);
|
||||
if (dot_dims.conv_spatial_dims.empty()) {
|
||||
return InferDotShardingFromOperands(instruction, dot_dims,
|
||||
may_combine_partial_sharding);
|
||||
}
|
||||
const auto& dnums = instruction->convolution_dimension_numbers();
|
||||
@ -1031,7 +1032,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
|
||||
HloSharding InferDotOperandSharding(
|
||||
const HloInstruction* instruction,
|
||||
const dot_as_convolution_util::DotGeneralAsConvolutionDimsInfo& dnums,
|
||||
const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
|
||||
int64 operand_index, bool may_combine_partial_sharding) {
|
||||
auto operand = instruction->operand(operand_index);
|
||||
auto other = instruction->operand(1 - operand_index);
|
||||
@ -1185,10 +1186,10 @@ absl::optional<HloSharding> GetShardingFromUser(
|
||||
return HloSharding::Tile(new_tile_assignment);
|
||||
}
|
||||
case HloOpcode::kConvolution: {
|
||||
if (auto dot_dims =
|
||||
dot_as_convolution_util::ParseDotGeneralFromConvolution(&user)) {
|
||||
auto dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(&user);
|
||||
if (dot_dims.conv_spatial_dims.empty()) {
|
||||
int64 op_idx = user.operand_index(&instruction);
|
||||
return InferDotOperandSharding(&user, *dot_dims, op_idx,
|
||||
return InferDotOperandSharding(&user, dot_dims, op_idx,
|
||||
may_combine_partial_sharding);
|
||||
}
|
||||
return absl::nullopt;
|
||||
|
@ -23,6 +23,7 @@ cc_library(
|
||||
"spmd_partitioner_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"convolution_handler.h",
|
||||
"spmd_partitioner.h",
|
||||
"spmd_partitioner_util.h",
|
||||
],
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/spmd/convolution_handler.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
|
||||
@ -32,15 +34,8 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
namespace spmd {
|
||||
namespace {
|
||||
|
||||
// Partition convolution.
|
||||
StatusOr<HloInstruction*> PartitionConvolution(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding, const Window& conv_window,
|
||||
HloInstruction* original_hlo, int64 num_partitions,
|
||||
const SpmdPartitionerOptions& options, HloInstruction* partition_id,
|
||||
HloModule* module, SpmdBuilder* b);
|
||||
namespace {
|
||||
|
||||
// Partition convolution with batch group count.
|
||||
StatusOr<HloInstruction*> PartitionConvolutionWithBatchGroupCount(
|
||||
@ -240,95 +235,6 @@ StatusOr<HloInstruction*> PartitionConvolutionWithFeatureGroupCount(
|
||||
.hlo();
|
||||
}
|
||||
|
||||
// Partition convolution with only paralell dims are tiled
|
||||
StatusOr<HloInstruction*> PartitionConvolutionWithParallelDimension(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding, const Window& conv_window,
|
||||
HloInstruction* original_hlo, int64 num_partitions, SpmdBuilder* b) {
|
||||
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
|
||||
|
||||
const auto& dnums = original_hlo->convolution_dimension_numbers();
|
||||
std::vector<int64> rhs_to_lhs_indices(output_base_shape.rank());
|
||||
rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
|
||||
dnums.input_batch_dimension();
|
||||
rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
|
||||
dnums.input_feature_dimension();
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
|
||||
dnums.input_spatial_dimensions(i);
|
||||
}
|
||||
std::vector<int64> lhs_to_rhs_indices(output_base_shape.rank());
|
||||
for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
|
||||
lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
|
||||
}
|
||||
auto aligned_rhs_sharding =
|
||||
hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
|
||||
auto aligned_lhs_sharding =
|
||||
hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
|
||||
|
||||
// Handling cases where all the partitioned dimensions are parallel
|
||||
// dimensions.
|
||||
int64 lhs_parallel_dim_partitions = 1;
|
||||
int64 rhs_parallel_dim_partitions = 1;
|
||||
std::vector<int64> parallel_spatial_dims;
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
int64 lhs_dim = dnums.input_spatial_dimensions(i);
|
||||
int64 lhs_size = lhs.base_shape().dimensions(lhs_dim);
|
||||
const auto& wd = conv_window.dimensions(i);
|
||||
int64 rhs_dim = dnums.kernel_spatial_dimensions(i);
|
||||
if (dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) {
|
||||
parallel_spatial_dims.emplace_back(i);
|
||||
lhs_parallel_dim_partitions *= ShardCountAtDim(lhs.sharding(), lhs_dim);
|
||||
rhs_parallel_dim_partitions *= ShardCountAtDim(rhs.sharding(), rhs_dim);
|
||||
}
|
||||
}
|
||||
bool lhs_partition_dims_are_parallel =
|
||||
(lhs_parallel_dim_partitions == num_partitions);
|
||||
bool rhs_partition_dims_are_parallel =
|
||||
(rhs_parallel_dim_partitions == num_partitions);
|
||||
|
||||
// If there is a parallel dim and all the partitioned dimensions are parallel
|
||||
// dimensions in either LHS or RHS, simply create partitioned convolutions.
|
||||
if (parallel_spatial_dims.empty() || ((!lhs_partition_dims_are_parallel) &&
|
||||
(!rhs_partition_dims_are_parallel))) {
|
||||
return nullptr;
|
||||
}
|
||||
// Reshard LHS or RHS to partition at parallel dimensions as the other
|
||||
// operand.
|
||||
if (lhs_partition_dims_are_parallel) {
|
||||
rhs = rhs.Reshard(aligned_rhs_sharding);
|
||||
} else {
|
||||
lhs = lhs.Reshard(aligned_lhs_sharding);
|
||||
}
|
||||
|
||||
// Get LHS and RHS sharded shape.
|
||||
auto lhs_shard_shape = MakePartitionedShape(lhs.base_shape(), lhs.sharding());
|
||||
auto rhs_shard_shape = MakePartitionedShape(rhs.base_shape(), rhs.sharding());
|
||||
|
||||
// Update convolution window.
|
||||
auto new_window = conv_window;
|
||||
for (const auto& spatial_dim : parallel_spatial_dims) {
|
||||
auto wd = new_window.mutable_dimensions(spatial_dim);
|
||||
wd->set_size(lhs_shard_shape.dimensions(
|
||||
dnums.input_spatial_dimensions(spatial_dim)));
|
||||
wd->set_stride(std::max<int64>(1, wd->size() - 1));
|
||||
wd->set_base_dilation(wd->size());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape sharded_conv_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
lhs_shard_shape, rhs_shard_shape, original_hlo->feature_group_count(),
|
||||
original_hlo->batch_group_count(), new_window, dnums));
|
||||
auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve(
|
||||
sharded_conv_shape, lhs.hlo(), rhs.hlo(),
|
||||
original_hlo->feature_group_count(), original_hlo->batch_group_count(),
|
||||
new_window, dnums, original_hlo->precision_config()));
|
||||
sharded_conv->set_sharding(original_hlo->sharding());
|
||||
return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
|
||||
.Reshard(output_sharding)
|
||||
.hlo();
|
||||
}
|
||||
|
||||
// Partition convolution when both LHS and RHS are partitioned at spatial
|
||||
// dimensions. Halo exchange will happen on RHS only.
|
||||
StatusOr<HloInstruction*>
|
||||
@ -412,7 +318,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
|
||||
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
|
||||
int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
|
||||
int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension);
|
||||
auto wd = conv_window.dimensions(i);
|
||||
const auto& wd = conv_window.dimensions(i);
|
||||
if (wd.base_dilation() != 1 || wd.window_reversal()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -458,7 +364,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
|
||||
// Calculate the left and right halo sizes as described in the comments
|
||||
// above. It calculcates the halo sizes with dilation, so we apply
|
||||
// CeilOfRatio({left,right}_halo_size, window_dilation).
|
||||
auto wd = conv_window.dimensions(i);
|
||||
const auto& wd = conv_window.dimensions(i);
|
||||
int64 padding_low = wd.padding_low();
|
||||
int64 padding_high = wd.padding_high();
|
||||
int64 base = lhs.base_shape().dimensions(lhs_dimension);
|
||||
@ -628,7 +534,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
|
||||
lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
|
||||
}
|
||||
|
||||
Window window = conv_window;
|
||||
const Window& window = conv_window;
|
||||
std::vector<int64> reversed_rhs_dims;
|
||||
for (int64 i = 0; i < window.dimensions_size(); ++i) {
|
||||
if (window.dimensions(i).window_reversal()) {
|
||||
@ -703,7 +609,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
|
||||
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
|
||||
int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
|
||||
int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension);
|
||||
auto wd = window.dimensions(i);
|
||||
const auto& wd = window.dimensions(i);
|
||||
if (wd.base_dilation() != 1) {
|
||||
// TODO(wangtao): support parallel dim if it is replicate here.
|
||||
return nullptr;
|
||||
@ -738,7 +644,7 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
|
||||
|
||||
// Calculate the left and right halo sizes as described in the comments
|
||||
// above.
|
||||
auto wd = window.dimensions(i);
|
||||
const auto& wd = window.dimensions(i);
|
||||
int64 padding_low = wd.padding_low();
|
||||
int64 padding_high = wd.padding_high();
|
||||
int64 base = lhs.base_shape().dimensions(lhs_dimension);
|
||||
@ -890,116 +796,6 @@ StatusOr<HloInstruction*> PartitionConvolutionTiledOutput(
|
||||
shard_shape.dimensions()));
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> PartitionConvolutionGroupOnParallelDim(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding, const Window& conv_window,
|
||||
HloInstruction* original_hlo, const ConvolutionDimsMapping& dims_mapping,
|
||||
int64 num_partitions, const SpmdPartitionerOptions& options,
|
||||
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
|
||||
std::vector<int64> lhs_dims;
|
||||
std::vector<int64> rhs_dims;
|
||||
std::vector<int64> output_dims;
|
||||
auto lhs_sharding_dims_adjusted_to_output =
|
||||
lhs.sharding().IsReplicated()
|
||||
? std::vector<int64>(lhs.base_shape().rank(), 1)
|
||||
: lhs.sharding().tile_assignment().dimensions();
|
||||
auto rhs_sharding_dims_adjusted_to_output =
|
||||
rhs.sharding().IsReplicated()
|
||||
? std::vector<int64>(rhs.base_shape().rank(), 1)
|
||||
: rhs.sharding().tile_assignment().dimensions();
|
||||
auto output_sharding_dims_adjusted_to_lhs =
|
||||
output_sharding.tile_assignment().dimensions();
|
||||
bool lhs_rhs_dims_matching = true;
|
||||
for (const auto& dim : dims_mapping.parallel_spatial_dims) {
|
||||
lhs_dims.push_back(dim.lhs);
|
||||
rhs_dims.push_back(dim.rhs);
|
||||
output_dims.push_back(dim.output);
|
||||
if (lhs_sharding_dims_adjusted_to_output[dim.lhs] !=
|
||||
rhs_sharding_dims_adjusted_to_output[dim.rhs]) {
|
||||
lhs_rhs_dims_matching = false;
|
||||
}
|
||||
lhs_sharding_dims_adjusted_to_output[dim.lhs] =
|
||||
output_sharding.tile_assignment().dim(dim.output);
|
||||
rhs_sharding_dims_adjusted_to_output[dim.rhs] =
|
||||
output_sharding.tile_assignment().dim(dim.output);
|
||||
output_sharding_dims_adjusted_to_lhs[dim.output] =
|
||||
lhs.sharding().tile_assignment().dim(dim.lhs);
|
||||
}
|
||||
auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims);
|
||||
auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims);
|
||||
auto output_grouped = GroupShardingOnDims(output_sharding, output_dims);
|
||||
if (lhs_rhs_dims_matching) {
|
||||
if (ShapeUtil::ByteSizeOf(lhs.base_shape()) >
|
||||
ShapeUtil::ByteSizeOf(rhs.base_shape())) {
|
||||
rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped);
|
||||
rhs = rhs.Reshard(UngroupSharding(rhs_grouped));
|
||||
} else {
|
||||
lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped);
|
||||
lhs = lhs.Reshard(UngroupSharding(lhs_grouped));
|
||||
}
|
||||
auto reshaped_output_tiling = output_sharding.tile_assignment();
|
||||
reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs);
|
||||
output_grouped = AlignGroupsWith(
|
||||
GroupShardingOnDims(HloSharding::Tile(reshaped_output_tiling),
|
||||
output_dims),
|
||||
lhs_grouped);
|
||||
} else {
|
||||
auto reshaped_lhs_tiling = lhs.sharding().tile_assignment();
|
||||
reshaped_lhs_tiling.Reshape(lhs_sharding_dims_adjusted_to_output);
|
||||
lhs_grouped = AlignGroupsWith(
|
||||
GroupShardingOnDims(HloSharding::Tile(reshaped_lhs_tiling), lhs_dims),
|
||||
output_grouped);
|
||||
lhs = lhs.Reshard(UngroupSharding(lhs_grouped));
|
||||
auto reshaped_rhs_tiling = rhs.sharding().tile_assignment();
|
||||
reshaped_rhs_tiling.Reshape(rhs_sharding_dims_adjusted_to_output);
|
||||
rhs_grouped = AlignGroupsWith(
|
||||
GroupShardingOnDims(HloSharding::Tile(reshaped_rhs_tiling), rhs_dims),
|
||||
output_grouped);
|
||||
rhs = rhs.Reshard(UngroupSharding(rhs_grouped));
|
||||
}
|
||||
|
||||
// Update LHS and RHS sharding and shape.
|
||||
lhs.hlo()->set_sharding(lhs_grouped.sharding);
|
||||
rhs.hlo()->set_sharding(rhs_grouped.sharding);
|
||||
CHECK(lhs.hlo() != rhs.hlo() || lhs_grouped.sharding == rhs_grouped.sharding);
|
||||
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
|
||||
lhs.state(), lhs_grouped.device_groups, b);
|
||||
auto grouped_lhs_base_shape =
|
||||
GetPerGroupBaseShape(lhs_grouped, lhs.base_shape());
|
||||
auto grouped_lhs_shard_shape =
|
||||
MakePartitionedShape(grouped_lhs_base_shape, lhs.sharding());
|
||||
// Update convolution window with the new shape
|
||||
auto new_window = conv_window;
|
||||
for (const auto& dim : dims_mapping.parallel_spatial_dims) {
|
||||
auto wd = new_window.mutable_dimensions(dim.spatial);
|
||||
wd->set_size(grouped_lhs_shard_shape.dimensions(dim.lhs));
|
||||
wd->set_stride(std::max<int64>(1, wd->size() - 1));
|
||||
wd->set_base_dilation(wd->size());
|
||||
}
|
||||
|
||||
auto new_partition_id =
|
||||
lhs.state().collective_ops_creator.create_partition_id(b);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto conv,
|
||||
PartitionConvolution(
|
||||
PartitionedHlo(lhs.hlo(), grouped_lhs_base_shape,
|
||||
per_group_partitioner_state),
|
||||
PartitionedHlo(rhs.hlo(),
|
||||
GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
|
||||
per_group_partitioner_state),
|
||||
GetPerGroupBaseShape(output_grouped, output_base_shape),
|
||||
output_grouped.sharding, new_window, original_hlo,
|
||||
num_partitions / output_grouped.device_groups.size(), options,
|
||||
new_partition_id, module, b));
|
||||
// Reset the LHS sharding to the ungrouped one.
|
||||
lhs.hlo()->set_sharding(UngroupSharding(lhs_grouped));
|
||||
rhs.hlo()->set_sharding(UngroupSharding(rhs_grouped));
|
||||
conv->set_sharding(UngroupSharding(output_grouped));
|
||||
return PartitionedHlo(conv, output_base_shape, lhs.state())
|
||||
.Reshard(output_sharding)
|
||||
.hlo();
|
||||
}
|
||||
|
||||
// Partition convolution with only one kind of dims partitioned.
|
||||
StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
@ -1009,7 +805,7 @@ StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
|
||||
HloModule* module, SpmdBuilder* b) {
|
||||
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
|
||||
|
||||
// Case 0: Handle depthwise convolution with batch group count or
|
||||
// Case 1: Handle depthwise convolution with batch group count or
|
||||
// feature group count.
|
||||
if (original_hlo->batch_group_count() > 1) {
|
||||
TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv,
|
||||
@ -1031,15 +827,6 @@ StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
|
||||
}
|
||||
}
|
||||
|
||||
// Case 1: Either RHS or LHS is only partitioned at parallel dimensions.
|
||||
TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv,
|
||||
PartitionConvolutionWithParallelDimension(
|
||||
lhs, rhs, output_base_shape, output_sharding,
|
||||
conv_window, original_hlo, num_partitions, b));
|
||||
if (parallel_partitioned_conv) {
|
||||
return parallel_partitioned_conv;
|
||||
}
|
||||
|
||||
// Case 2: both RHS and LHS are tiled.
|
||||
// Handling cases where both operands' shardings are aligned. We check that
|
||||
// the LHS batch dimension is not partitioned because it is mapped to the
|
||||
@ -1082,13 +869,15 @@ StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Partition convolution.
|
||||
StatusOr<HloInstruction*> PartitionConvolution(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding, const Window& conv_window,
|
||||
HloInstruction* original_hlo, int64 num_partitions,
|
||||
const SpmdPartitionerOptions& options, HloInstruction* partition_id,
|
||||
HloModule* module, SpmdBuilder* b) {
|
||||
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|
||||
const Window& conv_window, HloInstruction* original_hlo,
|
||||
int64 num_partitions, const SpmdPartitionerOptions& options,
|
||||
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
|
||||
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -1100,133 +889,57 @@ StatusOr<HloInstruction*> PartitionConvolution(
|
||||
return try_partitioned_conv;
|
||||
}
|
||||
|
||||
const auto& dnums = original_hlo->convolution_dimension_numbers();
|
||||
spmd::ConvolutionDimsMapping mapping;
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
int64 lhs_dim = dnums.input_spatial_dimensions(i);
|
||||
int64 lhs_size = lhs.base_shape().dimensions(lhs_dim);
|
||||
const auto& wd = original_hlo->window().dimensions(i);
|
||||
int64 rhs_dim = dnums.kernel_spatial_dimensions(i);
|
||||
int64 output_dim = dnums.output_spatial_dimensions(i);
|
||||
if (dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) {
|
||||
mapping.parallel_spatial_dims.emplace_back();
|
||||
mapping.parallel_spatial_dims.back().lhs = lhs_dim;
|
||||
mapping.parallel_spatial_dims.back().rhs = rhs_dim;
|
||||
mapping.parallel_spatial_dims.back().output = output_dim;
|
||||
mapping.parallel_spatial_dims.back().spatial = i;
|
||||
} else {
|
||||
mapping.non_parallel_spatial_dims.emplace_back();
|
||||
mapping.non_parallel_spatial_dims.back().lhs = lhs_dim;
|
||||
mapping.non_parallel_spatial_dims.back().rhs = rhs_dim;
|
||||
mapping.non_parallel_spatial_dims.back().output = output_dim;
|
||||
mapping.non_parallel_spatial_dims.back().spatial = i;
|
||||
}
|
||||
}
|
||||
|
||||
// lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
|
||||
auto get_partitions_for_dims =
|
||||
[&](const HloSharding& sharding,
|
||||
absl::Span<const ConvolutionDimsMapping::DimsMapping> dims,
|
||||
int lhs_rhs_or_output) {
|
||||
int64 partitions = 1;
|
||||
if (sharding.IsTileMaximal()) {
|
||||
return partitions;
|
||||
}
|
||||
for (const auto& dim : dims) {
|
||||
if (lhs_rhs_or_output == 0) {
|
||||
partitions *= sharding.tile_assignment().dim(dim.lhs);
|
||||
} else if (lhs_rhs_or_output == 1) {
|
||||
partitions *= sharding.tile_assignment().dim(dim.rhs);
|
||||
} else {
|
||||
CHECK_EQ(lhs_rhs_or_output, 2);
|
||||
partitions *= sharding.tile_assignment().dim(dim.output);
|
||||
}
|
||||
}
|
||||
return partitions;
|
||||
};
|
||||
|
||||
const int64 lhs_parallel_spatial_partitions =
|
||||
get_partitions_for_dims(lhs.sharding(), mapping.parallel_spatial_dims, 0);
|
||||
const int64 rhs_parallel_spatial_partitions =
|
||||
get_partitions_for_dims(rhs.sharding(), mapping.parallel_spatial_dims, 1);
|
||||
const int64 output_parallel_spatial_partitions = get_partitions_for_dims(
|
||||
original_hlo->sharding(), mapping.parallel_spatial_dims, 2);
|
||||
|
||||
// Recursively partition on different types of dimensions.
|
||||
//
|
||||
// Case 1: Group partitions by parallel spatial dims.
|
||||
if (lhs_parallel_spatial_partitions == rhs_parallel_spatial_partitions &&
|
||||
lhs_parallel_spatial_partitions == output_parallel_spatial_partitions &&
|
||||
lhs_parallel_spatial_partitions > 1) {
|
||||
TF_ASSIGN_OR_RETURN(auto try_partitioned_conv,
|
||||
PartitionConvolutionGroupOnParallelDim(
|
||||
lhs, rhs, output_base_shape, output_sharding,
|
||||
conv_window, original_hlo, mapping, num_partitions,
|
||||
options, partition_id, module, b));
|
||||
if (try_partitioned_conv) {
|
||||
return try_partitioned_conv;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
|
||||
auto dot_dnums = dot_as_convolution_util::ParseDotGeneralFromConvolution(hlo);
|
||||
if (dot_dnums) {
|
||||
// Use HandleDotHelper() for convs that are actually einsums.
|
||||
spmd::DotGeneralDimsMapping mapping;
|
||||
for (const auto& dims : dot_dnums->batch_dims) {
|
||||
mapping.batch_dims.emplace_back();
|
||||
mapping.batch_dims.back().lhs = dims.lhs;
|
||||
mapping.batch_dims.back().rhs = dims.rhs;
|
||||
mapping.batch_dims.back().output = dims.output;
|
||||
}
|
||||
for (const auto& dims : dot_dnums->contracting_dims) {
|
||||
mapping.contracting_dims.emplace_back();
|
||||
mapping.contracting_dims.back().lhs = dims.lhs;
|
||||
mapping.contracting_dims.back().rhs = dims.rhs;
|
||||
mapping.contracting_dims.back().output = dims.output;
|
||||
}
|
||||
for (const auto& dims : dot_dnums->lhs_non_contracting_dims) {
|
||||
mapping.lhs_non_contracting_dims.emplace_back();
|
||||
mapping.lhs_non_contracting_dims.back().lhs = dims.lhs;
|
||||
mapping.lhs_non_contracting_dims.back().rhs = dims.rhs;
|
||||
mapping.lhs_non_contracting_dims.back().output = dims.output;
|
||||
}
|
||||
for (const auto& dims : dot_dnums->rhs_non_contracting_dims) {
|
||||
mapping.rhs_non_contracting_dims.emplace_back();
|
||||
mapping.rhs_non_contracting_dims.back().lhs = dims.lhs;
|
||||
mapping.rhs_non_contracting_dims.back().rhs = dims.rhs;
|
||||
mapping.rhs_non_contracting_dims.back().output = dims.output;
|
||||
}
|
||||
auto create_sharded_conv =
|
||||
[&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
|
||||
spmd::SpmdBuilder* b) -> StatusOr<HloInstruction*> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sharded_conv,
|
||||
dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
|
||||
*hlo, *dot_dnums, lhs_hlo, rhs_hlo));
|
||||
return b->AddInstruction(std::move(sharded_conv));
|
||||
};
|
||||
return HandleDotHelper(hlo, mapping, create_sharded_conv);
|
||||
auto dims_info = dot_as_convolution_util::ParseConvolutionDimsInfo(hlo);
|
||||
spmd::DotConvDimsMapping mapping;
|
||||
for (const auto& dims : dims_info.batch_dims) {
|
||||
mapping.batch_dims.emplace_back();
|
||||
mapping.batch_dims.back().lhs = dims.lhs;
|
||||
mapping.batch_dims.back().rhs = dims.rhs;
|
||||
mapping.batch_dims.back().output = dims.output;
|
||||
mapping.batch_dims.back().spatial = dims.spatial_dim;
|
||||
}
|
||||
|
||||
auto lhs = GetPartitionedHlo(hlo->operand(0));
|
||||
auto rhs = GetPartitionedHlo(hlo->operand(1));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto partitioned_conv,
|
||||
PartitionConvolution(lhs, rhs, hlo->shape(), hlo->sharding(),
|
||||
hlo->window(), hlo, num_partitions_, options_,
|
||||
partition_id_, module_, &b_));
|
||||
|
||||
if (partitioned_conv) {
|
||||
SetPartitionedHlo(hlo, [&] { return partitioned_conv; });
|
||||
return Status::OK();
|
||||
for (const auto& dims : dims_info.contracting_dims) {
|
||||
mapping.contracting_dims.emplace_back();
|
||||
mapping.contracting_dims.back().lhs = dims.lhs;
|
||||
mapping.contracting_dims.back().rhs = dims.rhs;
|
||||
mapping.contracting_dims.back().output = dims.output;
|
||||
mapping.contracting_dims.back().spatial = dims.spatial_dim;
|
||||
}
|
||||
return DefaultAction(hlo);
|
||||
for (const auto& dims : dims_info.lhs_non_contracting_dims) {
|
||||
mapping.lhs_non_contracting_dims.emplace_back();
|
||||
mapping.lhs_non_contracting_dims.back().lhs = dims.lhs;
|
||||
mapping.lhs_non_contracting_dims.back().rhs = dims.rhs;
|
||||
mapping.lhs_non_contracting_dims.back().output = dims.output;
|
||||
mapping.lhs_non_contracting_dims.back().spatial = dims.spatial_dim;
|
||||
}
|
||||
for (const auto& dims : dims_info.rhs_non_contracting_dims) {
|
||||
mapping.rhs_non_contracting_dims.emplace_back();
|
||||
mapping.rhs_non_contracting_dims.back().lhs = dims.lhs;
|
||||
mapping.rhs_non_contracting_dims.back().rhs = dims.rhs;
|
||||
mapping.rhs_non_contracting_dims.back().output = dims.output;
|
||||
mapping.rhs_non_contracting_dims.back().spatial = dims.spatial_dim;
|
||||
}
|
||||
for (const auto& dims : dims_info.conv_spatial_dims) {
|
||||
mapping.conv_spatial_dims.emplace_back();
|
||||
mapping.conv_spatial_dims.back().lhs = dims.lhs;
|
||||
mapping.conv_spatial_dims.back().rhs = dims.rhs;
|
||||
mapping.conv_spatial_dims.back().output = dims.output;
|
||||
mapping.conv_spatial_dims.back().spatial = dims.spatial_dim;
|
||||
}
|
||||
auto create_sharded_conv =
|
||||
[&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
|
||||
spmd::SpmdBuilder* b) -> StatusOr<HloInstruction*> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sharded_conv,
|
||||
dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
|
||||
*hlo, dims_info, lhs_hlo, rhs_hlo));
|
||||
return b->AddInstruction(std::move(sharded_conv));
|
||||
};
|
||||
return HandleDotHelper(hlo, mapping, create_sharded_conv);
|
||||
}
|
||||
|
||||
} // namespace spmd
|
||||
|
39
tensorflow/compiler/xla/service/spmd/convolution_handler.h
Normal file
39
tensorflow/compiler/xla/service/spmd/convolution_handler.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_CONVOLUTION_HANDLER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_CONVOLUTION_HANDLER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
|
||||
|
||||
namespace xla {
|
||||
namespace spmd {
|
||||
|
||||
// Partition convolution.
|
||||
StatusOr<HloInstruction*> PartitionConvolution(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|
||||
const Window& conv_window, HloInstruction* original_hlo,
|
||||
int64 num_partitions, const SpmdPartitionerOptions& options,
|
||||
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b);
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_CONVOLUTION_HANDLER_H_
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/service/spmd/convolution_handler.h"
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -36,7 +37,7 @@ namespace xla {
|
||||
namespace spmd {
|
||||
|
||||
Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
|
||||
DotGeneralDimsMapping mapping;
|
||||
DotConvDimsMapping mapping;
|
||||
const auto& dnums = hlo->dot_dimension_numbers();
|
||||
int64 next_output_dim = 0;
|
||||
for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) {
|
||||
@ -88,8 +89,8 @@ namespace {
|
||||
|
||||
StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding,
|
||||
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
|
||||
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|
||||
int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo, int64 lhs_batch_partitions,
|
||||
@ -98,7 +99,7 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions,
|
||||
int64 output_lhs_non_contracting_partitions,
|
||||
int64 output_rhs_non_contracting_partitions,
|
||||
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
|
||||
const SpmdPartitionerOptions& options, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
windowed_dot_general_loops,
|
||||
bool may_reshard_without_detecting_match) {
|
||||
@ -116,7 +117,7 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
std::vector<int64> output_to_lhs_indices(output_base_shape.rank(), -1);
|
||||
std::vector<int64> output_to_rhs_indices(output_base_shape.rank(), -1);
|
||||
auto populate_indices_mapping =
|
||||
[&](const DotGeneralDimsMapping::DimsMapping& mapping) {
|
||||
[&](const DotConvDimsMapping::DimsMapping& mapping) {
|
||||
if (mapping.lhs >= 0) {
|
||||
lhs_to_rhs_indices[mapping.lhs] = mapping.rhs;
|
||||
lhs_to_output_indices[mapping.lhs] = mapping.output;
|
||||
@ -142,6 +143,9 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) {
|
||||
populate_indices_mapping(mapping);
|
||||
}
|
||||
for (const auto& mapping : dims_mapping.conv_spatial_dims) {
|
||||
populate_indices_mapping(mapping);
|
||||
}
|
||||
auto lhs_sharding_transposed_to_match_rhs =
|
||||
hlo_sharding_util::TransposeShardingWithCollapsedDims(
|
||||
lhs_sharding, lhs_to_rhs_indices, rhs_to_lhs_indices);
|
||||
@ -408,7 +412,7 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
if (output_lhs_non_contracting_partitions == num_partitions &&
|
||||
output_sharding_transposed_to_match_lhs == lhs_sharding &&
|
||||
ShapeSizeInBytes(rhs.base_shape()) >=
|
||||
threshold_for_windowed_einsum_mib * 1024 * 1024) {
|
||||
options.threshold_for_windowed_einsum_mib * 1024 * 1024) {
|
||||
if (rhs_contracting_partitions == num_partitions) {
|
||||
return emit_windowed_dot_general(0, 1, true, false);
|
||||
}
|
||||
@ -422,7 +426,7 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
if (output_rhs_non_contracting_partitions == num_partitions &&
|
||||
output_sharding_transposed_to_match_rhs == rhs_sharding &&
|
||||
ShapeSizeInBytes(lhs.base_shape()) >=
|
||||
threshold_for_windowed_einsum_mib * 1024 * 1024) {
|
||||
options.threshold_for_windowed_einsum_mib * 1024 * 1024) {
|
||||
if (lhs_contracting_partitions == num_partitions) {
|
||||
return emit_windowed_dot_general(1, 0, true, false);
|
||||
}
|
||||
@ -572,26 +576,26 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
|
||||
StatusOr<HloInstruction*> PartitionDot(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding,
|
||||
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
|
||||
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|
||||
int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
|
||||
const SpmdPartitionerOptions& options, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
windowed_dot_general_loops);
|
||||
|
||||
StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding,
|
||||
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
|
||||
int64 lhs_contracting_partitions, int64 rhs_contracting_partitions,
|
||||
int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions,
|
||||
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|
||||
int64 num_partitions, int64 lhs_contracting_partitions,
|
||||
int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
|
||||
int64 rhs_non_contracting_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
bool require_matching_devices_to_group,
|
||||
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
|
||||
const SpmdPartitionerOptions& options, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
windowed_dot_general_loops) {
|
||||
std::vector<std::pair<HloInstruction*, HloSharding>>
|
||||
@ -804,8 +808,7 @@ StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
|
||||
GetPerGroupBaseShape(output_grouped, output_base_shape),
|
||||
output_grouped.sharding, dims_mapping,
|
||||
num_partitions / output_grouped.device_groups.size(),
|
||||
create_sharded_dot, module, original_hlo,
|
||||
threshold_for_windowed_einsum_mib, b,
|
||||
create_sharded_dot, module, original_hlo, options, b,
|
||||
windowed_dot_general_loops));
|
||||
dot->set_sharding(UngroupSharding(output_grouped));
|
||||
return PartitionedHlo(dot, output_base_shape, lhs.state())
|
||||
@ -816,17 +819,17 @@ StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
|
||||
StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
|
||||
bool lhs_matching, PartitionedHlo matching, PartitionedHlo other,
|
||||
int64 matching_contracting_partitions, int64 other_contracting_partitions,
|
||||
absl::Span<const DotGeneralDimsMapping::DimsMapping>
|
||||
absl::Span<const DotConvDimsMapping::DimsMapping>
|
||||
partitioned_non_contractin_dims,
|
||||
int64 other_non_contracting_partitions,
|
||||
int64 output_other_non_contracting_partitions,
|
||||
const Shape& output_base_shape, const HloSharding& output_sharding,
|
||||
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
|
||||
const DotConvDimsMapping& dims_mapping, int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
bool require_matching_devices_to_group,
|
||||
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
|
||||
const SpmdPartitionerOptions& options, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
windowed_dot_general_loops) {
|
||||
std::vector<std::pair<HloInstruction*, HloSharding>>
|
||||
@ -949,25 +952,24 @@ StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
|
||||
GetPerGroupBaseShape(output_grouped, output_base_shape),
|
||||
output_grouped.sharding, dims_mapping,
|
||||
num_partitions / matching_grouped.device_groups.size(),
|
||||
create_sharded_dot, module, original_hlo,
|
||||
threshold_for_windowed_einsum_mib, b,
|
||||
create_sharded_dot, module, original_hlo, options, b,
|
||||
windowed_dot_general_loops));
|
||||
return dot;
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs,
|
||||
absl::Span<const DotGeneralDimsMapping::DimsMapping>
|
||||
absl::Span<const DotConvDimsMapping::DimsMapping>
|
||||
partitioned_contractin_dims,
|
||||
int64 output_batch_partitions, int64 output_lhs_non_contracting_partitions,
|
||||
int64 output_rhs_non_contracting_partitions, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding,
|
||||
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
|
||||
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|
||||
int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
bool require_matching_devices_to_group,
|
||||
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
|
||||
const SpmdPartitionerOptions& options, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
windowed_dot_general_loops) {
|
||||
std::vector<std::pair<HloInstruction*, HloSharding>>
|
||||
@ -1090,8 +1092,8 @@ StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
|
||||
inner_state),
|
||||
MakePartitionedShape(output_base_shape, outer_output_tmp_sharding),
|
||||
inner_output_sharding, dims_mapping, num_partitions / group_count,
|
||||
create_sharded_dot, module, original_hlo,
|
||||
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
|
||||
create_sharded_dot, module, original_hlo, options, b,
|
||||
windowed_dot_general_loops));
|
||||
if (!dot) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -1119,19 +1121,19 @@ StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
|
||||
// in-group dot.
|
||||
StatusOr<HloInstruction*> PartitionDot(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding,
|
||||
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
|
||||
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|
||||
int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
bool require_matching_devices_to_group,
|
||||
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
|
||||
const SpmdPartitionerOptions& options, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
windowed_dot_general_loops) {
|
||||
// lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
|
||||
auto get_partitions_for_dims =
|
||||
[&](const HloSharding& sharding,
|
||||
absl::Span<const DotGeneralDimsMapping::DimsMapping> dims,
|
||||
absl::Span<const DotConvDimsMapping::DimsMapping> dims,
|
||||
int lhs_rhs_or_output) {
|
||||
int64 partitions = 1;
|
||||
if (sharding.IsTileMaximal()) {
|
||||
@ -1169,6 +1171,39 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
output_sharding, dims_mapping.rhs_non_contracting_dims, 2);
|
||||
// Before we find partial matches along the dimensions, invoke base case again
|
||||
// without may_reshard_without_detecting_match.
|
||||
|
||||
// Try partition the purely spatially-partitioned convolution first.
|
||||
if (!dims_mapping.conv_spatial_dims.empty()) {
|
||||
const auto& conv_dnums = original_hlo->convolution_dimension_numbers();
|
||||
auto window = original_hlo->window();
|
||||
|
||||
// TODO(wangtao): remove this hack by passing create_sharded_conv to
|
||||
// PartitionConv.
|
||||
// Update convolution window when it is in the recursive call for
|
||||
// batch_dims.
|
||||
if (original_hlo->batch_group_count() == 1 &&
|
||||
original_hlo->feature_group_count() == 1 &&
|
||||
!ShapeUtil::Compatible(original_hlo->shape(), output_base_shape)) {
|
||||
for (const auto& dim : dims_mapping.batch_dims) {
|
||||
auto wd = window.mutable_dimensions(dim.spatial);
|
||||
wd->set_size(lhs.hlo()->shape().dimensions(
|
||||
conv_dnums.input_spatial_dimensions(dim.spatial)));
|
||||
wd->set_stride(std::max<int64>(1, wd->size() - 1));
|
||||
wd->set_base_dilation(wd->size());
|
||||
}
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto partitioned_conv,
|
||||
PartitionConvolution(lhs, rhs, output_base_shape, output_sharding,
|
||||
dims_mapping, window, original_hlo, num_partitions,
|
||||
options, lhs.state().partition_id, module, b));
|
||||
|
||||
if (partitioned_conv) {
|
||||
return partitioned_conv;
|
||||
}
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto try_partitioned_dot,
|
||||
PartitionBaseCase(
|
||||
@ -1178,8 +1213,8 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
lhs_contracting_partitions, rhs_contracting_partitions,
|
||||
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
|
||||
output_lhs_non_contracting_partitions,
|
||||
output_rhs_non_contracting_partitions,
|
||||
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops,
|
||||
output_rhs_non_contracting_partitions, options, b,
|
||||
windowed_dot_general_loops,
|
||||
/*may_reshard_without_detecting_match=*/false));
|
||||
if (try_partitioned_dot) {
|
||||
return try_partitioned_dot;
|
||||
@ -1198,8 +1233,8 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
num_partitions, lhs_contracting_partitions,
|
||||
rhs_contracting_partitions, lhs_non_contracting_partitions,
|
||||
rhs_non_contracting_partitions, create_sharded_dot, module,
|
||||
original_hlo, require_matching_devices_to_group,
|
||||
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
|
||||
original_hlo, require_matching_devices_to_group, options, b,
|
||||
windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
return dot;
|
||||
}
|
||||
@ -1239,8 +1274,8 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
: output_lhs_non_contracting_partitions,
|
||||
output_base_shape, output_sharding, dims_mapping, num_partitions,
|
||||
create_sharded_dot, module, original_hlo,
|
||||
require_matching_devices_to_group,
|
||||
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
|
||||
require_matching_devices_to_group, options, b,
|
||||
windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
return dot;
|
||||
}
|
||||
@ -1248,7 +1283,7 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
if (lhs_non_contracting_partitions > 1 &&
|
||||
output_lhs_non_contracting_partitions > 1) {
|
||||
// If part of LHS non-contracting dims match output, try them.
|
||||
std::vector<DotGeneralDimsMapping::DimsMapping> matching_dims;
|
||||
std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
|
||||
for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
|
||||
int64 lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
|
||||
if (lhs_partitions > 1 &&
|
||||
@ -1265,9 +1300,8 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
rhs_non_contracting_partitions,
|
||||
output_rhs_non_contracting_partitions, output_base_shape,
|
||||
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
|
||||
module, original_hlo, require_matching_devices_to_group,
|
||||
threshold_for_windowed_einsum_mib, b,
|
||||
windowed_dot_general_loops));
|
||||
module, original_hlo, require_matching_devices_to_group, options,
|
||||
b, windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
return dot;
|
||||
}
|
||||
@ -1276,7 +1310,7 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
if (rhs_non_contracting_partitions > 1 &&
|
||||
output_rhs_non_contracting_partitions > 1) {
|
||||
// If part of RHS non-contracting dims match output, try them.
|
||||
std::vector<DotGeneralDimsMapping::DimsMapping> matching_dims;
|
||||
std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
|
||||
for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
|
||||
int64 rhs_partitions = rhs.sharding().tile_assignment().dim(dim.rhs);
|
||||
if (rhs_partitions > 1 &&
|
||||
@ -1293,9 +1327,8 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
lhs_non_contracting_partitions,
|
||||
output_lhs_non_contracting_partitions, output_base_shape,
|
||||
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
|
||||
module, original_hlo, require_matching_devices_to_group,
|
||||
threshold_for_windowed_einsum_mib, b,
|
||||
windowed_dot_general_loops));
|
||||
module, original_hlo, require_matching_devices_to_group, options,
|
||||
b, windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
return dot;
|
||||
}
|
||||
@ -1312,15 +1345,15 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
output_lhs_non_contracting_partitions,
|
||||
output_rhs_non_contracting_partitions, output_base_shape,
|
||||
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
|
||||
module, original_hlo, require_matching_devices_to_group,
|
||||
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
|
||||
module, original_hlo, require_matching_devices_to_group, options, b,
|
||||
windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
return dot;
|
||||
}
|
||||
}
|
||||
if (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1) {
|
||||
// If part of contracting dims match, try them.
|
||||
std::vector<DotGeneralDimsMapping::DimsMapping> matching_dims;
|
||||
std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
|
||||
for (const auto& dim : dims_mapping.contracting_dims) {
|
||||
int64 lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
|
||||
if (lhs_partitions > 1 &&
|
||||
@ -1336,9 +1369,8 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
output_lhs_non_contracting_partitions,
|
||||
output_rhs_non_contracting_partitions, output_base_shape,
|
||||
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
|
||||
module, original_hlo, require_matching_devices_to_group,
|
||||
threshold_for_windowed_einsum_mib, b,
|
||||
windowed_dot_general_loops));
|
||||
module, original_hlo, require_matching_devices_to_group, options,
|
||||
b, windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
return dot;
|
||||
}
|
||||
@ -1359,8 +1391,7 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
PartitionedHlo(rhs.hlo(), rhs.base_shape(), inner_state),
|
||||
output_base_shape, grouped_output.sharding, dims_mapping,
|
||||
output_sharding.NumTiles(), create_sharded_dot, module,
|
||||
original_hlo, threshold_for_windowed_einsum_mib, b,
|
||||
windowed_dot_general_loops));
|
||||
original_hlo, options, b, windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
return dot;
|
||||
}
|
||||
@ -1377,8 +1408,8 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
lhs_contracting_partitions, rhs_contracting_partitions,
|
||||
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
|
||||
output_lhs_non_contracting_partitions,
|
||||
output_rhs_non_contracting_partitions,
|
||||
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops,
|
||||
output_rhs_non_contracting_partitions, options, b,
|
||||
windowed_dot_general_loops,
|
||||
/*may_reshard_without_detecting_match=*/true));
|
||||
if (dot) {
|
||||
return dot;
|
||||
@ -1388,12 +1419,12 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
|
||||
StatusOr<HloInstruction*> PartitionDot(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding,
|
||||
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
|
||||
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|
||||
int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
|
||||
const SpmdPartitionerOptions& options, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
windowed_dot_general_loops) {
|
||||
// First try partitioning without resharding the groups, then try allow
|
||||
@ -1403,8 +1434,7 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
auto try_partition,
|
||||
PartitionDot(lhs, rhs, output_base_shape, output_sharding, dims_mapping,
|
||||
num_partitions, create_sharded_dot, module, original_hlo,
|
||||
require_matching_devices_to_group,
|
||||
threshold_for_windowed_einsum_mib, b,
|
||||
require_matching_devices_to_group, options, b,
|
||||
windowed_dot_general_loops));
|
||||
if (try_partition) {
|
||||
return try_partition;
|
||||
@ -1423,7 +1453,7 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
} // namespace
|
||||
|
||||
Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping,
|
||||
HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) {
|
||||
auto& lhs = GetPartitionedHlo(hlo->operand(0));
|
||||
@ -1431,9 +1461,8 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto partitioned_dot,
|
||||
PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping,
|
||||
num_partitions_, create_sharded_dot, module_, hlo,
|
||||
options_.threshold_for_windowed_einsum_mib, &b_,
|
||||
&windowed_dot_general_loops_));
|
||||
num_partitions_, create_sharded_dot, module_, hlo, options_,
|
||||
&b_, &windowed_dot_general_loops_));
|
||||
SetPartitionedHlo(hlo, [&] { return partitioned_dot; });
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -330,27 +330,11 @@ class PartitionedHlo {
|
||||
PartitioningState state_;
|
||||
};
|
||||
|
||||
struct DotGeneralDimsMapping {
|
||||
struct DotConvDimsMapping {
|
||||
// The dimension numbers for the operands and output corresponding to a
|
||||
// logical dimension (e.g., batch, contracting, non-contracting). If an
|
||||
// operand or the output doesn't have the logical dimension, it is set to
|
||||
// -1.
|
||||
struct DimsMapping {
|
||||
int64 lhs;
|
||||
int64 rhs;
|
||||
int64 output;
|
||||
};
|
||||
std::vector<DimsMapping> batch_dims;
|
||||
std::vector<DimsMapping> contracting_dims;
|
||||
std::vector<DimsMapping> lhs_non_contracting_dims;
|
||||
std::vector<DimsMapping> rhs_non_contracting_dims;
|
||||
};
|
||||
|
||||
struct ConvolutionDimsMapping {
|
||||
// The dimension numbers for the operands and output corresponding to a
|
||||
// logical dimension (e.g., batch, parallel, non-parallel). If an
|
||||
// operand or the output doesn't have the logical dimension, it is set to
|
||||
// -1.
|
||||
struct DimsMapping {
|
||||
int64 lhs;
|
||||
int64 rhs;
|
||||
@ -358,8 +342,11 @@ struct ConvolutionDimsMapping {
|
||||
// input mapped to index in input_spatial_dimensions().
|
||||
int64 spatial;
|
||||
};
|
||||
std::vector<DimsMapping> parallel_spatial_dims;
|
||||
std::vector<DimsMapping> non_parallel_spatial_dims;
|
||||
std::vector<DimsMapping> batch_dims;
|
||||
std::vector<DimsMapping> contracting_dims;
|
||||
std::vector<DimsMapping> lhs_non_contracting_dims;
|
||||
std::vector<DimsMapping> rhs_non_contracting_dims;
|
||||
std::vector<DimsMapping> conv_spatial_dims;
|
||||
};
|
||||
|
||||
class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
|
||||
@ -404,7 +391,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
// Implementation of dot partitioning given DotGeneralDimsMapping.
|
||||
Status HandleDotHelper(
|
||||
HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping,
|
||||
HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user