diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 3349528ebc2..126b62a8eb2 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -460,6 +460,37 @@ cc_library( ], ) +cc_library( + name = "hlo_sharding_util", + srcs = [ + "hlo_sharding_util.cc", + ], + hdrs = [ + "hlo_sharding_util.h", + ], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:array", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "hlo_sharding_util_test", + srcs = [ + "hlo_sharding_util_test.cc", + ], + deps = [ + ":hlo_sharding_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + tf_cc_test( name = "dynamic_parameter_binding_test", srcs = ["dynamic_parameter_binding_test.cc"], diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc new file mode 100644 index 00000000000..129091ca06f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc @@ -0,0 +1,574 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" + +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/array.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace hlo_sharding_util { + +absl::optional SelectDominantDevice( + const std::map& device_map, int64* top_count) { + int64 device = 0; + int64 count = 0; + for (auto& it : device_map) { + if (it.second > count) { + count = it.second; + device = it.first; + } + } + if (top_count != nullptr) { + *top_count = count; + } + return count > 0 ? absl::optional(device) : absl::optional(); +} + +Status AssignComputationDevice(HloComputation* computation, int64 device) { + VLOG(4) << "Assigning device " << device << " to " << computation->name() + << " computation"; + for (HloInstruction* instruction : computation->instructions()) { + if (!instruction->has_sharding()) { + VLOG(4) << "Assigning device " << device << " to " << instruction->name(); + instruction->set_device_sharding(device); + } + } + return Status::OK(); +} + +absl::optional GetMostOccurringDevice( + absl::Span instructions) { + std::map device_map; + for (HloInstruction* instruction : instructions) { + if (instruction->has_sharding()) { + for (auto& it : instruction->sharding().UsedDevices(nullptr)) { + // The UsedDevices() API returns a map. + device_map[it.first] += it.second; + } + } + } + return SelectDominantDevice(device_map, nullptr); +} + +StatusOr> GetDominantDevice( + absl::Span computations, double dominant_factor) { + int64 instruction_count = 0; + std::map device_map; + for (HloComputation* computation : computations) { + for (HloInstruction* instruction : computation->instructions()) { + int64 count = 1; + if (instruction->has_sharding()) { + for (auto& it : instruction->sharding().UsedDevices(&count)) { + // The UsedDevices() API returns a map. + device_map[it.first] += it.second; + } + } + instruction_count += count; + } + } + int64 count; + absl::optional device = SelectDominantDevice(device_map, &count); + absl::optional dominant_device; + if (device) { + double factor = + static_cast(count) / static_cast(instruction_count); + if (factor >= dominant_factor) { + dominant_device = device; + } + } + return dominant_device; +} + +HloSharding TransposeSharding(const HloSharding& sharding, + const std::vector& dimensions) { + if (sharding.IsTileMaximal()) { + return sharding; + } + const int64 rank = dimensions.size(); + std::vector tile_assignment_dim(rank); + for (int64 i = 0; i < rank; ++i) { + tile_assignment_dim[i] = sharding.tile_assignment().dim(dimensions[i]); + } + Array tile_assignment = sharding.tile_assignment(); + tile_assignment.Reshape(tile_assignment_dim); + tile_assignment.Each([&](absl::Span indices, int64* value) { + std::vector src_indices(indices.size(), -1); + for (int64 i = 0; i < indices.size(); ++i) { + src_indices[dimensions[i]] = indices[i]; + } + *value = sharding.tile_assignment()(src_indices); + }); + return HloSharding::Tile(tile_assignment); +} + +absl::optional ReshapeSharding(const Shape& source_shape, + const Shape& target_shape, + const HloSharding& sharding) { + if (sharding.IsTileMaximal()) { + return sharding; + } + + // In case of a tiled sharding the reshaped sharding will be a valid if the + // reshape is composed from the following operations: + // * Adding or removing dimensions with size 1. + // * Merging consecutive dimensions where only the most major is sharded. + // * Splitting a dimension to consecutive dimensions. + // * Any reshaping of unsharded dimensions. + // Note that merge and split can happen consecutively on the same dimension, + // e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024 + // gets split into 128 and 8, but 8 then gets merged with 256. We use stacks + // to make supporting such cases easy. + const Shape tile_shape = sharding.TileShape(source_shape); + std::vector target_tile_assignment_dimensions; + std::vector source_dims_stack(source_shape.rank()); + std::vector target_dims_stack(target_shape.rank()); + std::vector sharding_tile_dims_stack(source_shape.rank()); + for (int64 i = 0; i < source_shape.rank(); ++i) { + source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i); + sharding_tile_dims_stack[i] = + sharding.tile_assignment().dim(source_shape.rank() - 1 - i); + } + for (int64 i = 0; i < target_shape.rank(); ++i) { + target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i); + } + while (!source_dims_stack.empty() || !target_dims_stack.empty()) { + if (target_dims_stack.empty()) { + if (Product(sharding_tile_dims_stack) != 1) { + return absl::nullopt; + } + break; + } + int64 s_size = 1; + int64 t_size = 1; + int64 s_partitions = 1; + if (!source_dims_stack.empty()) { + s_size = source_dims_stack.back(); + source_dims_stack.pop_back(); + s_partitions = sharding_tile_dims_stack.back(); + sharding_tile_dims_stack.pop_back(); + } + t_size = target_dims_stack.back(); + target_dims_stack.pop_back(); + if (s_partitions * Product(sharding_tile_dims_stack) == 1) { + // No more partitions left. + target_tile_assignment_dimensions.push_back(1); + continue; + } + if (s_size == t_size) { + // Same dimension. + target_tile_assignment_dimensions.push_back(s_partitions); + } else if (t_size == 1) { + // Trivial dimension added. + target_tile_assignment_dimensions.push_back(1); + source_dims_stack.push_back(s_size); + sharding_tile_dims_stack.push_back(s_partitions); + } else if (s_size == 1) { + // Trivial dimension removed. + if (s_partitions != 1) { + return absl::nullopt; + } + target_dims_stack.push_back(t_size); + } else if (s_size > t_size) { + // Dimension split. + if (s_size % t_size != 0 || t_size % s_partitions != 0) { + return absl::nullopt; + } + target_tile_assignment_dimensions.push_back(s_partitions); + // We have part of the s_size unprocessed, so put it back to stack. + source_dims_stack.push_back(s_size / t_size); + sharding_tile_dims_stack.push_back(1); + } else { + // Dimension merge. Also merge the source dimension with the next, and + // process it next time. + if (s_size % s_partitions != 0) { + return absl::nullopt; + } + CHECK(!source_dims_stack.empty()); + if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) { + // If the next dimension to combine is sharded, we require that the + // current dimension's shard size to be 1. Otherwise, the new shard + // would be non-contiguous. + return absl::nullopt; + } + source_dims_stack.back() *= s_size; + sharding_tile_dims_stack.back() *= s_partitions; + target_dims_stack.push_back(t_size); + } + } + Array new_tile_assignment = sharding.tile_assignment(); + new_tile_assignment.Reshape(target_tile_assignment_dimensions); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, + absl::Span dims) { + CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal()); + CHECK_NE(absl::c_find(dims, dim), dims.end()) << "dim is not in dims"; + // We optimize the tile assignment on the single dimension dim in a way to + // minimize communication among devices caused by the reshard: + // +---+---+ +---+---+ +-+-+-+-+ + // | | | | 0 | | | | | | + // | 0 | 1 | +-------+ | | | | | + // | | | reshape on | 1 | reshape on | | | | | + // +---+---+ dim 0 => +-------+ dim 1 => |0|2|1|3| + // | | | | 2 | | | | | | + // | 2 | 3 | +-------+ | | | | | + // | | | | 3 | | | | | | + // +---+---+ +---+---+ +-+-+-+-+ + + std::vector tile_dims(sharding.tile_assignment().num_dimensions(), 1); + // Handle ignore dimensions. + std::vector ignore_sizes; + int64 ignore_size = 1; + for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + if (absl::c_find(dims, i) == dims.end()) { + int64 size = sharding.tile_assignment().dim(i); + ignore_sizes.push_back(size); + tile_dims[i] = size; + ignore_size *= size; + } + } + + using Buckets = std::vector>; + Array buckets(ignore_sizes, + Buckets(sharding.tile_assignment().dim(dim))); + sharding.tile_assignment().Each( + [&](absl::Span index, int64 device) { + std::vector ignore_index; + for (int64 i = 0; i < index.size(); ++i) { + if (absl::c_find(dims, i) == dims.end()) { + ignore_index.push_back(index[i]); + } + } + buckets(ignore_index)[index[dim]].push_back(device); + }); + std::vector devices; + buckets.Each([&](absl::Span index, const Buckets& buckets) { + for (auto& bucket : buckets) { + devices.insert(devices.end(), bucket.begin(), bucket.end()); + } + }); + tile_dims[dim] = devices.size() / ignore_size; + Array tile_assignment(tile_dims); + tile_assignment.SetValues(devices); + return HloSharding::Tile(tile_assignment); +} + +bool ContainsTileSharding(const HloModule& module) { + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + if (instruction->has_sharding() && + !instruction->sharding().IsTileMaximal()) { + return true; + } + } + } + return false; +} + +HloSharding GatherOutputSharding(const HloSharding& index_sharding, + const HloInstruction* hlo) { + if (index_sharding.IsTileMaximal()) { + return index_sharding; + } + + const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers(); + std::vector output_tile_assignment_dims; + for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) { + if (absl::c_binary_search(dnums.offset_dims(), i)) { + output_tile_assignment_dims.push_back(1); + } else { + output_tile_assignment_dims.push_back( + index_sharding.tile_assignment().dim(index_dim)); + index_dim++; + } + } + Array new_tile_assignment = index_sharding.tile_assignment(); + new_tile_assignment.Reshape(output_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding GatherIndexSharding(const HloSharding& output_sharding, + const HloInstruction* hlo) { + if (output_sharding.IsTileMaximal()) { + return output_sharding; + } + + const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers(); + std::vector index_tile_assignment_dims; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.offset_dims(), i)) { + index_tile_assignment_dims.push_back( + output_sharding.tile_assignment().dim(i)); + } + } + Array new_tile_assignment = output_sharding.tile_assignment(); + new_tile_assignment.Reshape(index_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) { + if (hlo.sharding().IsTileMaximal()) { + return hlo.sharding(); + } + + const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers(); + std::vector tile_assignment_dims(hlo.shape().rank()); + int64 num_elements = 1; + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.offset_dims(), i)) { + tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i); + num_elements *= hlo.sharding().tile_assignment().dim(i); + } else { + tile_assignment_dims[i] = 1; + } + } + if (num_elements == hlo.sharding().tile_assignment().num_elements()) { + // Output sharding is only on non offset dimensions. We use output sharding + // to shard this gather op directly. + return hlo.sharding(); + } + + if (num_elements == 1) { + // Output sharding is only on offset dimensions. We do not shard this gather + // op. Return a tile maximal sharding with the first device in output + // sharding tile assignment. + return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin()); + } + + // Output sharding is on both offset and non offset dimensions. We shard the + // gather op only on non offset dimensions. + // For example: + // - the gather op has sharding [2,2]{0,1,2,3}, + // - first dimension is non offset dimension, + // - second dimension is offset dimension, + // Then the result sharding will be [2,1]{0,2}. + std::vector slice_starts(hlo.shape().rank(), 0LL), + slice_limits(hlo.shape().rank()); + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.offset_dims(), i)) { + slice_limits[i] = hlo.sharding().tile_assignment().dim(i); + } else { + slice_limits[i] = 1; + } + } + Array tile_assignment = + hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits); + return HloSharding::Tile(tile_assignment); +} + +HloSharding ScatterIndexSharding(const HloSharding& data_sharding, + const HloInstruction* hlo) { + if (data_sharding.IsTileMaximal()) { + return data_sharding; + } + + const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers(); + std::vector index_tile_assignment_dims; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.update_window_dims(), i)) { + index_tile_assignment_dims.push_back( + data_sharding.tile_assignment().dim(i)); + } + } + if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) { + index_tile_assignment_dims.push_back(1); + } + Array new_tile_assignment = data_sharding.tile_assignment(); + new_tile_assignment.Reshape(index_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding ScatterDataSharding(const HloSharding& index_sharding, + const HloInstruction* hlo) { + if (index_sharding.IsTileMaximal()) { + return index_sharding; + } + + const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers(); + std::vector data_tile_assignment_dims; + for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) { + if (absl::c_binary_search(dnums.update_window_dims(), i)) { + data_tile_assignment_dims.push_back(1); + } else { + data_tile_assignment_dims.push_back( + index_sharding.tile_assignment().dim(index_dim)); + index_dim++; + } + } + Array new_tile_assignment = index_sharding.tile_assignment(); + new_tile_assignment.Reshape(data_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, + const HloInstruction& hlo) { + if (index_sharding.IsTileMaximal()) { + return index_sharding; + } + + // Only shard on first "number of scatter_window_dims" dimensions. + const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers(); + int64 num_elements = 1; + int64 index_dim = 0; + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (absl::c_binary_search(dnums.inserted_window_dims(), i)) { + num_elements *= index_sharding.tile_assignment().dim(index_dim); + index_dim++; + } + } + if (num_elements == index_sharding.tile_assignment().num_elements()) { + // Index sharding is only on scatter_window_dims. We use this index sharding + // directly. + return index_sharding; + } + + // Index sharding is only on update_window_dims. We do not shard this scatter + // op. Return a tile maximal sharding with the first device in index sharding + // tile assignment. + if (num_elements == 1) { + return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin()); + } + + const int64 index_rank = hlo.operand(1)->shape().rank(); + std::vector slice_starts(index_rank, 0LL), slice_limits(index_rank); + for (int64 i = 0; i < index_rank; ++i) { + if (i < index_dim) { + slice_limits[i] = index_sharding.tile_assignment().dim(i); + } else { + slice_limits[i] = 1; + } + } + Array tile_assignment = + index_sharding.tile_assignment().Slice(slice_starts, slice_limits); + return HloSharding::Tile(tile_assignment); +} + +HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, + const HloInstruction& hlo) { + if (data_sharding.IsTileMaximal()) { + return data_sharding; + } + + const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers(); + const int64 data_rank = hlo.operand(2)->shape().rank(); + std::vector tile_assignment_dims(data_rank, 1LL); + int64 num_elements = 1; + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (absl::c_binary_search(dnums.inserted_window_dims(), i)) { + CHECK_LT(i, data_rank); + tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i); + num_elements *= data_sharding.tile_assignment().dim(i); + } + } + if (num_elements == data_sharding.tile_assignment().num_elements()) { + // Data sharding is only on scatter_window_dims. We use this data sharding + // directly. + return data_sharding; + } + + if (num_elements == 1) { + // Data sharding is only on update_window_dims. We do not shard this + // scatter op. Return a tile maximal sharding with the first device in + // data sharding tile assignment. + return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin()); + } + + // Data sharding is on both update_window_dims and scatter_window_dims. We + // shard the scatter op only on scatter_window_dims. For example: + // - the scatter data has sharding [2,2]{0,1,2,3}, + // - first dimension is scatter_window_dims, + // - second dimension is update_window_dims, + // Then the result sharding will be [2,1]{0,2}. + std::vector slice_starts(data_rank, 0LL); + Array tile_assignment = + data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims); + return HloSharding::Tile(tile_assignment); +} + +StatusOr, HloOpcode>> +IdentityValueAndHloOpcodeForScatterReduceComputation( + const HloScatterInstruction& scatter) { + auto computation = scatter.to_apply(); + // We only handle computations with 2 parameters and only 1 calculation. + if (computation->instruction_count() != 3) { + return Status( + tensorflow::error::Code::INVALID_ARGUMENT, + "Expected scatter reduce computation with 2 parameters and only 1 " + "calculation"); + } + + auto root_instruction = computation->root_instruction(); + if (root_instruction->opcode() == HloOpcode::kAdd || + root_instruction->opcode() == HloOpcode::kOr) { + return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::Zero( + scatter.shape().element_type())), + root_instruction->opcode()); + } else if (root_instruction->opcode() == HloOpcode::kMultiply || + root_instruction->opcode() == HloOpcode::kAnd) { + return std::make_pair(HloInstruction::CreateConstant( + LiteralUtil::One(scatter.shape().element_type())), + root_instruction->opcode()); + } else if (root_instruction->opcode() == HloOpcode::kMaximum) { + return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MinValue( + scatter.shape().element_type())), + root_instruction->opcode()); + } else if (root_instruction->opcode() == HloOpcode::kMinimum) { + return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MaxValue( + scatter.shape().element_type())), + root_instruction->opcode()); + } + + return Status(tensorflow::error::Code::INVALID_ARGUMENT, + "Expected scatter reduce computation which is " + "add/or/multiply/add/min/max"); +} + +std::vector DevicesForSharding( + const HloSharding& sharding, const std::vector& available_devices) { + std::vector devices; + if (sharding.IsReplicated()) { + for (int64 d : available_devices) { + if (!HloSharding::IsReservedDevice(d)) { + devices.push_back(d); + } + } + return devices; + } + + for (int64 i : available_devices) { + if (sharding.UsesDevice(i)) { + devices.push_back(i); + } + } + DCHECK(std::all_of(sharding.tile_assignment().begin(), + sharding.tile_assignment().end(), [&](int64 device) { + return std::find(available_devices.begin(), + available_devices.end(), + device) != available_devices.end(); + })); + return devices; +} + +} // namespace hlo_sharding_util +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.h b/tensorflow/compiler/xla/service/hlo_sharding_util.h new file mode 100644 index 00000000000..00d9434a34d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.h @@ -0,0 +1,143 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" + +namespace xla { +namespace hlo_sharding_util { + +// Given a map, selects the device with higher +// occurrence count (if any). If top_count in not nullptr, it will receive the +// count of the dominant device returned. +absl::optional SelectDominantDevice( + const std::map& device_map, int64* top_count); + +// Assigns all the instructions of a computation, to a given device. +// This API does not recurse into called computations, and does not assign +// instructions which already have sharding. +Status AssignComputationDevice(HloComputation* computation, int64 device); + +// Given an instruction container, returns the device which is most commonly +// occurring among the instructions. +absl::optional GetMostOccurringDevice( + absl::Span instructions); + +// Given a set of computations, tries to extract the dominant device. A device +// is dominant if the combined occurrence among all the instructions of the +// input computations, is greater/equal than/to dominant_factor (real number +// from 0 to 1). +// This API does not recurse into called computations. +// If no device exists that satisfies the condition, the returned optional will +// hold no value. +StatusOr> GetDominantDevice( + absl::Span computations, double dominant_factor); + +// Returns the HloSharding with the tile dimensions and tile assignment +// transposed based on the specified dimension numbers. In case of a tile +// maximal sharding returns the original sharding. +HloSharding TransposeSharding(const HloSharding& sharding, + const std::vector& dimensions); + +// Returns the HloSharding with the tile shape reshaped based on the source and +// target shapes and the tile assignment adjusted to correspond to the new tile +// shape or absl::nullopt if the resulting reshape would create an invalid +// sharding (non continuous or non uniformly sized tiles). In case of a tile +// maximal sharding returns the original sharding. +absl::optional ReshapeSharding(const Shape& source_shape, + const Shape& target_shape, + const HloSharding& sharding); + +// Returns a sharding tiled on unique dimension dim by reshaping the tile +// assignment of the sharding argument. Only dimensions in the dims span +// argument are considered for reshaping, the others are ignored. +// Assumptions: sharding is tile sharded, and dim must be included in dims. +HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, + absl::Span dims); + +// Returns true if the provided module includes one or more instructions with +// a tile sharding. +bool ContainsTileSharding(const HloModule& module); + +// Returns the preferred output sharding for a gather op based on the sharding +// of the indces. +HloSharding GatherOutputSharding(const HloSharding& index_sharding, + const HloInstruction* hlo); + +// Returns the preferred index sharding for a gather op based on the sharding +// of the output. +HloSharding GatherIndexSharding(const HloSharding& output_sharding, + const HloInstruction* hlo); + +// Returns a new HloSharding for a gather op so that only non offset dimensions +// are sharded. Assume "result" is returned by this function. It is ensured that +// "GetIndexSharding(result, hlo)" will have the same number of elements as +// "result". +HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo); + +// Returns the preferred index sharding for a scatter op based on the sharding +// of the data. +HloSharding ScatterIndexSharding(const HloSharding& data_sharding, + const HloInstruction* hlo); + +// Returns the preferred data sharding for a scatter op based on the sharding +// of the index. +HloSharding ScatterDataSharding(const HloSharding& index_sharding, + const HloInstruction* hlo); + +// Returns a new index sharding for a scatter op so that we only shard on first +// "number of scatter_window_dims" dimensions. Assume "result" is returned by +// this function. It is ensured that "ScatterDataSharding(result, hlo)" will +// have the same number of elements as "result". +HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, + const HloInstruction& hlo); + +// Returns a new data sharding for a scatter op so that we only shard on +// scatter_window_dims. Assume "result" is returned by this function. It is +// ensured that "ScatterIndexSharding(result, hlo)" will have the same number of +// elements as "result". +HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, + const HloInstruction& hlo); + +// Returns an identity value and an HloOpcode for reduce computation of scatter +// instruction. +// - If computation is add/or, return 0/false with corresponding op code; +// - If computation is multiply/and, return 1/true with corresponding op code. +// - If computation is min/max, return max value/min value with corresponding op +// code. +// - Otherwise, return error status. +StatusOr, HloOpcode>> +IdentityValueAndHloOpcodeForScatterReduceComputation( + const HloScatterInstruction& scatter); + +// Given a sharding and a list of devices in the topology, return a +// list of the devices that `sharding` applies to. +std::vector DevicesForSharding( + const HloSharding& sharding, const std::vector& available_devices); + +} // namespace hlo_sharding_util +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc new file mode 100644 index 00000000000..02496c75965 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc @@ -0,0 +1,206 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" + +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace hlo_sharding_util { +namespace { + +TEST(HloShardingUtilTest, TransposeShardingReplicated) { + EXPECT_EQ(TransposeSharding(HloSharding::Replicate(), {0, 1, 2}), + HloSharding::Replicate()); +} + +TEST(HloShardingUtilTest, TransposeShardingTiled) { + HloSharding input = HloSharding::Tile(Array4D({{{{0, 1}}, {{2, 3}}}})); + HloSharding output = + HloSharding::Tile(Array4D({{{{0}, {2}}}, {{{1}, {3}}}})); + EXPECT_EQ(TransposeSharding(input, {3, 0, 1, 2}), output); +} + +TEST(HloShardingUtilTest, ReshapeShardingMaximal) { + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5}); + Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2}); + HloSharding sharding = HloSharding::AssignDevice(7); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledInvalid) { + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5}); + Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2}); + HloSharding sharding = HloSharding::Tile(Array3D({{{0}, {1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_FALSE(result.has_value()); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledMerge) { + Shape input_shape = ShapeUtil::MakeShape(F32, {4, 5, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {20, 7}); + HloSharding input_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + HloSharding output_sharding = HloSharding::Tile(Array2D({{0}, {1}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledSplit) { + Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 4, 7}); + HloSharding input_sharding = HloSharding::Tile(Array2D({{0}, {1}})); + HloSharding output_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledSplitThenMerge) { + Shape input_shape = ShapeUtil::MakeShape(F32, {16, 4, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 16, 7}); + HloSharding input_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + HloSharding output_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledArbitraryMinorDimensions) { + Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7, 5, 3}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 15, 2, 14}); + Array sharding_array({2, 1, 1, 1}); + sharding_array(0, 0, 0, 0) = 0; + sharding_array(1, 0, 0, 0) = 1; + HloSharding sharding = HloSharding::Tile(sharding_array); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledTrivialDimensions) { + Shape input_shape = ShapeUtil::MakeShape(F32, {3, 1, 5, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 1, 7}); + HloSharding input_sharding = + HloSharding::Tile(Array4D({{{{0}, {1}}}})); + HloSharding output_sharding = + HloSharding::Tile(Array4D({{{{0}}, {{1}}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTrivialDImensionInsertedToEnd) { + Shape input_shape = ShapeUtil::MakeShape(F32, {8, 16}); + Shape output_shape = ShapeUtil::MakeShape(F32, {8, 16, 1}); + HloSharding input_sharding = HloSharding::Tile(Array2D({{0}, {1}})); + HloSharding output_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, NoopReshapeShardingEmptyTile) { + Shape shape = ShapeUtil::MakeShape(F32, {7, 1, 1}); + HloSharding sharding = HloSharding::Tile(Array3D({{{0}, {1}}})); + absl::optional result = ReshapeSharding(shape, shape, sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingScalar) { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1}); + Shape output_shape = ShapeUtil::MakeShape(F32, {}); + HloSharding sharding = HloSharding::Tile(Array3D({{{0}, {1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_FALSE(result.has_value()); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim0) { + HloSharding sharding = HloSharding::Tile(Array2D({{0, 1}, {2, 3}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1}); + EXPECT_EQ(result.tile_assignment(), Array2D({{0}, {1}, {2}, {3}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim1) { + HloSharding sharding = HloSharding::Tile(Array2D({{0, 1}, {2, 3}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1}); + EXPECT_EQ(result.tile_assignment(), Array2D({{0, 2, 1, 3}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim0) { + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1, 2}); + EXPECT_EQ( + result.tile_assignment(), + Array3D({{{0}}, {{1}}, {{2}}, {{3}}, {{4}}, {{5}}, {{6}}, {{7}}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim1) { + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1, 2}); + EXPECT_EQ(result.tile_assignment(), + Array3D({{{0}, {1}, {4}, {5}, {2}, {3}, {6}, {7}}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim2) { + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/2, /*dims=*/{0, 1, 2}); + EXPECT_EQ(result.tile_assignment(), + Array3D({{{0, 2, 4, 6, 1, 3, 5, 7}}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim2_Batch1) { + // Tile sharding in batch dimension, i.e. + // sharding={devices[2,2,2]0,1,2,3,4,5,6,7,8}. + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + // Reshape on dimensions {1, 2} only, therefore ignoring batch dimension 0. + HloSharding result = ReshapeToTileDimension(sharding, /*dim=*/2, + /*dims=*/{1, 2}); + // Expected result is {devices=[2,1,4]0,2,1,3,4,6,5,7}, i.e. the two + // non-batch dimensions {{0, 1}, {2, 3}} and {{4, 5}, {6, 7}} are individually + // reshaped to tile dimension 2, i.e. {{0, 2, 1, 3}}, {{4, 6, 5, 7}}. + EXPECT_EQ(result.tile_assignment(), + Array3D({{{0, 2, 1, 3}}, {{4, 6, 5, 7}}})); +} + +} // namespace +} // namespace hlo_sharding_util +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD new file mode 100644 index 00000000000..5be6a04f934 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/BUILD @@ -0,0 +1,69 @@ +# Description: SPMD partitioning pass. + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package( + default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +cc_library( + name = "spmd_partitioner", + srcs = [ + "spmd_partitioner.cc", + "spmd_partitioner_util.cc", + ], + hdrs = [ + "spmd_partitioner.h", + "spmd_partitioner_util.h", + ], + deps = [ + "//tensorflow/compiler/xla:comparison_util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:protobuf_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client/lib:comparators", + "//tensorflow/compiler/xla/service:flatten_call_graph", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_cse", + "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:hlo_query", + "//tensorflow/compiler/xla/service:hlo_sharding_util", + "//tensorflow/compiler/xla/service:shape_inference", + "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/core/platform:numbers", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "spmd_partitioner_test", + srcs = ["spmd_partitioner_test.cc"], + deps = [ + ":spmd_partitioner", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc new file mode 100644 index 00000000000..fd865342ca3 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -0,0 +1,4655 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" + +#include + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/comparison_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/numbers.h" + +namespace xla { +namespace spmd { + +string SpmdLogger::MakeReport() { + string report; + absl::StrAppend(&report, + "\n\n***** SPMD memory during transformation *****\n"); + + std::sort(entries_.begin(), entries_.end(), + [](auto const& entry0, auto const& entry1) { + return entry0.first > entry1.first; + }); + for (int64 i = 0; + i < std::min(report_instruction_count_, entries_.size()); ++i) { + absl::StrAppend( + &report, "\n ", + tensorflow::strings::HumanReadableNumBytes(entries_[i].first), " : ", + entries_[i].second, "\n"); + } + + return report; +} + +void SpmdLogger::RegisterLogEntry(HloInstruction* hlo, + const std::vector& group) { + string report = hlo->ToString(); + int64 max_value = -1; + for (HloInstruction* inst : group) { + if (inst->shape().IsTuple()) { + continue; + } + max_value = + std::max(max_value, ShapeUtil::ByteSizeOf(inst->shape(), 4)); + absl::StrAppend(&report, " * ", inst->ToString(), "\n"); + } + entries_.push_back(std::make_pair(max_value, report)); +} + +/* static */ string SpmdLogger::ReportBeforePartition( + const HloModule& module, int64 report_instruction_count) { + string report; + absl::StrAppend(&report, + "\n\n***** SPMD memory usage before partition *****\n"); + absl::StrAppend(&report, "\n ** Replicated instructions\n"); + absl::StrAppend(&report, ReportMemoryUsage( + module, + [](const HloInstruction* hlo) { + return !hlo->has_sharding() || + hlo->sharding().IsReplicated(); + }, + report_instruction_count)); + absl::StrAppend(&report, "\n ** All instructions\n"); + absl::StrAppend(&report, + ReportMemoryUsage( + module, [](const HloInstruction* hlo) { return true; }, + report_instruction_count)); + return report; +} + +/* static */ string SpmdLogger::ReportAfterPartition( + const HloModule& module, int64 report_instruction_count) { + string report; + absl::StrAppend(&report, + "\n\n***** SPMD memory usage after partition *****\n"); + absl::StrAppend(&report, + ReportMemoryUsage( + module, [](const HloInstruction* hlo) { return true; }, + report_instruction_count)); + return report; +} + +template +/* static */ string SpmdLogger::ReportMemoryUsage( + const HloModule& module, const F& filter, int64 report_instruction_count) { + string report; + std::vector instructions; + instructions.reserve(module.instruction_count()); + + for (auto computation : module.computations()) { + if (computation->IsFusionComputation()) { + continue; + } + for (auto hlo : computation->instructions()) { + if (hlo->shape().IsTuple() || + ShapeUtil::IsEffectiveScalar(hlo->shape())) { + continue; + } + if (filter(hlo)) { + instructions.push_back(hlo); + } + } + } + + const auto add_report = [&](std::vector* insts) { + std::sort(insts->begin(), insts->end(), + [](const HloInstruction* inst0, const HloInstruction* inst1) { + return ShapeUtil::ByteSizeOf(inst0->shape()) > + ShapeUtil::ByteSizeOf(inst1->shape()); + }); + for (int64 i = 0; + i < std::min(report_instruction_count, insts->size()); ++i) { + absl::StrAppend(&report, " ", + tensorflow::strings::HumanReadableNumBytes( + ShapeUtil::ByteSizeOf((*insts)[i]->shape())), + " : ", (*insts)[i]->ToString(), "\n"); + } + }; + + add_report(&instructions); + return report; +} + +namespace { + +// Returns the replica group configuration where each replica belongs to its own +// group. +std::vector CreateReplicaGroups(int64 num_replicas) { + std::vector groups(num_replicas); + for (int64 i = 0; i < num_replicas; ++i) { + groups[i].add_replica_ids(i); + } + return groups; +} + +bool CanReshardWithAllToAll(const HloSharding& source, + const HloSharding& target) { + return UniqueTiledDim(source) && UniqueTiledDim(target) && + UniqueTiledDim(source) != UniqueTiledDim(target); +} + +bool CanReshardWithCollectivePermute(const HloSharding& source, + const HloSharding& target) { + return UniqueTiledDim(source) && UniqueTiledDim(target) && + UniqueTiledDim(source) == UniqueTiledDim(target) && source != target; +} + +// Clears all sharding attributes from instructions in the module. This must be +// called only after all SPMD transformation is complete. +Status ClearShardingAttributes(HloModule* module) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* hlo : computation->instructions()) { + // Keep sharding annotation on Infeed and entry parameters since they're + // used by HloReplicationAnalysis later (for ArCrsCombiner). + if (hlo->opcode() == HloOpcode::kInfeed) { + continue; + } + if (hlo->opcode() == HloOpcode::kParameter && + computation == module->entry_computation()) { + continue; + } + hlo->clear_sharding(); + } + } + return Status::OK(); +} + +} // namespace + +HloInstruction* SpmdBuilder::AddInstruction( + std::unique_ptr instruction) { + HloInstruction* hlo = + HloComputation::Builder::AddInstruction(std::move(instruction)); + if (visiting_hlo_) { + instructions_[visiting_hlo_].push_back(hlo); + } + return hlo; +} + +PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) { + auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache; + for (auto& entry : cache) { + if (entry.first == target) { + return entry.second; + } + } + cache.emplace_back(target, ReshardNoCache(target)); + state_.reshard_cache->per_hlo_cache[cache.back().second.hlo()] + .reshard_cache.emplace_back(sharding(), *this); + return cache.back().second; +} + +PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { + VLOG(2) << "Resharding " << hlo_->ToString() << " from " + << hlo_->sharding().ToString() << " to " << target.ToString(); + const Shape& shape = hlo_->shape(); + CHECK(shape.IsTuple() || !target.IsTuple()); + + // Tuple shape instructions may have non-tuple sharding, which means that the + // same sharding applies to all the leaves. + if (shape.IsTuple() && !target.IsTuple()) { + return Reshard(target.GetTupleSharding(shape).ValueOrDie()); + } + + // For a tuple shape, recursively apply Reshard to all the leaves and return + // a tuple instruction. + if (shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + auto subshape = ShapeUtil::GetTupleElementShape(shape, i); + auto element = state_.b->AddInstruction( + HloInstruction::CreateGetTupleElement(subshape, hlo(), i)); + element->set_sharding(sharding().GetSubSharding(shape, {i})); + elements.push_back( + PartitionedHlo( + element, ShapeUtil::GetTupleElementShape(base_shape_, i), state_) + .Reshard(target.GetSubSharding(shape, {i})) + .hlo()); + } + auto tuple = + state_.b->AddInstruction(HloInstruction::CreateTuple(elements)); + tuple->set_sharding(target); + return PartitionedHlo(tuple, base_shape_, state_); + } + + if (sharding() == target) { + return *this; + } + + if (shape.element_type() == TOKEN) { + return *this; + } + + if (CanReshardWithCollectivePermute(sharding(), target)) { + return ReshardWithCollectivePermute(target); + } + + if (CanReshardWithAllToAll(sharding(), target)) { + return ReshardWithAllToAll(target); + } + + // If not replicated yet, first replicate and then reshard to use one of the + // two implementations below. + if (!sharding().IsReplicated()) { + return Replicate().Reshard(target); + } + + // 'Replicated' to 'SingleDevice'. + if (target.IsTileMaximal()) { + auto copy = state_.b->AddInstruction( + HloInstruction::CreateUnary(hlo_->shape(), HloOpcode::kCopy, hlo_)); + copy->set_sharding(target); + return PartitionedHlo(copy, base_shape_, state_); + } + + // 'Replicated' to 'Tiled'. + auto padded_hlo = + PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b); + auto shard_shape = MakePartitionedShape(shape, target); + auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, padded_hlo, + MakePartitionOffsets(shape, target, state_.partition_id, state_.b), + shard_shape.dimensions())); + slice->set_sharding(target); + return PartitionedHlo(slice, base_shape_, state_); +} + +PartitionedHlo PartitionedHlo::PadWithValue(HloInstruction* pad_value) const { + const HloSharding& sharding = hlo_->sharding(); + const Shape& shape = hlo_->shape(); + CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); + if (sharding.IsReplicated() || EvenlyPartitions(base_shape_, sharding)) { + return *this; + } + CHECK(!sharding.IsTileMaximal()); + auto index_shape = ShapeUtil::ChangeElementType(shape, S32); + auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED); + auto get_mask_for_dim = [&](int64 dim, HloInstruction* start_index) { + // Comparison: iota + start_index < valid_size + auto iota = + state_.b->AddInstruction(HloInstruction::CreateIota(index_shape, dim)); + auto broadcast_start_index = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(index_shape, start_index, {})); + auto index_in_full_shape = + state_.b->AddInstruction(HloInstruction::CreateBinary( + index_shape, HloOpcode::kAdd, iota, broadcast_start_index)); + auto valid_size = state_.b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(base_shape_.dimensions(dim)))); + auto broadcast_valid_size = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(index_shape, valid_size, {})); + return state_.b->AddInstruction(HloInstruction::CreateCompare( + mask_shape, index_in_full_shape, broadcast_valid_size, + ComparisonDirection::kLt)); + }; + + HloInstruction* mask = nullptr; + auto offsets = MakePartitionOffsets(base_shape_, sharding, + state_.partition_id, state_.b); + for (int64 i = 0; i < shape.rank(); ++i) { + if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0) { + continue; + } + if (mask == nullptr) { + mask = get_mask_for_dim(i, offsets[i]); + } else { + mask = state_.b->AddInstruction( + HloInstruction::CreateBinary(mask->shape(), HloOpcode::kAnd, mask, + get_mask_for_dim(i, offsets[i]))); + } + } + + if (mask == nullptr) { + return *this; + } + + auto broadcast_pad_value = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(shape, pad_value, {})); + auto result = state_.b->AddInstruction(HloInstruction::CreateTernary( + shape, HloOpcode::kSelect, mask, hlo_, broadcast_pad_value)); + result->set_sharding(sharding); + return PartitionedHlo(result, base_shape_, state_); +} + +absl::optional +PartitionedHlo::ReshardAsWindowedInput(const Window& window, + const HloSharding& target, + HloInstruction* pad_value, + bool mask_invalid_region) { + auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].window_reshard_cache; + for (auto& entry : cache) { + if (std::get<0>(entry) == target && + protobuf_util::ProtobufEquals(std::get<1>(entry), window)) { + return std::get<2>(entry); + } + } + auto update_cache = [&](WindowedInputShardReturnValue result) { + cache.emplace_back(target, window, std::move(result)); + return std::get<2>(cache.back()); + }; + VLOG(2) << "ReshardAsWindowedInput()\n" + << "\twindow:" << window_util::ToString(window) + << "\ttarget sharding:" << target.ToString(); + + CHECK(!target.IsTileMaximal()); + auto partition_ordinals = + MakeTiledPartitionOrdinals(target, state_.partition_id, state_.b); + auto shard_shape = base_shape_; + + std::vector start_on_padded_calculations( + base_shape_.rank()); + std::vector limit_on_padded_calculations( + base_shape_.rank()); + std::vector dynamic_slice_offset_on_output( + base_shape_.rank(), nullptr); + + Window shard_window = window; + auto padded_shape = base_shape_; + std::vector offsets_on_padded_shape(base_shape_.rank()); + std::vector per_shard_window_counts(base_shape_.rank()); + std::vector explicit_left_padding(base_shape_.rank()); + for (int64 i = 0; i < base_shape_.rank(); ++i) { + // Do not pad non-partitioned dimensions. + int64 shard_count = target.tile_assignment().dim(i); + if (shard_count == 1) { + offsets_on_padded_shape[i] = state_.b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + continue; + } + const auto& wd = window.dimensions(i); + if (wd.window_dilation() != 1) { + // TODO(yuanzx): Support window dilation. + VLOG(2) << "Failed to reshard window operand due to window dilation"; + return absl::nullopt; + } + int64 full_size = + base_shape_.dimensions(i) + + (wd.base_dilation() - 1) * (base_shape_.dimensions(i) - 1) + + wd.padding_high() + wd.padding_low(); + if (full_size < wd.size()) { + VLOG(2) << "Failed to reshard window operand because the window size is " + "larger than padded base size"; + return absl::nullopt; + } + int64 window_count = (full_size - wd.size()) / wd.stride() + 1; + per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count); + if (wd.stride() != 1 && + (wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) { + // TODO(yuanzx): Support this case. + VLOG(2) << "Failed to reshard window operand due to non-trivial dilation"; + return absl::nullopt; + } + + // We use explicit padding for full dilations, then use padding_low and + // padding_high on the sharded op for the remaining. padding_low and + // padding_high are now given initial values, which will be later updated if + // dilation is not 1. + auto swd = shard_window.mutable_dimensions(i); + explicit_left_padding[i] = wd.padding_low() / wd.base_dilation(); + swd->set_padding_low(wd.padding_low() % wd.base_dilation()); + swd->set_padding_high(0); + + // Calculation for the first element needed on the 'padded-but-not-dilated' + // shape. The start on the dilated shape could be a hole, so we add + // wd.base_dilation() - 1 to the constant term to skip the leading holes. + start_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation( + wd.stride() * per_shard_window_counts[i], + wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation()); + int64 dilated_shard_size = + wd.stride() * (per_shard_window_counts[i] - 1) + wd.size(); + limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation( + wd.stride() * per_shard_window_counts[i], + dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(), + wd.base_dilation()); + + offsets_on_padded_shape[i] = start_on_padded_calculations[i].Calculate( + partition_ordinals[i], state_.b); + + auto shard_size_function = + limit_on_padded_calculations[i] - start_on_padded_calculations[i]; + int64 max_shard_size = shard_size_function.MaxInRange(0, shard_count); + shard_shape.set_dimensions(i, max_shard_size); + padded_shape.set_dimensions( + i, limit_on_padded_calculations[i].Calculate(shard_count - 1)); + + // For base dilation, calculate the needed padding_low and padding_high, as + // well as the offset for the output if a dynamic slice is needed after the + // sharded op. + if (wd.base_dilation() != 1) { + // Returns the offset of a shard's first valid element in the dilated + // shard. + auto get_first_valid_element_offset_on_dilated_shard = + [&](int64 shard_ordinal) { + return start_on_padded_calculations[i].Calculate(shard_ordinal) * + wd.base_dilation() + + swd->padding_low() - + wd.stride() * per_shard_window_counts[i] * shard_ordinal; + }; + CHECK_EQ(get_first_valid_element_offset_on_dilated_shard(0), + swd->padding_low()); + + // Determine swd->padding_high. + for (int64 shard_ordinal = 0; shard_ordinal < shard_count; + ++shard_ordinal) { + int64 wanted_limit_on_dilated_shard = + wd.stride() * (per_shard_window_counts[i] - 1) + wd.size(); + int64 actual_limit_on_dilated_shard_without_pad_high = + get_first_valid_element_offset_on_dilated_shard(shard_ordinal) + + (max_shard_size - 1) * wd.base_dilation() + 1; + swd->set_padding_high(std::max( + swd->padding_high(), + wanted_limit_on_dilated_shard - + actual_limit_on_dilated_shard_without_pad_high)); + } + + // Determine swd->padding_low and output dynamic slice index. + if (wd.stride() == 1) { + int64 max_pad_low = get_first_valid_element_offset_on_dilated_shard(0); + bool all_same = true; + for (int64 shard_ordinal = 1; shard_ordinal < shard_count; + ++shard_ordinal) { + int64 start = + get_first_valid_element_offset_on_dilated_shard(shard_ordinal); + if (start != swd->padding_low()) { + all_same = false; + } + max_pad_low = std::max(max_pad_low, start); + } + if (!all_same) { + auto start_on_padded_input = + start_on_padded_calculations[i].Calculate(partition_ordinals[i], + state_.b); + // We will calculate + // max_pad_low - (first_window - required_first_window) + // which equals + // required_first_window - (first_window - max_pad_low) + auto first_window_minus_max_pad_low = + MultiplyAddDivideOffsetCalculation( + wd.base_dilation(), swd->padding_low() - max_pad_low, 1) + .Calculate(start_on_padded_input, state_.b); + auto required_first_window = + MultiplyAddDivideOffsetCalculation(per_shard_window_counts[i], 0, + 1) + .Calculate(partition_ordinals[i], state_.b); + dynamic_slice_offset_on_output[i] = + state_.b->AddInstruction(HloInstruction::CreateBinary( + required_first_window->shape(), HloOpcode::kSubtract, + required_first_window, first_window_minus_max_pad_low)); + } + swd->set_padding_low(max_pad_low); + } else { + CHECK_EQ( + (wd.stride() * per_shard_window_counts[i]) % wd.base_dilation(), 0) + << "General base dilation not yet implemented."; + // padding_low on all shards should equal the initially assigned + // swd->padding_low(), i.e., the padding_low() on the original window. + } + } + } + + // Returns the output dynamic slice offset when needed, and absl::nullopt + // otherwise. + auto get_dynamic_slice_offset_on_output_if_needed = + [&]() -> absl::optional> { + if (absl::c_all_of( + dynamic_slice_offset_on_output, + [](HloInstruction* offset) { return offset == nullptr; })) { + return absl::nullopt; + } + auto zero = state_.b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + for (int64 i = 0; i < dynamic_slice_offset_on_output.size(); ++i) { + if (dynamic_slice_offset_on_output[i] == nullptr) { + dynamic_slice_offset_on_output[i] = zero; + } + } + return dynamic_slice_offset_on_output; + }; + + // If the currrent HLO is replicated, pad then slice. + if (sharding().IsReplicated()) { + PaddingConfig padding_config; + for (int64 i = 0; i < base_shape_.rank(); ++i) { + auto padding_config_dim = padding_config.add_dimensions(); + padding_config_dim->set_interior_padding(0); + // Do not pad non-partitioned dimensions. + if (target.tile_assignment().dim(i) == 1) { + padding_config_dim->set_edge_padding_low(0); + padding_config_dim->set_edge_padding_high(0); + continue; + } + padding_config_dim->set_edge_padding_low(explicit_left_padding[i]); + padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) - + explicit_left_padding[i] - + base_shape_.dimensions(i)); + } + auto padded_hlo = ShapeUtil::Compatible(padded_shape, base_shape_) + ? hlo_ + : state_.b->AddInstruction(HloInstruction::CreatePad( + padded_shape, hlo_, pad_value, padding_config)); + auto sharded_input = + state_.b->AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, padded_hlo, offsets_on_padded_shape, + shard_shape.dimensions())); + return update_cache(WindowedInputShardReturnValue{ + sharded_input, shard_window, + get_dynamic_slice_offset_on_output_if_needed()}); + } + + if (target != sharding()) { + return Replicate().ReshardAsWindowedInput(window, target, pad_value); + } + + // Halo exchange. + HloInstruction* visiting_hlo = hlo_; + auto original_shard_shape = MakePartitionedShape(base_shape_, target); + + std::vector left_halo_size_functions(base_shape_.rank()); + std::vector right_halo_size_functions(base_shape_.rank()); + // TODO(yuanzx): We are concatenating on each sharded dimension one at time, + // and in the second dimension (and beyond) we create halos by slicing the + // concat in the previous dimension, which is not optimal. We should generate + // halos only concating slices, instead of slicing concats. + for (int dim = 0; dim < base_shape_.rank(); ++dim) { + int64 shard_count = target.tile_assignment().dim(dim); + if (shard_count == 1) { + continue; + } + int64 input_shard_size = + CeilOfRatio(base_shape_.dimensions(dim), shard_count); + + // Left halo. The size of the halo is derived by subtracting the first read + // element offset of the i'th partition from the limit of the (i-1)'th + // partition. + MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded( + input_shard_size, explicit_left_padding[dim], 1); + left_halo_size_functions[dim] = + shard_limit_of_previous_on_padded - start_on_padded_calculations[dim]; + + // Right halo. + MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded( + input_shard_size, input_shard_size + explicit_left_padding[dim], 1); + right_halo_size_functions[dim] = + limit_on_padded_calculations[dim] - shard_start_of_next_on_padded; + + auto resharded = ExchangeHaloAndGetValidData( + visiting_hlo, base_shape_, left_halo_size_functions[dim], + right_halo_size_functions[dim], explicit_left_padding[dim], + padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim, target, + offsets_on_padded_shape[dim], pad_value, partition_ordinals[dim], + state_.collective_ops_creator, state_.next_channel_id, state_.b, + mask_invalid_region); + if (!resharded) { + VLOG(1) << "ReshardAsWindowedInput failed without replicate first: halo " + "is beyond the neighbor."; + return Replicate().ReshardAsWindowedInput(window, target, pad_value); + } + visiting_hlo = *resharded; + } + return update_cache(WindowedInputShardReturnValue{ + visiting_hlo, shard_window, + get_dynamic_slice_offset_on_output_if_needed()}); +} + +PartitionedHlo PartitionedHlo::Replicate() { + const HloSharding& sharding = hlo_->sharding(); + const Shape& shape = hlo_->shape(); + CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); + + if (sharding.IsReplicated()) { + return *this; + } + auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache; + for (auto& entry : cache) { + if (entry.first.IsReplicated()) { + return entry.second; + } + } + auto update_cache = [&](PartitionedHlo resharded) { + state_.reshard_cache->per_hlo_cache[resharded.hlo()] + .reshard_cache.emplace_back(sharding, *this); + cache.emplace_back(HloSharding::Replicate(), std::move(resharded)); + return cache.back().second; + }; + // 'Single Device' to 'Repliated'. + if (sharding.IsTileMaximal()) { + return update_cache(Broadcast()); + } + + // 'Tiled' to 'Replicated'. + Shape padded_base_shape = shape; + for (int64 i = 0; i < padded_base_shape.rank(); ++i) { + padded_base_shape.set_dimensions( + i, shape.dimensions(i) * sharding.tile_assignment().dim(i)); + } + auto zero = state_.b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); + auto zero_bcast = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(padded_base_shape, zero, {})); + auto dus = state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + padded_base_shape, zero_bcast, hlo_, + MakePartitionOffsets(padded_base_shape, sharding, state_.partition_id, + state_.b))); + HloComputation* reduction = + MakeBinaryAdd(shape.element_type(), state_.module); + + auto all_reduce = + state_.collective_ops_creator.create_cross_partition_all_reduce( + state_.b, dus, reduction, NewChannel()); + HloInstruction* result = all_reduce; + if (!ShapeUtil::Compatible(base_shape_, padded_base_shape)) { + std::vector start_indices(shape.rank(), 0); + std::vector strides(shape.rank(), 1); + result = state_.b->AddInstruction(HloInstruction::CreateSlice( + base_shape_, result, start_indices, base_shape_.dimensions(), strides)); + } + result->set_sharding(HloSharding::Replicate()); + return update_cache(PartitionedHlo(result, base_shape_, state_)); +} + +PartitionedHlo PartitionedHlo::Broadcast() const { + const Shape& shape = hlo_->shape(); + const HloSharding& sharding = hlo_->sharding(); + CHECK(sharding.HasUniqueDevice()); + CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); + + auto src_core_id = state_.b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(sharding.GetUniqueDevice()))); + Shape bcast_shape = ShapeUtil::ChangeElementType(shape, PRED); + auto is_src_core = state_.b->AddInstruction(HloInstruction::CreateBroadcast( + bcast_shape, + state_.b->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), state_.partition_id, src_core_id, + ComparisonDirection::kEq)), + {})); + + auto zero = state_.b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); + auto zero_bcast = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(shape, zero, {})); + auto operand = state_.b->AddInstruction(HloInstruction::CreateTernary( + shape, HloOpcode::kSelect, is_src_core, hlo(), zero_bcast)); + HloComputation* reduction = + MakeBinaryAdd(shape.element_type(), state_.module); + + auto result = state_.collective_ops_creator.create_cross_partition_all_reduce( + state_.b, operand, reduction, NewChannel()); + result->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(result, base_shape_, state_); +} + +PartitionedHlo PartitionedHlo::ReshardWithAllToAll( + const HloSharding& target) const { + int64 partition_count = sharding().tile_assignment().num_elements(); + absl::optional input_partition_dim = UniqueTiledDim(sharding()); + absl::optional output_partition_dim = UniqueTiledDim(target); + CHECK(input_partition_dim.has_value()); + CHECK(output_partition_dim.has_value()); + + // If the device order is different in the target, fix the order with + // ReshardWithCollectivePermute. + auto input_tile_fixed_device_order = target.tile_assignment(); + input_tile_fixed_device_order.Reshape( + sharding().tile_assignment().dimensions()); + auto input_sharding_fixed_device_order = + HloSharding::Tile(input_tile_fixed_device_order); + if (input_sharding_fixed_device_order != sharding()) { + auto fixed_order = + ReshardWithCollectivePermute(input_sharding_fixed_device_order); + return fixed_order.ReshardWithAllToAll(target); + } + + auto padded_hlo = + PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b); + + // The order of ids in the group must follow the target sharding. + std::vector groups(1); + for (int64 device : target.tile_assignment()) { + groups[0].add_replica_ids(device); + } + + HloInstruction* result = nullptr; + + // Split along the split dimension (output_partition_dim) of the all-to-all + // output. + std::vector dimensions; + for (int64 i = 0; i < base_shape_.rank(); ++i) { + if (i == *output_partition_dim) { + dimensions.push_back(partition_count); + dimensions.push_back(padded_hlo->shape().dimensions(i) / partition_count); + } else { + dimensions.push_back(padded_hlo->shape().dimensions(i)); + } + } + auto reshape = state_.b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(base_shape_.element_type(), dimensions), + padded_hlo)); + // After the reshape, it is guaranteed to have at least 3 dimensions. + auto all_to_all = + state_.collective_ops_creator.create_cross_partition_all_to_all( + state_.b, {reshape}, groups, (*state_.next_channel_id)++, + output_partition_dim); + + // Reorder the split dimension of the reshape to be located in front of the + // input partition dimension, so the two dimensions can be combined. + int64 new_input_partition_dim = (*output_partition_dim < *input_partition_dim) + ? *input_partition_dim + 1 + : *input_partition_dim; + std::vector permutation; + for (int64 i = 0; i < all_to_all->shape().rank(); ++i) { + if (i == *output_partition_dim) { + continue; + } + if (i == new_input_partition_dim) { + permutation.push_back(*output_partition_dim); + } + permutation.push_back(i); + } + auto transpose = state_.b->AddInstruction(HloInstruction::CreateTranspose( + ShapeInference::InferTransposeShape(all_to_all->shape(), permutation) + .ValueOrDie(), + all_to_all, permutation)); + + // Combine the split dimension and the input partition dimension. + auto new_shape = ShapeInference::InferAllToAllShape( + padded_hlo->shape(), *output_partition_dim, + *input_partition_dim, partition_count) + .ValueOrDie(); + result = state_.b->AddInstruction( + HloInstruction::CreateReshape(new_shape, transpose)); + + const Shape result_shape = MakePartitionedShape(base_shape_, target); + if (result_shape != result->shape()) { + result = state_.b->AddInstruction(HloInstruction::CreateSlice( + result_shape, result, std::vector(result_shape.rank(), 0), + result_shape.dimensions(), std::vector(result_shape.rank(), 1))); + } + result->set_sharding(target); + return PartitionedHlo(result, base_shape_, state_); +} + +PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute( + const HloSharding& target) const { + CHECK(CanReshardWithCollectivePermute(sharding(), target)); + std::vector> src_dst_pairs; + sharding().tile_assignment().Each( + [&](absl::Span indices, int64 src_device) { + int64 dst_device = target.tile_assignment()(indices); + if (dst_device != src_device) { + src_dst_pairs.emplace_back(src_device, dst_device); + } + }); + auto cp = + state_.collective_ops_creator.create_cross_partition_collective_permute( + state_.b, hlo(), src_dst_pairs, (*state_.next_channel_id)++); + cp->set_sharding(target); + return PartitionedHlo(cp, base_shape_, state_); +} + +SpmdPartitioningVisitor::SpmdPartitioningVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, SpmdPartitionerOptions options, + SpmdPartitioner* partitioner) + : changed_(false), + module_(computation->parent()), + num_partitions_(num_partitions), + num_replicas_(num_replicas), + collective_ops_creator_(collective_ops_creator), + next_channel_id_(next_channel_id), + b_(SpmdBuilder(computation->name() + "_spmd", /*hlo=*/nullptr)), + partition_id_(collective_ops_creator_.create_partition_id(&b_)), + logger_(logger), + options_(std::move(options)), + partitioner_(partitioner) {} + +Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) { + if (hlo->HasSideEffect()) { + return Unimplemented("Side-effect ops cannot be replicated: %s", + hlo->ToString()); + } + + if (hlo->IsElementwise() && hlo->operand_count() > 0) { + return HandleElementwise(hlo); + } + + if (!hlo->sharding().IsTileMaximal()) { + VLOG(1) << "Not partitioned in SPMD mode (DefaultAction):" + << hlo->ToString(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + VLOG(1) << " operand " << i + << " sharding:" << hlo->operand(i)->sharding().ToString(); + } + } + + // If the instruction cannot be partitioned, replicate the instruction unless + // the instruction has side-effect. + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back( + GetPartitionedHlo(operand).Reshard(HloSharding::Replicate()).hlo()); + } + auto clone = + b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands)); + clone->set_sharding(HloSharding::Replicate()); + clone->set_metadata(hlo->metadata()); + SetPartitionedHlo(hlo, + PartitionedHlo(clone, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding())); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) { + visiting_hlo_ = hlo; + b_.set_visiting_hlo(hlo); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) { + logger_->RegisterLogEntry(GetPartitionedHlo(hlo).hlo(), + b_.derived_instructions(hlo)); + visiting_hlo_ = nullptr; + b_.set_visiting_hlo(nullptr); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) { + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back( + GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo()); + } + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + const Shape shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + const int64 dimension = hlo->concatenate_dimension(); + if (sharding.tile_assignment().dim(dimension) == 1) { + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back( + GetPartitionedHlo(operand).Reshard(sharding).hlo()); + } + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction( + hlo->CloneWithNewOperands(shard_shape, new_operands)); + }); + return Status::OK(); + } + + // If the concatenate dimension is along one of the partitioned dimensions, + // allocate the full output shape, each partition updates its owned region, + // all-reduce across partitions, and then slice its output region. + + // We currently don't support subgroup all-reduce along partitions, so more + // than 1 partitioned dimensions is not supported. + if (sharding.tile_assignment().dim(dimension) != num_partitions_) { + return DefaultAction(hlo); + } + + // temp_output_shape is the output shape where the concatenate dimension + // is changed to the full (and padded to shard count) dimension size. + auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding); + temp_output_shape.set_dimensions( + dimension, temp_output_shape.dimensions(dimension) * + sharding.tile_assignment().dim(dimension)); + auto temp_output = CreateZero(temp_output_shape, &b_); + + // Offset of each operand along the concatenate dimension. + int64 offset = 0; + for (HloInstruction* operand : hlo->operands()) { + auto spmd_operand = GetPartitionedHlo(operand).Reshard(sharding).hlo(); + std::vector start_indices( + hlo->shape().rank(), b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(S32)))); + start_indices[dimension] = + MultiplyAddDivideOffsetCalculation( + spmd_operand->shape().dimensions(dimension), offset, 1) + .Calculate(MakeTiledPartitionOrdinals(sharding, partition_id_, + &b_)[dimension], + &b_); + temp_output = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + temp_output_shape, temp_output, spmd_operand, start_indices)); + offset += operand->shape().dimensions(dimension); + } + auto all_reduce = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, temp_output, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + SetPartitionedHlo(hlo, [&] { + auto start_indices = + MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_); + start_indices[dimension] = MultiplyAddDivideOffsetCalculation( + shard_shape.dimensions(dimension), 0, 1) + .Calculate(start_indices[dimension], &b_); + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, all_reduce, start_indices, shard_shape.dimensions())); + }); + + return Status::OK(); +} + +// If partitioning in the operand only happens in dimensions in passthrough +// dimensions (offset dimensions in the gather output (or scatter update) that +// have the same size as the operand), returns the corresponding output (or +// update) sharding by passing through the input sharding. +absl::optional PassthroughOperandToGatherOutputOrScatterUpdate( + const PartitionedHlo& operand, const Shape& update_or_gather_shape, + absl::Span collapsed_or_inserted_dims, + absl::Span index_map, + absl::Span offset_or_window_dims, + absl::Span slice_size) { + if (operand.sharding().IsTileMaximal()) { + return operand.sharding(); + } + std::vector passthrough_tile(update_or_gather_shape.rank(), 1); + int64 collapsed = 0; + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + int64 dim_partitions = operand.sharding().tile_assignment().dim(i); + if (absl::c_linear_search(collapsed_or_inserted_dims, i) || + absl::c_linear_search(index_map, i)) { + if (dim_partitions > 1) { + return absl::nullopt; + } + collapsed++; + continue; + } + if (slice_size[i] != operand.base_shape().dimensions(i) && + dim_partitions > 1) { + return absl::nullopt; + } + int64 offset_dim = offset_or_window_dims[i - collapsed]; + if (i - collapsed > 0 && + offset_dim < offset_or_window_dims[i - collapsed - 1]) { + // Output offsets are transposed, we do not support this case. + return absl::nullopt; + } + passthrough_tile[offset_dim] = dim_partitions; + } + Array tile_assignment = operand.sharding().tile_assignment(); + tile_assignment.Reshape(passthrough_tile); + return HloSharding::Tile(tile_assignment); +} + +// Returns whether partitioning in the operand only happens in dimensions with +// gather/scatter slice size 1. +bool GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( + const PartitionedHlo& operand, absl::Span index_map, + absl::Span slice_size, int64 num_partitions) { + if (operand.sharding().IsTileMaximal()) { + return false; + } + int64 trivial_slice_dims_partitions = 1; + for (int64 dim : index_map) { + if (slice_size[dim] == 1) { + trivial_slice_dims_partitions *= + operand.sharding().tile_assignment().dim(dim); + } + } + return trivial_slice_dims_partitions == num_partitions; +} + +// Returns the min and max for the indices (replicated) in a scatter/gather +// which has the operand partitioned on trivial slice dimensions (slice size 1). +std::pair +IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( + const PartitionedHlo& operand, const PartitionedHlo& replicated_indices, + HloInstruction* partition_id, absl::Span index_map, + int64 index_vector_dim, SpmdBuilder* b) { + auto operand_offsets = MakePartitionOffsets( + operand.base_shape(), operand.sharding(), partition_id, b); + // Find the per-dimension index bounds. + std::vector min_indices; + std::vector max_indices; + for (int64 i = 0; i < index_map.size(); ++i) { + int64 dim = index_map[i]; + int64 partitions = operand.sharding().tile_assignment().dim(dim); + if (partitions == 1) { + min_indices.push_back(CreateR0WithType( + replicated_indices.base_shape().element_type(), 0, b)); + max_indices.push_back(CreateR0WithType( + replicated_indices.base_shape().element_type(), + operand.base_shape().dimensions(dim), b)); + continue; + } + auto offset = operand_offsets[dim]; + if (offset->shape().element_type() != + replicated_indices.base_shape().element_type()) { + offset = b->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(replicated_indices.base_shape().element_type(), + {}), + offset)); + } + min_indices.push_back(offset); + auto partition_size_minus_1 = + CreateR0WithType(replicated_indices.base_shape().element_type(), + operand.hlo()->shape().dimensions(dim) - 1, b); + max_indices.push_back(b->AddInstruction(HloInstruction::CreateBinary( + offset->shape(), HloOpcode::kAdd, offset, partition_size_minus_1))); + } + // Broadcast the index bounds to the same shape as the indices. + HloInstruction* broadcast_min; + HloInstruction* broadcast_max; + if (index_vector_dim < replicated_indices.base_shape().rank()) { + // The index vector is an R1, we need to reshape individual bounds to + // [1], and concat them if there are more than one. + for (int64 i = 0; i < min_indices.size(); ++i) { + min_indices[i] = b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(min_indices[i]->shape().element_type(), {1}), + min_indices[i])); + max_indices[i] = b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(max_indices[i]->shape().element_type(), {1}), + max_indices[i])); + } + int64 slice_dims = max_indices.size(); + if (slice_dims > 1) { + min_indices[0] = b->AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(min_indices[0]->shape().element_type(), + {slice_dims}), + min_indices, 0)); + max_indices[0] = b->AddInstruction(HloInstruction::CreateConcatenate( + min_indices[0]->shape(), max_indices, 0)); + } + broadcast_min = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), min_indices[0], {index_vector_dim})); + broadcast_max = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), max_indices[0], {index_vector_dim})); + } else { + CHECK_EQ(max_indices.size(), 1); + broadcast_min = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), min_indices[0], {})); + broadcast_max = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), max_indices[0], {})); + } + return {broadcast_min, broadcast_max}; +} + +Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { + auto scatter = Cast(hlo); + auto dnums = scatter->scatter_dimension_numbers(); + auto operand = GetPartitionedHlo(scatter->operand(0)); + auto indices = GetPartitionedHlo(scatter->operand(1)); + auto updates = GetPartitionedHlo(scatter->operand(2)); + std::vector slice_size(operand.base_shape().rank(), 1); + int64 num_update_window_dims = 0; + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + if (absl::c_linear_search(dnums.inserted_window_dims(), i)) { + continue; + } + slice_size[i] = updates.base_shape().dimensions( + dnums.update_window_dims(num_update_window_dims++)); + } + std::vector inserted_window_dims(dnums.inserted_window_dims().begin(), + dnums.inserted_window_dims().end()); + std::vector scatter_dims_to_operand_dims( + dnums.scatter_dims_to_operand_dims().begin(), + dnums.scatter_dims_to_operand_dims().end()); + std::vector update_window_dims(dnums.update_window_dims().begin(), + dnums.update_window_dims().end()); + if (!operand.sharding().IsTileMaximal()) { + auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate( + operand, updates.base_shape(), inserted_window_dims, + scatter_dims_to_operand_dims, update_window_dims, slice_size); + // Handle pass through cases if we can use compatible sharding for update. + if (maybe_passthrough.has_value()) { + indices = indices.Reshard(HloSharding::Replicate()); + updates = updates.Reshard(*maybe_passthrough); + auto pscatter = b_.AddInstruction(HloInstruction::CreateScatter( + operand.hlo()->shape(), operand.hlo(), indices.hlo(), updates.hlo(), + scatter->to_apply(), dnums, scatter->indices_are_sorted(), + scatter->unique_indices())); + pscatter->set_sharding(*maybe_passthrough); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(pscatter, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( + operand, scatter_dims_to_operand_dims, slice_size, + num_partitions_) && + ShapeUtil::ByteSizeOf(updates.base_shape()) < + ShapeUtil::ByteSizeOf(scatter->shape())) { + // Operand is sharded on trivial slice dims (update slice size 1). We can + // adjust the indices on each partition by subtracting the offsets. Then + // we execute a scatter on full updated indices, and out-of-bound accesses + // will have no effect on the result as guaranteed by the scatter + // semantics. + indices = indices.Reshard(HloSharding::Replicate()); + updates = updates.Reshard(HloSharding::Replicate()); + HloInstruction* indices_min; + HloInstruction* indices_max_unused; + std::tie(indices_min, indices_max_unused) = + IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( + operand, indices, partition_id_, scatter_dims_to_operand_dims, + dnums.index_vector_dim(), &b_); + auto adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary( + indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), + indices_min)); + auto pscatter = b_.AddInstruction(HloInstruction::CreateScatter( + operand.hlo()->shape(), operand.hlo(), adjusted_indices, + updates.hlo(), scatter->to_apply(), dnums, + scatter->indices_are_sorted(), scatter->unique_indices())); + pscatter->set_sharding(operand.sharding()); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(pscatter, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + auto operand = GetPartitionedHlo(hlo->operand(0)).Reshard(sharding); + + // Create a window config to represent the slice. + Window window; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_stride(hlo->slice_strides(i)); + dim->set_window_dilation(1); + dim->set_window_reversal(false); + dim->set_padding_low(-hlo->slice_starts(i)); + dim->set_padding_high(hlo->slice_limits(i) - + hlo->operand(0)->shape().dimensions(i)); + dim->set_base_dilation(1); + } + + auto reshard_operand = operand.ReshardAsWindowedInput( + window, sharding, + CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_), + /*mask_invalid_region=*/false); + if (!reshard_operand.has_value()) { + return DefaultAction(hlo); + } + TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value()); + const Shape& operand_shape = reshard_operand->sharded_input->shape(); + + std::vector start_indices = hlo->slice_starts(); + std::vector limit_indices = hlo->slice_limits(); + std::vector strides = hlo->slice_strides(); + bool need_slice = false; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + auto dim = reshard_operand->shard_window.dimensions(i); + start_indices[i] = -dim.padding_low(); + limit_indices[i] = operand_shape.dimensions(i) + dim.padding_high(); + if (start_indices[i] != 0 || strides[i] != 1 || + limit_indices[i] != operand_shape.dimensions(i)) { + need_slice = true; + } + } + + SetPartitionedHlo(hlo, [&] { + if (need_slice) { + auto shard_shape = MakePartitionedShape(hlo->shape(), sharding); + return b_.AddInstruction(HloInstruction::CreateSlice( + shard_shape, reshard_operand->sharded_input, start_indices, + limit_indices, strides)); + } + return reshard_operand->sharded_input; + }); + + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) { + HloSharding sharding = hlo->sharding(); + if (hlo->shape().IsTuple()) { + // Check that all elements are sharded in the same way. + if (hlo->shape().tuple_shapes_size() == 0) { + return DefaultAction(hlo); + } + sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0}); + for (int64 i = 1; i < hlo->operand_count(); ++i) { + if (sharding != hlo->sharding().GetSubSharding(hlo->shape(), {i})) { + return DefaultAction(hlo); + } + } + } + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + for (int64 dim : hlo->dimensions()) { + if (sharding.tile_assignment().dim(dim) > 1) { + return DefaultAction(hlo); + } + } + // Reshard operands to the same as the output. + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo()); + } + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) { + if (hlo->custom_call_target() == "SPMDFullToShardShape") { + // This op switches from auto partitioning to manual partitioning. + auto input_partitioned = GetPartitionedHlo(hlo->operand(0)); + if (!EvenlyPartitions(hlo->shape(), input_partitioned.sharding())) { + input_partitioned = input_partitioned.PadWithValue( + CreateR0WithType(hlo->shape().element_type(), 0, &b_)); + } + auto input = input_partitioned.hlo(); + CHECK(hlo->sharding().IsReplicated()); + CHECK(ShapeUtil::Compatible(input->shape(), hlo->shape())); + auto copy = b_.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input)); + SetPartitionedHlo(hlo, [&] { return copy; }); + return Status::OK(); + } + if (hlo->custom_call_target() == "SPMDShardToFullShape") { + // This op switches from manual partitioning to auto partitioning. + auto input = GetPartitionedHlo(hlo->operand(0)).hlo(); + CHECK(input->sharding().IsReplicated()); + auto copy = b_.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input)); + CHECK(ShapeUtil::Compatible( + copy->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding()))); + SetPartitionedHlo(hlo, [&] { return copy; }); + return Status::OK(); + } + if (hlo->custom_call_target() != "TopK") { + return DefaultAction(hlo); + } + + if (!hlo->operand(0)->has_sharding()) { + return DefaultAction(hlo); + } + + const HloSharding& sharding = hlo->operand(0)->sharding(); + if (sharding.IsTileMaximal() || sharding.IsReplicated()) { + return DefaultAction(hlo); + } + + const int64 sort_dim = 1; + const int64 shard_count = sharding.tile_assignment().dim(sort_dim); + + if (shard_count <= 1) { + return DefaultAction(hlo); + } + + const int64 input_size = hlo->operand(0)->shape().dimensions(sort_dim); + const int64 batch_size = hlo->shape().tuple_shapes(0).dimensions(0); + const int64 k = hlo->shape().tuple_shapes(0).dimensions(sort_dim); + const int64 per_partition_size = CeilOfRatio(input_size, shard_count); + + if (k >= per_partition_size) { + return DefaultAction(hlo); + } + + auto input = hlo->operand(0); + const auto element_type = input->shape().element_type(); + + // Pad input with minimal value. + auto min_value = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MinValue(element_type))); + // TODO(wangtao): add test to see if -NaN < -Inf in BF16. + if (element_type == F32) { + auto float_pad_value = std::numeric_limits::quiet_NaN(); + min_value = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(-float_pad_value))); + } + auto partitioned_input = GetPartitionedHlo(input).PadWithValue(min_value); + + // Each partition needs to do TopK separately, thus the base shape + // becomes [batch_size, k * shard_count]. + const Shape replicated_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(hlo->operand(0)->shape().element_type(), + {batch_size, k * shard_count}), + ShapeUtil::MakeShape(S32, {batch_size, k * shard_count})}); + auto custom_call_sharding = + sharding.GetTupleSharding(replicated_shape).ValueOrDie(); + auto shard_shape = + MakePartitionedShape(replicated_shape, custom_call_sharding); + auto topk = b_.AddInstruction( + hlo->CloneWithNewOperands(shard_shape, {partitioned_input.hlo()})); + topk->set_sharding(custom_call_sharding); + // Partition customcall. + PartitionedHlo partitioned_topk(topk, replicated_shape, + MakePartitioningState()); + topk = partitioned_topk.hlo(); + + // Get value from TopK. + HloInstruction* value_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(0), topk, 0)); + value_gte->set_sharding(sharding); + // Partition GetTupleElement of value. + PartitionedHlo value_partitioned_gte( + value_gte, partitioned_topk.base_shape().tuple_shapes(0), + MakePartitioningState()); + // Reshard value to be replicated. + auto replicated_value_gte = + value_partitioned_gte.Reshard(HloSharding::Replicate()).hlo(); + + // Get index from TopK. + HloInstruction* index_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(1), topk, 1)); + auto partition_id_s32 = b_.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(S32, partition_id_->shape().dimensions()), + partition_id_)); + // Add per partition offset to index, index returned from CustomCall always + // starts from 0. + auto index_offset = b_.AddInstruction(HloInstruction::CreateBroadcast( + index_gte->shape(), + b_.AddInstruction(HloInstruction::CreateBinary( + partition_id_s32->shape(), HloOpcode::kMultiply, partition_id_s32, + b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(per_partition_size))))), + {})); + index_gte = b_.AddInstruction(HloInstruction::CreateBinary( + index_offset->shape(), HloOpcode::kAdd, index_gte, index_offset)); + index_gte->set_sharding(sharding); + // Parttion GetTupleElement of index. + PartitionedHlo index_partitioned_gte( + index_gte, partitioned_topk.base_shape().tuple_shapes(1), + MakePartitioningState()); + // Reshard index to be replicated. + auto replicated_index_gte = + index_partitioned_gte.Reshard(HloSharding::Replicate()).hlo(); + + // Creates replicated sort to do TopK, the input is value and index pairs + // from all the partitions. The reason to use Sort instead of CustomCall TopK + // is CustomCall only takes value as input. There will be an extra Gather + // to get the correct index if CustomCall is used here. + + // Create comparator for the sort. + XlaBuilder b("Sort.Compare"); + XlaComputation comparator = CreateScalarComparisonComputation( + "compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt}, + &b); + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, + HloModule::CreateFromProto(comparator.proto(), config)); + HloCloneContext context(module_); + auto compare_computation = + module_->DeepCloneComputation(new_module->entry_computation(), &context); + auto sort = b_.AddInstruction(HloInstruction::CreateSort( + replicated_shape, sort_dim, {replicated_value_gte, replicated_index_gte}, + compare_computation, true)); + sort->set_sharding( + HloSharding::Replicate().GetTupleSharding(sort->shape()).ValueOrDie()); + PartitionedHlo replicated_sort(sort, replicated_shape, + MakePartitioningState()); + + // Slice value and index from top-k for output. + HloInstruction* sort_value_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + replicated_sort.hlo()->shape().tuple_shapes(0), replicated_sort.hlo(), + 0)); + HloInstruction* sort_index_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + replicated_sort.hlo()->shape().tuple_shapes(1), replicated_sort.hlo(), + 1)); + const Shape& hlo_shape = sort_value_gte->shape(); + auto hlo_dims = hlo_shape.dimensions(); + std::vector start_indices(hlo_shape.dimensions_size(), 0); + std::vector limit_indices(hlo_dims.begin(), hlo_dims.end()); + std::vector strides(hlo_shape.dimensions_size(), sort_dim); + limit_indices[sort_dim] = k; + auto output_shape = hlo_shape; + output_shape.set_dimensions(sort_dim, k); + // Slice value from final sort. + HloInstruction* slice_sort_value = + b_.AddInstruction(HloInstruction::CreateSlice( + output_shape, sort_value_gte, start_indices, limit_indices, strides)); + // Slice index from final sort. + auto index_output_shape = sort_index_gte->shape(); + index_output_shape.set_dimensions(sort_dim, k); + HloInstruction* slice_index_value = b_.AddInstruction( + HloInstruction::CreateSlice(index_output_shape, sort_index_gte, + start_indices, limit_indices, strides)); + auto create_tuple = b_.AddInstruction( + HloInstruction::CreateTuple({slice_sort_value, slice_index_value})); + create_tuple->set_sharding(HloSharding::Replicate()); + + SetPartitionedHlo(hlo, PartitionedHlo(create_tuple, create_tuple->shape(), + MakePartitioningState()) + .Reshard(hlo->sharding())); + + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleTranspose(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + std::vector inverse_dimensions(hlo->shape().rank()); + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + inverse_dimensions[hlo->dimensions(i)] = i; + } + auto desired_operand_sharding = + hlo_sharding_util::TransposeSharding(sharding, inverse_dimensions); + + auto operand = GetPartitionedHlo(hlo->operand(0)) + .Reshard(desired_operand_sharding) + .hlo(); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand})); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + auto operand = GetPartitionedHlo(hlo->operand(0)); + // The output shape is the source and the operand shape is the target to get + // the aligned sharding for the operand. + auto desired_operand_sharding = hlo_sharding_util::ReshapeSharding( + hlo->shape(), hlo->operand(0)->shape(), hlo->sharding()); + if (desired_operand_sharding.has_value()) { + auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo(); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand_hlo})); + }); + return Status::OK(); + } + + // Try use halo exchange for certain split-dim/merge-dims cases. + // ReshapeSharding failed in these cases probably due to uneven partitioning, + // where halo exchange could help. Specifically we check the following + // conditions to detect supported cases: + // 1) Both input and output are partitioned on one dimension. + // 2) The combined size of dimensions before the partitioned dimension are the + // same on input and output. This means we don't need to consider the major + // dimensions. + // 3) Let A = the input size on the partitioned dimension, and + // B = the output size on the partitioned dimension; then + // either A % B == 0 (split dim) or B % A == 0 (merge dims). + auto maybe_input_sharded_dim = UniqueTiledDim(operand.sharding()); + auto maybe_output_sharded_dim = UniqueTiledDim(sharding); + if (!maybe_input_sharded_dim || !maybe_output_sharded_dim) { + return DefaultAction(hlo); + } + int64 input_sharded_dim = *maybe_input_sharded_dim; + int64 output_sharded_dim = *maybe_output_sharded_dim; + // Check that the major dims before the sharded dim have the same total size + // for input and output. + int64 input_major_dims_size = 1; + for (int64 i = 0; i < input_sharded_dim; ++i) { + input_major_dims_size *= operand.base_shape().dimensions(i); + } + int64 output_major_dims_size = 1; + for (int64 i = 0; i < output_sharded_dim; ++i) { + output_major_dims_size *= hlo->shape().dimensions(i); + } + if (input_major_dims_size != output_major_dims_size) { + return DefaultAction(hlo); + } + // Fix potential device ordering mismatch in tile assignment. + Array new_input_tile_assignment = sharding.tile_assignment(); + new_input_tile_assignment.Reshape( + operand.sharding().tile_assignment().dimensions()); + operand = operand.Reshard(HloSharding::Tile(new_input_tile_assignment)); + + int64 input_dim_size = operand.base_shape().dimensions(input_sharded_dim); + int64 output_dim_size = hlo->shape().dimensions(output_sharded_dim); + auto input_shard_shape = + MakePartitionedShape(operand.base_shape(), operand.sharding()); + auto output_shard_shape = MakePartitionedShape(hlo->shape(), sharding); + if (input_dim_size % output_dim_size == 0) { + // Split dim. + int64 split_factor = input_dim_size / output_dim_size; + int64 output_shard_size = output_shard_shape.dimensions(output_sharded_dim); + // Use halo exchange to fix misaligned data. + Window window; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_stride(1); + dim->set_window_dilation(1); + dim->set_window_reversal(false); + dim->set_base_dilation(1); + dim->set_padding_low(0); + if (i == input_sharded_dim) { + dim->set_padding_high(output_shard_size * split_factor * + num_partitions_ - + input_dim_size); + } else { + dim->set_padding_high(0); + } + } + + auto reshard_operand = operand.ReshardAsWindowedInput( + window, operand.sharding(), + CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_), + /*mask_invalid_region=*/false); + if (!reshard_operand.has_value()) { + return DefaultAction(hlo); + } + TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value()); + CHECK_EQ( + reshard_operand->sharded_input->shape().dimensions(input_sharded_dim), + output_shard_size * split_factor); + SetPartitionedHlo(hlo, [&] { + // Do a local reshape. + return b_.AddInstruction(HloInstruction::CreateReshape( + output_shard_shape, reshard_operand->sharded_input)); + }); + return Status::OK(); + } else if (output_dim_size % input_dim_size == 0) { + // Merge dims. + int64 merge_factor = output_dim_size / input_dim_size; + // First reshape locally. (The sharded dimension could include padded data.) + auto tmp_shard_shape = output_shard_shape; + tmp_shard_shape.set_dimensions( + output_sharded_dim, + input_shard_shape.dimensions(input_sharded_dim) * merge_factor); + auto tmp_reshape = b_.AddInstruction( + HloInstruction::CreateReshape(tmp_shard_shape, operand.hlo())); + tmp_reshape->set_metadata(hlo->metadata()); + tmp_reshape->set_sharding(hlo->sharding()); + auto tmp_full_shape = tmp_shard_shape; + tmp_full_shape.set_dimensions( + output_sharded_dim, + tmp_shard_shape.dimensions(output_sharded_dim) * num_partitions_); + auto tmp_output = + PartitionedHlo(tmp_reshape, tmp_full_shape, MakePartitioningState()); + + // Use halo exchange to fix misaligned data. + Window window; + for (int64 i = 0; i < tmp_shard_shape.rank(); ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_stride(1); + dim->set_window_dilation(1); + dim->set_window_reversal(false); + dim->set_base_dilation(1); + dim->set_padding_low(0); + if (i == output_sharded_dim) { + dim->set_padding_high(output_dim_size - + tmp_shard_shape.dimensions(output_sharded_dim) * + num_partitions_); + } else { + dim->set_padding_high(0); + } + } + + auto reshard_output = tmp_output.ReshardAsWindowedInput( + window, sharding, + CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_), + /*mask_invalid_region=*/false); + if (!reshard_output.has_value()) { + return DefaultAction(hlo); + } + TF_RET_CHECK(!reshard_output->dynamic_slice_index_on_output.has_value()); + CHECK_EQ( + reshard_output->sharded_input->shape().dimensions(input_sharded_dim), + output_shard_shape.dimensions(output_sharded_dim)); + SetPartitionedHlo(hlo, [&] { return reshard_output->sharded_input; }); + return Status::OK(); + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleIota(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + SetPartitionedHlo(hlo, [&] { + int64 dimension = Cast(hlo)->iota_dimension(); + auto iota = b_.AddInstruction(HloInstruction::CreateIota( + MakePartitionedShape(hlo->shape(), sharding), dimension)); + + if (sharding.tile_assignment().dim(dimension) > 1) { + auto partition_ordinals = + MakeTiledPartitionOrdinals(sharding, partition_id_, &b_); + auto multiplier = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(iota->shape().dimensions(dimension)))); + auto offset = b_.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, + partition_ordinals[dimension], multiplier)); + if (iota->shape().element_type() != S32) { + offset = b_.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(iota->shape().element_type(), {}), offset)); + } + auto broadcast = b_.AddInstruction( + HloInstruction::CreateBroadcast(iota->shape(), offset, {})); + return b_.AddInstruction(HloInstruction::CreateBinary( + iota->shape(), HloOpcode::kAdd, iota, broadcast)); + } + + return iota; + }); + + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleSingleDevice(const HloInstruction* hlo) { + TF_RET_CHECK(hlo->sharding().HasUniqueDevice()); + int64 device = hlo->sharding().GetUniqueDevice(); + const HloSharding sharding = HloSharding::AssignDevice(device); + + std::vector operands; + std::vector operand_shapes; + for (const HloInstruction* operand : hlo->operands()) { + operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo()); + operand_shapes.push_back(operand->shape()); + } + auto operand = b_.AddInstruction(HloInstruction::CreateTuple(operands)); + auto operand_shape = ShapeUtil::MakeTupleShape(operand_shapes); + + auto on_device = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(device))); + auto pred = b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), partition_id_, on_device, + ComparisonDirection::kEq)); + + SpmdBuilder true_b("true_computation", visiting_hlo_); + HloComputation* true_computation; + { + auto param = true_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, operand_shape, "true_branch_param")); + std::vector new_operands; + for (int64 i = 0; i < operands.size(); ++i) { + new_operands.push_back(true_b.AddInstruction( + HloInstruction::CreateGetTupleElement(operand_shapes[i], param, i))); + } + auto root = true_b.AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), new_operands)); + true_computation = module_->AddEmbeddedComputation(true_b.Build(root)); + } + + SpmdBuilder false_b("false_computation", visiting_hlo_); + HloComputation* false_computation; + { + false_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, operand_shape, "false_branch_param")); + auto root = CreateZero(hlo->shape(), &false_b); + false_computation = module_->AddEmbeddedComputation(false_b.Build(root)); + } + + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateConditional( + hlo->shape(), pred, operand, true_computation, operand, + false_computation)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleAllReduce(HloInstruction* hlo) { + if (hlo->IsCrossReplicaAllReduce() && hlo->operand_count() == 1) { + return HandleElementwise(hlo); + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + + auto& operand = GetPartitionedHlo(hlo->operand(0)); + + // Tiled output. + std::vector wanted_input_tile_size(operand.base_shape().rank()); + std::vector sharded_new_dims; + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + wanted_input_tile_size[i] = + hlo->sharding().tile_assignment().dim(hlo->dimensions(i)); + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!absl::c_linear_search(hlo->dimensions(), i) && + hlo->sharding().tile_assignment().dim(i) > 1) { + sharded_new_dims.push_back(i); + } + } + if (sharded_new_dims.empty()) { + // The new dimensions are replicated, so that we can do the adjustment on + // the input. + Array wanted_input_tile_assignment(wanted_input_tile_size); + wanted_input_tile_assignment.Each( + [&](absl::Span indices, int64* val) { + std::vector indices_in_broadcast(hlo->shape().rank(), 0); + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + indices_in_broadcast[hlo->dimensions(i)] = indices[i]; + } + *val = hlo->sharding().tile_assignment()(indices_in_broadcast); + }); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), + {operand.Reshard(HloSharding::Tile(wanted_input_tile_assignment)) + .hlo()})); + }); + } else { + auto input = operand.Reshard(HloSharding::Replicate()).hlo(); + // We pad and shard the input first, then broadcast to the final shard + // shape. + auto output_offsets = + MakePartitionOffsets(hlo->shape(), hlo->sharding(), partition_id_, &b_); + std::vector input_offsets(operand.base_shape().rank()); + auto output_shard_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto input_shard_shape = input->shape(); + auto padded_input_shape = input->shape(); + for (int64 i = 0; i < input_offsets.size(); ++i) { + input_offsets[i] = output_offsets[hlo->dimensions(i)]; + input_shard_shape.set_dimensions( + i, output_shard_shape.dimensions(hlo->dimensions(i))); + padded_input_shape.set_dimensions( + i, hlo->sharding().tile_assignment().dim(hlo->dimensions(i)) * + input_shard_shape.dimensions(i)); + } + auto padded_input = PadToShape(input, padded_input_shape, &b_); + auto input_shard = + ShapeUtil::Compatible(input_shard_shape, padded_input->shape()) + ? padded_input + : b_.AddInstruction(HloInstruction::CreateDynamicSlice( + input_shard_shape, padded_input, input_offsets, + input_shard_shape.dimensions())); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction( + hlo->CloneWithNewOperands(output_shard_shape, {input_shard})); + }); + } + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConstant(HloInstruction* hlo) { + const Literal& literal = hlo->literal(); + if (literal.shape().IsTuple() || + (!hlo->sharding().IsTileMaximal() && + (!EvenlyPartitions(hlo->shape(), hlo->sharding()) || + !literal.IsAllFirst()))) { + return DefaultAction(hlo); + } + + SetPartitionedHlo(hlo, [&]() { + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + std::vector start_indices(hlo->shape().rank(), 0); + auto constant = b_.AddInstruction(HloInstruction::CreateConstant( + literal.Slice(start_indices, shard_shape.dimensions()))); + *constant->mutable_shape() = shard_shape; + return constant; + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (hlo->sharding().tile_assignment().dim(i) != 1 && + (hlo->dynamic_slice_sizes()[i] != hlo->shape().dimensions(i) || + !hlo->operand(i + 1)->IsConstant() || + !hlo->operand(i + 1)->literal().IsZero({}))) { + // We currently do not partition the sliced dimensions. + return DefaultAction(hlo); + } + } + std::vector new_indices(hlo->shape().rank()); + auto new_input = + GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo(); + for (int64 i = 0; i < new_indices.size(); ++i) { + // Replicate the indices. + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1)) + .Reshard(HloSharding::Replicate()) + .hlo(); + } + SetPartitionedHlo(hlo, [&]() { + auto partitioned_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + partitioned_shape, new_input, new_indices, + partitioned_shape.dimensions())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (hlo->sharding().tile_assignment().dim(i) != 1 && + (hlo->operand(1)->shape().dimensions(i) != hlo->shape().dimensions(i) || + !hlo->operand(i + 2)->IsConstant() || + !hlo->operand(i + 2)->literal().IsZero({}))) { + // We currently do not partition the sliced dimensions. + return DefaultAction(hlo); + } + } + std::vector new_indices(hlo->shape().rank()); + auto new_input = + GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo(); + auto new_update = + GetPartitionedHlo(hlo->operand(1)).Reshard(hlo->sharding()).hlo(); + for (int64 i = 0; i < new_indices.size(); ++i) { + // Replicate the indices. + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)) + .Reshard(HloSharding::Replicate()) + .hlo(); + } + SetPartitionedHlo(hlo, [&]() { + auto partitioned_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + return b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + partitioned_shape, new_input, new_update, new_indices)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) { + auto gather = Cast(hlo); + const auto& dnums = gather->gather_dimension_numbers(); + auto operand = GetPartitionedHlo(gather->operand(0)); + auto indices = GetPartitionedHlo(gather->operand(1)); + std::vector collapsed_slice_dims(dnums.collapsed_slice_dims().begin(), + dnums.collapsed_slice_dims().end()); + std::vector start_index_map(dnums.start_index_map().begin(), + dnums.start_index_map().end()); + std::vector offset_dims(dnums.offset_dims().begin(), + dnums.offset_dims().end()); + if (!operand.sharding().IsTileMaximal()) { + auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate( + operand, gather->shape(), collapsed_slice_dims, start_index_map, + offset_dims, gather->gather_slice_sizes()); + if (maybe_passthrough.has_value()) { + indices = indices.Reshard(HloSharding::Replicate()); + auto pshape = MakePartitionedShape(gather->shape(), *maybe_passthrough); + std::vector pslice_sizes(gather->gather_slice_sizes().begin(), + gather->gather_slice_sizes().end()); + for (int64 i = 0; i < pslice_sizes.size(); ++i) { + if (operand.sharding().tile_assignment().dim(i) > 1) { + pslice_sizes[i] = operand.hlo()->shape().dimensions(i); + } + } + auto pgather = b_.AddInstruction(HloInstruction::CreateGather( + pshape, operand.hlo(), indices.hlo(), dnums, pslice_sizes, + gather->indices_are_sorted())); + pgather->set_sharding(*maybe_passthrough); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(pgather, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( + operand, start_index_map, gather->gather_slice_sizes(), + num_partitions_) && + ShapeUtil::ByteSizeOf(gather->shape()) < + ShapeUtil::ByteSizeOf(gather->operand(0)->shape())) { + indices = indices.Reshard(HloSharding::Replicate()); + // Now the operand is partitioned in trivial slice dimensions, and the + // indices are replicated. We execute a gather on partitioned operand, + // with full number of indices, where out-of-bounds indices are clamped, + // and masked out with 0 in the result; then we use all-reduce to combine + // results. Although gather will not get faster, we avoided the need to + // replicate the operand. + HloInstruction* indices_min; + HloInstruction* indices_max; + std::tie(indices_min, indices_max) = + IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( + operand, indices, partition_id_, start_index_map, + dnums.index_vector_dim(), &b_); + // Clamp the indices. + auto adjusted_indices = b_.AddInstruction(HloInstruction::CreateTernary( + indices.base_shape(), HloOpcode::kClamp, indices_min, indices.hlo(), + indices_max)); + // Adjust the indices by subtracting the offset. + adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary( + indices.base_shape(), HloOpcode::kSubtract, adjusted_indices, + indices_min)); + // Gather on adjusted indices. + auto pgather = b_.AddInstruction(HloInstruction::CreateGather( + gather->shape(), operand.hlo(), adjusted_indices, dnums, + gather->gather_slice_sizes(), gather->indices_are_sorted())); + // Mask out invalid results. + auto filter = b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(indices.base_shape(), PRED), + indices.hlo(), indices_min, ComparisonDirection::kLt)); + filter = b_.AddInstruction(HloInstruction::CreateBinary( + filter->shape(), HloOpcode::kOr, filter, + b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(indices.base_shape(), PRED), + indices.hlo(), indices_max, ComparisonDirection::kGt)))); + if (dnums.index_vector_dim() < indices.base_shape().rank()) { + std::vector reduced_filter_dims; + for (int64 i = 0; i < filter->shape().rank(); ++i) { + if (i != dnums.index_vector_dim()) { + reduced_filter_dims.push_back(filter->shape().dimensions(i)); + } + } + filter = b_.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(PRED, reduced_filter_dims), filter, + CreateR0WithType(PRED, false, &b_), {dnums.index_vector_dim()}, + MakeBinaryAdd(PRED, module_))); + } + std::vector batch_dims; + for (int64 i = 0; i < pgather->shape().rank(); ++i) { + if (!absl::c_linear_search(dnums.offset_dims(), i)) { + batch_dims.push_back(i); + } + } + auto broadcast_filter = b_.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType(pgather->shape(), PRED), filter, + batch_dims)); + auto filtered = b_.AddInstruction(HloInstruction::CreateTernary( + pgather->shape(), HloOpcode::kSelect, broadcast_filter, + CreateZero(pgather->shape(), &b_), pgather)); + // Combine from different partitions. + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, filtered, + MakeBinaryAdd(filtered->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleGetTupleElement(HloInstruction* hlo) { + const auto& tuple = GetPartitionedHlo(hlo->operand(0)); + auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(tuple.hlo()->shape(), hlo->tuple_index()), + tuple.hlo(), hlo->tuple_index())); + SetPartitionedHlo(hlo, [&]() { + const auto source_sharding = tuple.sharding().GetSubSharding( + tuple.base_shape(), {hlo->tuple_index()}); + gte->set_sharding(source_sharding); + PartitionedHlo source_partitioned_gte(gte, hlo->shape(), + MakePartitioningState()); + return source_partitioned_gte.Reshard(hlo->sharding()).hlo(); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) { + const Shape& shape = ShapeUtil::GetTupleElementShape(hlo->shape(), 0); + auto token = GetPartitionedHlo(hlo->operand(0)).hlo(); + if (ShapeUtil::GetLeafCount(shape) == 0) { + // TODO(b/155819021): HloSharding has issues with tuple-shaped sharding: it + // requires one element for an empty tuple, but leaf-count number of + // elements for non-empty tuple. So if it has a nested empty tuple, we + // cannot invoke GetSubSharding() since it expects a sharding for the empty + // tuple. This is a workaround for that case. + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction( + HloInstruction::CreateInfeed(shape, token, hlo->infeed_config())); + }); + return Status::OK(); + } + auto sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0}); + auto shard_shape = MakePartitionedShape(shape, sharding); + if (EvenlyPartitions(shape, sharding)) { + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateInfeed( + shard_shape, token, hlo->infeed_config())); + }); + return Status::OK(); + } + + if (hlo->sharding().HasUniqueDevice()) { + return HandleSingleDevice(hlo); + } + + // Create a branch for each unique partitioned shape. + std::vector per_branch_partitioned_shapes; + std::vector conditional_branch_indices(num_partitions_); + for (int64 i = 0; i < num_partitions_; ++i) { + auto partitioned_shape = + MakeNonPaddedShapeForGivenPartition(shape, sharding, i); + int64 matching_existing_index = 0; + for (; matching_existing_index < per_branch_partitioned_shapes.size(); + ++matching_existing_index) { + if (ShapeUtil::Compatible( + partitioned_shape, + per_branch_partitioned_shapes[matching_existing_index])) { + break; + } + } + if (matching_existing_index < per_branch_partitioned_shapes.size()) { + conditional_branch_indices[i] = matching_existing_index; + } else { + conditional_branch_indices[i] = per_branch_partitioned_shapes.size(); + per_branch_partitioned_shapes.push_back(std::move(partitioned_shape)); + } + } + + HloInstruction* branch_index; + if (per_branch_partitioned_shapes.size() == num_partitions_) { + // Use partition ID as the branch index if each partition has its own + // branch. + branch_index = partition_id_; + // PartitionId's output is U32 but conditional requires S32. + if (branch_index->shape().element_type() != S32) { + branch_index = b_.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(branch_index->shape(), S32), + branch_index)); + } + } else { + // Otherwise, use a constant table to look up the branch index. + auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(conditional_branch_indices))); + branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(S32, {1}), branch_index_table, {partition_id_}, + {1})); + branch_index = b_.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {}), branch_index)); + } + + std::vector branches(per_branch_partitioned_shapes.size()); + for (int64 i = 0; i < branches.size(); ++i) { + SpmdBuilder branch_b(absl::StrCat("infeed_branch_", i), visiting_hlo_); + auto param = branch_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, token->shape(), "infeed_token_param")); + auto infeed = branch_b.AddInstruction(HloInstruction::CreateInfeed( + per_branch_partitioned_shapes[i], param, hlo->infeed_config())); + branches[i] = module_->AddEmbeddedComputation(branch_b.Build(infeed)); + if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) { + TF_ASSIGN_OR_RETURN( + auto padded, + branches[i]->DeepCopyInstructionWithCustomCopier( + infeed, [&](HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* comp) { + // Index {1} corresponds to the token. + if (leaf_index.empty() || leaf_index[0] != 0) { + return leaf; + } + ShapeIndexView subindex(leaf_index, 1); + if (ShapeUtil::Compatible( + ShapeUtil::GetSubshape(per_branch_partitioned_shapes[i], + subindex), + ShapeUtil::GetSubshape(shard_shape, subindex))) { + return leaf; + } + return PadToShape(leaf, + ShapeUtil::GetSubshape(shard_shape, subindex), + nullptr, comp); + })); + branches[i]->set_root_instruction(padded, + /*accept_different_shape=*/true); + } + } + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateConditional( + ShapeUtil::MakeTupleShape({shard_shape, token->shape()}), branch_index, + branches, std::vector(branches.size(), token))); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + const auto& pd = hlo->padding_config().dimensions(i); + // Right now we only support non-padded dimensions to be partitioned. + if (hlo->sharding().tile_assignment().dim(i) > 1 && + (pd.edge_padding_high() != 0 || pd.edge_padding_low() != 0 || + pd.interior_padding() != 0)) { + return DefaultAction(hlo); + } + } + auto resharded_lhs = + GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo(); + auto replicated_rhs = GetPartitionedHlo(hlo->operand(1)) + .Reshard(HloSharding::Replicate()) + .hlo(); + SetPartitionedHlo(hlo, [&]() { + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + return b_.AddInstruction(hlo->CloneWithNewOperands( + shard_shape, {resharded_lhs, replicated_rhs})); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleParameter(HloInstruction* hlo) { + SetPartitionedHlo(hlo, [&]() { + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto new_param = b_.AddInstruction(HloInstruction::CreateParameter( + hlo->parameter_number(), shard_shape, "param")); + if (hlo->parameter_replicated_at_leaf_buffers()) { + new_param->set_parameter_replicated_at_leaf_buffers( + *hlo->parameter_replicated_at_leaf_buffers()); + } + return new_param; + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { + int64 input_count = 1; + auto per_input_sharding = hlo->sharding(); + if (hlo->shape().IsTuple()) { + input_count = hlo->shape().tuple_shapes_size(); + CHECK_GT(input_count, 0); + per_input_sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0}); + } + + std::vector inputs; + std::vector inits; + for (int64 operand_id = 0; operand_id < input_count; ++operand_id) { + inits.push_back(GetPartitionedHlo(hlo->operand(operand_id + input_count)) + .Reshard(HloSharding::Replicate()) + .hlo()); + inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id))); + if (operand_id > 0) { + // Make sure all operands are sharded in the same way. + inputs.back() = inputs.back().Reshard(inputs[0].sharding()); + } + if (!inputs[0].sharding().IsTileMaximal()) { + inputs.back() = inputs.back().PadWithValue(inits[operand_id]); + } + } + bool reduce_sharded_dimension = false; + if (!inputs[0].sharding().IsTileMaximal()) { + reduce_sharded_dimension = absl::c_any_of(hlo->dimensions(), [&](int64 i) { + return inputs[0].sharding().tile_assignment().dim(i) > 1; + }); + + // reduce_sharded_dimension is not supported for tuple-shaped reduces. + if (reduce_sharded_dimension && input_count > 1) { + return DefaultAction(hlo); + } + + // Currently we only support reducing all or none of the sharded + // dimensions. + if (reduce_sharded_dimension) { + for (int64 i = 0; i < inputs[0].base_shape().rank(); ++i) { + if (inputs[0].sharding().tile_assignment().dim(i) > 1 && + absl::c_count(hlo->dimensions(), i) == 0) { + return DefaultAction(hlo); + } + } + } + } + + std::vector new_operand_shapes(input_count * 2); + for (int64 i = 0; i < input_count; ++i) { + new_operand_shapes[i] = inputs[i].hlo()->mutable_shape(); + new_operand_shapes[i + input_count] = inits[i]->mutable_shape(); + } + // Create the shard shape of the reduce result. + TF_ASSIGN_OR_RETURN( + auto reduce_shape, + ShapeInference::InferReduceShape(new_operand_shapes, hlo->dimensions(), + hlo->to_apply()->ComputeProgramShape())); + *reduce_shape.mutable_layout() = hlo->shape().layout(); + + std::vector input_hlos(input_count); + for (int64 i = 0; i < input_count; ++i) { + input_hlos[i] = inputs[i].hlo(); + } + auto local_reduce = b_.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, input_hlos, inits, hlo->dimensions(), hlo->to_apply())); + local_reduce->set_metadata(hlo->metadata()); + + SetPartitionedHlo(hlo, [&]() { + HloInstruction* reduce; + if (reduce_sharded_dimension) { + CHECK(local_reduce->shape().IsArray()); + reduce = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, local_reduce, hlo->to_apply(), NewChannel()); + reduce->set_sharding(HloSharding::Replicate()); + } else { + reduce = local_reduce; + if (inputs[0].sharding().IsTileMaximal()) { + reduce->set_sharding(inputs[0].sharding()); + } else { + // Remove tile assignment dimensions that are reduced. + std::vector tile_dimensions; + for (int64 i = 0; i < input_hlos[0]->shape().rank(); ++i) { + if (absl::c_count(hlo->dimensions(), i) == 0) { + tile_dimensions.push_back( + inputs[0].sharding().tile_assignment().dim(i)); + } + } + Array new_tile = inputs[0].sharding().tile_assignment(); + new_tile.Reshape(tile_dimensions); + auto sharding = HloSharding::Tile(new_tile); + if (input_count > 1) { + std::vector tuple(input_count, sharding); + sharding = HloSharding::Tuple(hlo->shape(), tuple); + } + reduce->set_sharding(sharding); + } + } + + return PartitionedHlo(reduce, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) { + auto reverse = Cast(hlo); + if (reverse->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + if (absl::c_all_of(reverse->dimensions(), [&](int64 d) { + return reverse->sharding().tile_assignment().dim(d) == 1; + })) { + auto operand = + GetPartitionedHlo(reverse->operand(0)).Reshard(reverse->sharding()); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction( + hlo->CloneWithNewOperands(operand.hlo()->shape(), {operand.hlo()})); + }); + return Status::OK(); + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + + // Shardings for the body parameter, body root, and cond parameter must be + // the same, and the condition root must be replicated so that all partitions + // follow the same control flow. + hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding); + hlo->while_body()->parameter_instruction(0)->set_sharding(sharding); + TF_RETURN_IF_ERROR(partitioner_ + ->PartitionComputation(hlo->while_condition(), + HloSharding::Replicate(), + next_channel_id_, logger_) + .status()); + TF_RETURN_IF_ERROR(partitioner_ + ->PartitionComputation(hlo->while_body(), sharding, + next_channel_id_, logger_) + .status()); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(HloInstruction::CreateWhile( + MakePartitionedShape(hlo->shape(), sharding), hlo->while_condition(), + hlo->while_body(), + GetPartitionedHlo(hlo->operand(0)).Reshard(sharding).hlo())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConditional(HloInstruction* hlo) { + std::vector branch_args; + for (int64 i = 0; i < hlo->branch_count(); ++i) { + HloComputation* computation = hlo->branch_computation(i); + + // Shardings of the branch computation parameter and its argument must be + // the same. + computation->parameter_instruction(0)->set_sharding( + hlo->operand(i + 1)->sharding()); + branch_args.push_back(GetPartitionedHlo(hlo->operand(i + 1)).hlo()); + } + + // The root of the branch computations must follow the sharding of the + // conditional instruction. + for (int64 i = 0; i < hlo->branch_count(); ++i) { + HloComputation* computation = hlo->branch_computation(i); + TF_RETURN_IF_ERROR(partitioner_ + ->PartitionComputation(computation, hlo->sharding(), + next_channel_id_, logger_) + .status()); + } + + // We replicate the predicate of the conditional (the first operand) so that + // all partitions follow the same control flow. + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(HloInstruction::CreateConditional( + MakePartitionedShape(hlo->shape(), hlo->sharding()), + GetPartitionedHlo(hlo->operand(0)) + .Reshard(HloSharding::Replicate()) + .hlo(), + hlo->called_computations(), branch_args)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) { + TF_RET_CHECK(hlo->sharding().HasUniqueDevice()); + return HandleSingleDevice(hlo); +} + +Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) { + if (hlo->sharding().HasUniqueDevice()) { + return HandleSingleDevice(hlo); + } + + if (hlo->sharding().IsReplicated()) { + SetPartitionedHlo(hlo, [&] { + // Run on a single device (0) and distribute the data to all other cores. + std::vector new_operands; + for (int64 i = 0; i < hlo->operand_count(); ++i) { + new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) + .Reshard(HloSharding::AssignDevice(0)) + .hlo()); + } + auto clone = b_.AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), new_operands)); + clone->set_sharding(HloSharding::AssignDevice(0)); + return PartitionedHlo(clone, hlo->shape(), MakePartitioningState()) + .Reshard(HloSharding::Replicate()) + .hlo(); + }); + return Status::OK(); + } + + TF_RET_CHECK(!hlo->sharding().IsTileMaximal()); + SetPartitionedHlo(hlo, [&] { + // Replicate the operands and run partitioned Rng on all devices. + std::vector new_operands; + for (int64 i = 0; i < hlo->operand_count(); ++i) { + new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) + .Reshard(HloSharding::Replicate()) + .hlo()); + } + return b_.AddInstruction(HloInstruction::CreateRng( + MakePartitionedShape(hlo->shape(), hlo->sharding()), + hlo->random_distribution(), new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) { + auto& operand = GetPartitionedHlo(hlo->operand(0)); + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + + // Replicate init + auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(1)) + .Reshard(HloSharding::Replicate()); + auto resharded_operand_and_window = operand.ReshardAsWindowedInput( + hlo->window(), hlo->sharding(), replicated_init.hlo()); + if (!resharded_operand_and_window.has_value()) { + return DefaultAction(hlo); + } + + TF_ASSIGN_OR_RETURN(Shape sharded_rw_shape, + ShapeInference::InferReduceWindowShape( + resharded_operand_and_window->sharded_input->shape(), + replicated_init.hlo()->shape(), + resharded_operand_and_window->shard_window, + hlo->to_apply()->ComputeProgramShape())); + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + *sharded_rw_shape.mutable_layout() = shard_shape.layout(); + SetPartitionedHlo(hlo, [&]() { + auto sharded_rw = b_.AddInstruction(HloInstruction::CreateReduceWindow( + sharded_rw_shape, resharded_operand_and_window->sharded_input, + replicated_init.hlo(), resharded_operand_and_window->shard_window, + hlo->to_apply())); + if (!resharded_operand_and_window->dynamic_slice_index_on_output + .has_value()) { + CHECK(ShapeUtil::Compatible(shard_shape, sharded_rw->shape())); + return sharded_rw; + } + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, sharded_rw, + *resharded_operand_and_window->dynamic_slice_index_on_output, + shard_shape.dimensions())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleSelectAndScatter(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + auto operand = GetPartitionedHlo(hlo->operand(0)); + auto source = GetPartitionedHlo(hlo->mutable_operand(1)); + if (hlo->sharding() != operand.sharding()) { + operand = operand.Reshard(hlo->sharding()); + } + if (hlo->sharding() != source.sharding()) { + source = source.Reshard(hlo->sharding()); + } + + // For F32 and BF16 types, we can use NaN padding to workaround the issue with + // low/high padding, since comparison will return false with NaN input. + if (hlo->shape().element_type() != F32 && + hlo->shape().element_type() != BF16) { + return DefaultAction(hlo); + } + + auto select = hlo->called_computations()[0]; + auto select_root = select->root_instruction(); + if (select_root->opcode() != HloOpcode::kCompare || + select_root->operand(0)->opcode() != HloOpcode::kParameter || + select_root->operand(1)->opcode() != HloOpcode::kParameter || + select_root->operand(0)->parameter_number() + + select_root->operand(1)->parameter_number() != + 1) { + return DefaultAction(hlo); + } + + float float_pad_value; + if (select_root->comparison_direction() == ComparisonDirection::kGe || + select_root->comparison_direction() == ComparisonDirection::kGt) { + if (select_root->operand(0)->parameter_number() == 0) { + float_pad_value = -std::numeric_limits::infinity(); + } else { + float_pad_value = std::numeric_limits::infinity(); + } + } else if (select_root->comparison_direction() == ComparisonDirection::kLe || + select_root->comparison_direction() == ComparisonDirection::kLt) { + if (select_root->operand(0)->parameter_number() == 0) { + float_pad_value = std::numeric_limits::infinity(); + } else { + float_pad_value = -std::numeric_limits::infinity(); + } + } else { + return DefaultAction(hlo); + } + + auto pad_value = b_.AddInstruction(HloInstruction::CreateConstant( + hlo->shape().element_type() == BF16 + ? LiteralUtil::CreateR0( + static_cast(float_pad_value)) + : LiteralUtil::CreateR0(float_pad_value))); + + // Replicate init + auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2)) + .Reshard(HloSharding::Replicate()); + + auto partition_ordinals = + MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_); + + // The first window for each dimension that overlaps with the shard area. + std::vector first_window( + hlo->shape().rank()); + // The first window for each dimension that goes beyond with the shard area. + std::vector limit_window( + hlo->shape().rank()); + std::vector data_left_halo_sizes(hlo->shape().rank()); + std::vector data_right_halo_sizes(hlo->shape().rank()); + std::vector source_left_halo_sizes(hlo->shape().rank()); + std::vector source_right_halo_sizes(hlo->shape().rank()); + auto unpadded_data_shard_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto unpadded_source_shard_shape = + MakePartitionedShape(hlo->operand(1)->shape(), hlo->sharding()); + auto source_shard_hlo = source.hlo(); + auto data_shard_hlo = operand.hlo(); + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + int64 shard_count = hlo->sharding().tile_assignment().dim(i); + if (shard_count == 1) { + continue; + } + // If stride > window_size, there will be gaps between windows. These gaps + // will also exist in the output, so we keep them during halo exchange. + // + // TODO(yuanzx): This could introduce overhead if partitions start at + // different offsets in a gap. + auto wd = hlo->window().dimensions(i); + if (wd.stride() > wd.size()) { + wd.set_size(wd.stride()); + } + // shard_size * i < stride * k - pad_low + window_size => + // k > (shard_size * i + pad_low - window_size) / stride => + // first_k == (shard_size * i + pad_low - window_size + stride) / stride + first_window[i] = MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), + wd.padding_low() - wd.size() + wd.stride(), wd.stride()); + // shard_size * (i + 1) <= stride * k - pad_low => + // k >= (shard_size * i + shard_size + pad_low) / stride => + // limit_k == (shard_size * i + shard_size + pad_low + stride - 1) / + // stride + limit_window[i] = MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), + unpadded_data_shard_shape.dimensions(i) + wd.padding_low() + + wd.stride() - 1, + wd.stride()); + source_left_halo_sizes[i] = + MultiplyAddDivideOffsetCalculation( + unpadded_source_shard_shape.dimensions(i), 0, 1) - + first_window[i]; + source_right_halo_sizes[i] = + limit_window[i] - MultiplyAddDivideOffsetCalculation( + unpadded_source_shard_shape.dimensions(i), + unpadded_source_shard_shape.dimensions(i), 1); + data_left_halo_sizes[i] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), wd.padding_low(), 1)) - + OffsetCalculation( + HloOpcode::kMultiply, first_window[i], + MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)); + data_right_halo_sizes[i] = + OffsetCalculation( + HloOpcode::kMultiply, limit_window[i], + MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)) - + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), + unpadded_data_shard_shape.dimensions(i) + wd.stride() + + wd.padding_low() - wd.size(), + 1)); + + int64 max_windows = + (limit_window[i] - first_window[i]).MaxInRange(0, shard_count); + auto first_window_hlo = + first_window[i].Calculate(partition_ordinals[i], &b_); + // Padding on the source is filled with the init value so they do not change + // the data on overlapping windows. + auto resharded_source = ExchangeHaloAndGetValidData( + source_shard_hlo, source.base_shape(), source_left_halo_sizes[i], + source_right_halo_sizes[i], 0, + limit_window[i].Calculate(shard_count - 1), max_windows, i, + hlo->sharding(), first_window_hlo, replicated_init.hlo(), + partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_); + if (!resharded_source) { + return DefaultAction(hlo); + } + source_shard_hlo = *resharded_source; + + auto offset_start_in_data = + MultiplyAddDivideOffsetCalculation(wd.stride(), 0, 1) + .Calculate(first_window_hlo, &b_); + int64 padded_data_size = + (limit_window[i].Calculate(shard_count - 1) - 1) * wd.stride() + + wd.size(); + int64 data_shard_size = (max_windows - 1) * wd.stride() + wd.size(); + auto resharded_data = ExchangeHaloAndGetValidData( + data_shard_hlo, operand.base_shape(), data_left_halo_sizes[i], + data_right_halo_sizes[i], wd.padding_low(), padded_data_size, + data_shard_size, i, hlo->sharding(), offset_start_in_data, pad_value, + partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_); + if (!resharded_data) { + return DefaultAction(hlo); + } + data_shard_hlo = *resharded_data; + } + + Window window_on_shard = hlo->window(); + for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) { + int64 shard_count = hlo->sharding().tile_assignment().dim(i); + if (shard_count == 1) { + continue; + } + auto reshard_wd = window_on_shard.mutable_dimensions(i); + // The shards are already explicitly padded. + reshard_wd->set_padding_low(0); + reshard_wd->set_padding_high(0); + } + + auto sharded_select_and_scatter = + b_.AddInstruction(HloInstruction::CreateSelectAndScatter( + data_shard_hlo->shape(), data_shard_hlo, select, window_on_shard, + source_shard_hlo, replicated_init.hlo(), + hlo->called_computations()[1])); + SetPartitionedHlo(hlo, [&]() { + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + if (ShapeUtil::Compatible(sharded_select_and_scatter->shape(), + shard_shape)) { + return sharded_select_and_scatter; + } + auto zero = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + std::vector slice_offsets(shard_shape.rank(), zero); + for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) { + if (hlo->sharding().tile_assignment().dim(i) == 1) { + continue; + } + int64 pad_low = hlo->window().dimensions(i).padding_low(); + auto left_halo_size = + data_left_halo_sizes[i].Calculate(partition_ordinals[i], &b_); + if (data_left_halo_sizes[i].Calculate(0) == pad_low) { + slice_offsets[i] = left_halo_size; + } else { + auto is_shard0 = b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), zero, partition_ordinals[i], + ComparisonDirection::kEq)); + auto pad_low_hlo = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(pad_low))); + slice_offsets[i] = b_.AddInstruction(HloInstruction::CreateTernary( + zero->shape(), HloOpcode::kSelect, is_shard0, pad_low_hlo, + left_halo_size)); + } + } + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, sharded_select_and_scatter, slice_offsets, + shard_shape.dimensions())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) { + std::vector new_operands; + for (int64 i = 0; i < hlo->operand_count(); ++i) { + new_operands.push_back( + GetPartitionedHlo(hlo->operand(i)) + .Reshard(hlo->sharding().GetSubSharding(hlo->shape(), {i})) + .hlo()); + } + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateTuple(new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs( + HloInstruction* hlo) { + TF_RET_CHECK(hlo->opcode() == HloOpcode::kConvolution); + + auto lhs = GetPartitionedHlo(hlo->operand(0)); + auto rhs = GetPartitionedHlo(hlo->operand(1)); + TF_RET_CHECK(!lhs.sharding().IsTileMaximal() && + !rhs.sharding().IsTileMaximal()); + + const auto& dnums = hlo->convolution_dimension_numbers(); + + // Check if the operand shardings are aligned. Also we currently don't + // support partitioning non-spatial dimensions. + std::vector rhs_to_lhs_indices(hlo->shape().rank()); + rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = + dnums.input_batch_dimension(); + rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = + dnums.input_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = + dnums.input_spatial_dimensions(i); + } + std::vector lhs_to_rhs_indices(hlo->shape().rank()); + for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { + lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; + } + auto aligned_rhs_sharding = + hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); + auto aligned_lhs_sharding = + hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); + + auto unsupported_sharding = [&](const HloSharding& lhs_sharding, + const HloSharding& rhs_sharding) { + return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) != + 1 || + rhs_sharding.tile_assignment().dim( + dnums.kernel_output_feature_dimension()) != 1; + }; + + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + if (ShapeUtil::ByteSizeOf(lhs.base_shape()) < + ShapeUtil::ByteSizeOf(rhs.base_shape())) { + if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) { + return DefaultAction(hlo); + } + lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) { + return DefaultAction(hlo); + } + lhs = lhs.PadWithValue(zero); + rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero); + } + + // Reshard LHS by exchanging halo such that each shard computes the partial + // sum of the full shape result, and add AllReduce. + // + // The size of halo on each dimension can be calculated from the projection + // onto the LHS that each RHS shard i needs to read. RHS and LHS below refers + // to the shard size of RHS and LHS, WC is the number of windows, and D is the + // window dilation. + // + // * offset(i): RHS * D * i - low_padding + // * limit(i): {(RHS - 1) * D + 1} * (i + 1) + (WC - 1) * stride - low_padding + // + // Since shard i has LHS of range [i * LHS, (i + 1) * LHS) + // * left-halo: i * LHS - offset(i) + // = (LHS - RHS) * i + low_padding + // * right-halo: limit(i) - (i + 1) * LHS + // = [{(RHS - 1) * D + 1} - LHS] * (i + 1) + (WC - 1) * stride - low_padding + + Window window = hlo->window(); + std::vector shard_counts(dnums.input_spatial_dimensions_size()); + std::vector lhs_shard_sizes(dnums.input_spatial_dimensions_size()); + std::vector rhs_shard_sizes(dnums.input_spatial_dimensions_size()); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); + int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension); + auto wd = window.dimensions(i); + if (wd.base_dilation() != 1 || wd.window_reversal()) { + return DefaultAction(hlo); + } + + int64 lhs_shard_size = + CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count); + int64 rhs_shard_size = + CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count); + shard_counts[i] = shard_count; + lhs_shard_sizes[i] = lhs_shard_size; + rhs_shard_sizes[i] = rhs_shard_size; + } + + std::vector left_halo_size_functions(hlo->shape().rank()); + std::vector right_halo_size_functions(hlo->shape().rank()); + Window new_window = window; + + auto partition_ordinals = + MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_); + HloInstruction* lhs_with_halo = lhs.hlo(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 lhs_shard_size = lhs_shard_sizes[i]; + int64 rhs_shard_size = rhs_shard_sizes[i]; + + if (shard_counts[i] == 1) { + continue; + } + + // Calculate the left and right halo sizes as described in the comments + // above. + auto wd = window.dimensions(i); + int64 padding_low = wd.padding_low(); + int64 padding_high = wd.padding_high(); + int64 base = lhs.base_shape().dimensions(lhs_dimension); + int64 window_count = 1 + (padding_low + padding_high + base - + (1 + (wd.size() - 1) * wd.window_dilation())) / + wd.stride(); + int64 rhs_shard_size_dilated = + (rhs_shard_size - 1) * wd.window_dilation() + 1; + + left_halo_size_functions[lhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + lhs_shard_size - rhs_shard_size * wd.window_dilation(), padding_low, + 1)); + right_halo_size_functions[lhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_size_dilated - lhs_shard_size, + rhs_shard_size_dilated - lhs_shard_size + + wd.stride() * (window_count - 1) - padding_low, + 1)); + + // Exchange halo and concatenate. + int64 dim = dnums.input_spatial_dimensions(i); + int64 explicit_left_padding_on_full_shape = padding_low; + int64 shard_size_with_halo = + wd.stride() * (window_count - 1) + rhs_shard_size_dilated; + + new_window.mutable_dimensions(i)->set_padding_low(0); + new_window.mutable_dimensions(i)->set_padding_high(0); + new_window.mutable_dimensions(i)->set_size(rhs_shard_size); + + // offset_on_padded_shape and padded_full_shape_size are needed only if + // we want to mask out-of-range values in ExchangeHaloAndGetValidData(). + // Since the default value for both the collective-permute is zero and + // also we call PadWithValue() on both operands at the beginning, we + // don't need to mask here. + // + // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls + // if it's always safe. + auto offset_on_padded_shape = + OffsetCalculation(MultiplyAddDivideOffsetCalculation()); + int64 padded_full_shape_size = 0; + auto concat = ExchangeHaloAndGetValidData( + lhs_with_halo, lhs.base_shape(), left_halo_size_functions[dim], + right_halo_size_functions[dim], explicit_left_padding_on_full_shape, + padded_full_shape_size, shard_size_with_halo, dim, lhs.sharding(), + offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_), zero, + partition_ordinals[dim], collective_ops_creator_, next_channel_id_, &b_, + /*mask_invalid_region=*/false); + if (!concat) { + return DefaultAction(hlo); + } + lhs_with_halo = *concat; + } + + SetPartitionedHlo(hlo, [&]() { + auto conv = b_.AddInstruction(HloInstruction::CreateConvolve( + hlo->shape(), lhs_with_halo, rhs.hlo(), hlo->feature_group_count(), + hlo->batch_group_count(), new_window, + hlo->convolution_dimension_numbers(), hlo->precision_config())); + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { + auto lhs = GetPartitionedHlo(hlo->operand(0)); + auto rhs = GetPartitionedHlo(hlo->operand(1)); + const HloSharding& sharding = hlo->sharding(); + const auto& dnums = hlo->convolution_dimension_numbers(); + std::vector rhs_to_lhs_indices(hlo->shape().rank()); + rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = + dnums.input_batch_dimension(); + rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = + dnums.input_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = + dnums.input_spatial_dimensions(i); + } + std::vector lhs_to_rhs_indices(hlo->shape().rank()); + for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { + lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; + } + auto aligned_rhs_sharding = + hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); + auto aligned_lhs_sharding = + hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); + + // Handling cases where both operands' shardings are aligned. We check that + // the LHS batch dimension is not partitioned because it is mapped to the + // output feature dimension in aligned_rhs_sharding, which are not the same + // dimension. + if (!lhs.sharding().IsTileMaximal() && !rhs.sharding().IsTileMaximal()) { + if (options_.conv_halo_exchange_always_on_lhs) { + return HandleConvolutionTiledLhsAndRhs(hlo); + } else { + // Reshard RHS so that each shard computes the partial sum of the full + // shape result, and add AllReduce. See HandleConvolutionTiledLhsAndRhs() + // that reshards LHS. + // + // The size of halo on each dimension can be calculated from the + // projection onto the RHS that shard i needs to read. RHS and LHS below + // refers to the shard size of RHS and LHS, WC is the number of windows, + // and D is the window dilation. + // + // * offset(i): LHS * i + low_padding - (WC - 1) * stride + // * limit(i): LHS * (i + 1) + low_padding + // + // Since shard i has RHS of range [i * RHS * D, (i + 1) * RHS * D) + // * left-halo: i * RHS - offset(i) + // = i * (RHS * D - LHS) + (WC - 1) * stride - low_padding + // * right-halo: limit(i) - (i + 1) * RHS + // = (i + 1) * (LHS - RHS * D) + low_pading + + auto unsupported_sharding = [&](const HloSharding& lhs_sharding, + const HloSharding& rhs_sharding) { + // We currently don't support partitioning input batch or output feature + // dimensions. + return lhs_sharding.tile_assignment().dim( + dnums.input_batch_dimension()) != 1 || + rhs_sharding.tile_assignment().dim( + dnums.kernel_output_feature_dimension()) != 1; + }; + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + if (ShapeUtil::ByteSizeOf(lhs.base_shape()) < + ShapeUtil::ByteSizeOf(rhs.base_shape())) { + if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) { + return DefaultAction(hlo); + } + lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) { + return DefaultAction(hlo); + } + lhs = lhs.PadWithValue(zero); + rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero); + } + + Window window = hlo->window(); + std::vector shard_counts(dnums.input_spatial_dimensions_size()); + std::vector lhs_shard_sizes(dnums.input_spatial_dimensions_size()); + std::vector rhs_shard_sizes(dnums.input_spatial_dimensions_size()); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); + int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension); + auto wd = window.dimensions(i); + if (wd.base_dilation() != 1 || wd.window_reversal()) { + return DefaultAction(hlo); + } + + int64 lhs_shard_size = CeilOfRatio( + lhs.base_shape().dimensions(lhs_dimension), shard_count); + int64 rhs_shard_size = CeilOfRatio( + rhs.base_shape().dimensions(rhs_dimension), shard_count); + shard_counts[i] = shard_count; + lhs_shard_sizes[i] = lhs_shard_size; + rhs_shard_sizes[i] = rhs_shard_size; + } + + std::vector left_halo_size_functions( + hlo->shape().rank()); + std::vector right_halo_size_functions( + hlo->shape().rank()); + Window new_window = window; + + // Data structures needed for Pad and DynamicSlice on LHS if needed. + bool need_dynamic_slice_lhs = false; + auto partition_ordinals = + MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_); + std::vector zero_padding(hlo->shape().rank()); + PaddingConfig pad_config = + window_util::MakeSymmetricPadding(zero_padding); + auto zero_s32 = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + std::vector dynamic_slice_start_indices( + hlo->shape().rank(), zero_s32); + Shape dynamic_slice_shape = lhs.hlo()->shape(); + Shape pad_shape = lhs.hlo()->shape(); + + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); + int64 lhs_shard_size = lhs_shard_sizes[i]; + int64 rhs_shard_size = rhs_shard_sizes[i]; + + if (shard_counts[i] == 1) { + continue; + } + + // Calculate the left and right halo sizes as described in the comments + // above. It calculcates the halo sizes with dilation, so we apply + // CeilOfRatio({left,right}_halo_size, window_dilation). + auto wd = window.dimensions(i); + int64 padding_low = wd.padding_low(); + int64 padding_high = wd.padding_high(); + int64 base = lhs.base_shape().dimensions(lhs_dimension); + int64 window_count = + 1 + (padding_low + padding_high + base - + (1 + (wd.size() - 1) * wd.window_dilation())) / + wd.stride(); + left_halo_size_functions[rhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_size * wd.window_dilation() - lhs_shard_size, + (window_count - 1) * wd.stride() - padding_low + + wd.window_dilation() - 1, + wd.window_dilation())); + right_halo_size_functions[rhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + lhs_shard_size - rhs_shard_size * wd.window_dilation(), + lhs_shard_size - rhs_shard_size * wd.window_dilation() + + padding_low + wd.window_dilation() - 1, + wd.window_dilation())); + + // New RHS window size includes the maximum of both left and right + // halos. + int64 halo_size = left_halo_size_functions[rhs_dimension].MaxInRange( + 1, shard_counts[i]) + + right_halo_size_functions[rhs_dimension].MaxInRange( + 0, shard_counts[i] - 1); + int64 new_window_size = + rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size; + + // The amount of new low padding could be dynamic (e.g., window_dilation + // != 1), which requires pad (to the maximum) and dynamic slice on LHS. + // + // If we consider the first window, the offset of the dilated RHS that + // aligns with the first valid LHS element for shard i is 'padding_low + + // LHS * i'. When the left halo is added to RHS, the offset of the first + // RHS element is (RHS * i - left_halo) * window_dilation. The + // difference between the two values is the amount of padding_low we + // need on LHS. + auto new_padding_low_function = + OffsetCalculation( + HloOpcode::kMultiply, left_halo_size_functions[rhs_dimension], + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + 0, wd.window_dilation(), 1))) - + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_size * wd.window_dilation() - lhs_shard_size, + -padding_low, 1)); + + int64 new_padding_low_max = + new_padding_low_function.MaxInRange(0, shard_counts[i]); + int64 new_padding_low = new_padding_low_max; + int64 new_padding_high = window_count * wd.stride() + + (new_window_size - 1) * wd.window_dilation() - + new_padding_low - lhs_shard_size; + + // We do pad/dynamic-slice only when the padding is dynamic. + if (!new_padding_low_function.IsConstant()) { + need_dynamic_slice_lhs = true; + new_padding_low = 0; + pad_config.mutable_dimensions(lhs_dimension) + ->set_edge_padding_low(new_padding_low_max); + pad_config.mutable_dimensions(lhs_dimension) + ->set_edge_padding_high(new_padding_low_max); + pad_shape.set_dimensions(lhs_dimension, + lhs_shard_size + 2 * new_padding_low_max); + dynamic_slice_start_indices[lhs_dimension] = + (OffsetCalculation(MultiplyAddDivideOffsetCalculation( + 0, new_padding_low_max, 1)) - + new_padding_low_function) + .Calculate(partition_ordinals[lhs_dimension], &b_); + dynamic_slice_shape.set_dimensions( + lhs_dimension, lhs_shard_size + new_padding_low_max); + } + + // Since the convolution RHS operand size increased with halos, adjust + // the window config accordingly. + new_window.mutable_dimensions(i)->set_padding_low(new_padding_low); + new_window.mutable_dimensions(i)->set_padding_high(new_padding_high); + new_window.mutable_dimensions(i)->set_size( + rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size); + } + + HloInstruction* conv_lhs = lhs.hlo(); + if (need_dynamic_slice_lhs) { + auto pad = b_.AddInstruction( + HloInstruction::CreatePad(pad_shape, lhs.hlo(), zero, pad_config)); + conv_lhs = b_.AddInstruction(HloInstruction::CreateDynamicSlice( + dynamic_slice_shape, pad, dynamic_slice_start_indices, + dynamic_slice_shape.dimensions())); + } + + // Exchange halo and concatenate. + HloInstruction* rhs_with_halo = rhs.hlo(); + for (int i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { + int64 dim = dnums.kernel_spatial_dimensions(i); + int64 explicit_left_padding_on_full_shape = + left_halo_size_functions[dim].Calculate(0); + int64 shard_size_with_halo = new_window.dimensions(i).size(); + + // offset_on_padded_shape and padded_full_shape_size are needed only if + // we want to mask out-of-range values in ExchangeHaloAndGetValidData(). + // Since the default value for both the collective-permute is zero and + // also we call PadWithValue() on both operands at the beginning, we + // don't need to mask here. + // + // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls + // if it's always safe. + auto offset_on_padded_shape = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_sizes[i], explicit_left_padding_on_full_shape, 1)) - + left_halo_size_functions[dim]; + int64 padded_full_shape_size = + offset_on_padded_shape.Calculate(shard_counts[i] - 1) + + new_window.dimensions(i).size(); + auto concat = ExchangeHaloAndGetValidData( + rhs_with_halo, rhs.base_shape(), left_halo_size_functions[dim], + right_halo_size_functions[dim], explicit_left_padding_on_full_shape, + padded_full_shape_size, shard_size_with_halo, dim, rhs.sharding(), + offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_), + zero, partition_ordinals[dim], collective_ops_creator_, + next_channel_id_, &b_, /*mask_invalid_region=*/false); + if (!concat) { + return DefaultAction(hlo); + } + rhs_with_halo = *concat; + } + + SetPartitionedHlo(hlo, [&]() { + auto conv = b_.AddInstruction(HloInstruction::CreateConvolve( + hlo->shape(), conv_lhs, rhs_with_halo, hlo->feature_group_count(), + hlo->batch_group_count(), new_window, dnums, + hlo->precision_config())); + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + } + + if (!sharding.IsTileMaximal()) { + // We don't currently support sharding on output feature dimension. + if (sharding.tile_assignment().dim(dnums.output_feature_dimension()) > 1) { + return DefaultAction(hlo); + } + + // Check if the operand and the output sharding are aligned. + std::vector input_to_output_indices(hlo->shape().rank()); + input_to_output_indices[dnums.input_batch_dimension()] = + dnums.output_batch_dimension(); + input_to_output_indices[dnums.input_feature_dimension()] = + dnums.output_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + input_to_output_indices[dnums.input_spatial_dimensions(i)] = + dnums.output_spatial_dimensions(i); + } + auto target_operand_sharding = + hlo_sharding_util::TransposeSharding(sharding, input_to_output_indices); + lhs = lhs.Reshard(target_operand_sharding); + + // Replicate the RHS. + rhs = rhs.Reshard(HloSharding::Replicate()); + + // Convolution window config does not include batch and feature dimensions, + // whereas ReshardAsWindowedInput() expects the same number of window + // dimensions as the rank of the operand. So add two more trivial + // dimensions. + std::vector ones(hlo->shape().rank(), 1); + auto operand_window = window_util::MakeWindow(ones); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + *operand_window.mutable_dimensions(dnums.input_spatial_dimensions(i)) = + hlo->window().dimensions(i); + } + + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + auto resharded_operand_and_window = lhs.ReshardAsWindowedInput( + operand_window, target_operand_sharding, zero); + if (!resharded_operand_and_window.has_value()) { + return DefaultAction(hlo); + } + Window new_window; + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + *new_window.add_dimensions() = + resharded_operand_and_window->shard_window.dimensions( + dnums.input_spatial_dimensions(i)); + } + TF_ASSIGN_OR_RETURN( + Shape sharded_conv_shape, + ShapeInference::InferConvolveShape( + resharded_operand_and_window->sharded_input->shape(), + rhs.hlo()->shape(), hlo->feature_group_count(), + hlo->batch_group_count(), new_window, dnums)); + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + *sharded_conv_shape.mutable_layout() = shard_shape.layout(); + SetPartitionedHlo(hlo, [&]() { + auto sharded_conv = b_.AddInstruction(HloInstruction::CreateConvolve( + sharded_conv_shape, resharded_operand_and_window->sharded_input, + rhs.hlo(), hlo->feature_group_count(), hlo->batch_group_count(), + new_window, dnums, hlo->precision_config())); + if (!resharded_operand_and_window->dynamic_slice_index_on_output + .has_value()) { + CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape())); + return sharded_conv; + } + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, sharded_conv, + *resharded_operand_and_window->dynamic_slice_index_on_output, + shard_shape.dimensions())); + }); + return Status::OK(); + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { + DotGeneralDimsMapping mapping; + const auto& dnums = hlo->dot_dimension_numbers(); + int64 next_output_dim = 0; + for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) { + mapping.batch_dims.emplace_back(); + mapping.batch_dims.back().lhs = dnums.lhs_batch_dimensions(i); + mapping.batch_dims.back().rhs = dnums.rhs_batch_dimensions(i); + mapping.batch_dims.back().output = next_output_dim++; + } + for (int64 i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { + mapping.contracting_dims.emplace_back(); + mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i); + mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i); + mapping.contracting_dims.back().output = -1; + } + for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) { + if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) || + absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) { + continue; + } + mapping.lhs_non_contracting_dims.emplace_back(); + mapping.lhs_non_contracting_dims.back().lhs = i; + mapping.lhs_non_contracting_dims.back().rhs = -1; + mapping.lhs_non_contracting_dims.back().output = next_output_dim++; + } + for (int64 i = 0; i < hlo->operand(1)->shape().rank(); ++i) { + if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) || + absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) { + continue; + } + mapping.rhs_non_contracting_dims.emplace_back(); + mapping.rhs_non_contracting_dims.back().lhs = -1; + mapping.rhs_non_contracting_dims.back().rhs = i; + mapping.rhs_non_contracting_dims.back().output = next_output_dim++; + } + auto create_sharded_dot = [&](HloInstruction* l, HloInstruction* r, + SpmdBuilder* b) -> StatusOr { + TF_ASSIGN_OR_RETURN( + auto sharded_dot_shape, + ShapeInference::InferDotOpShape(l->shape(), r->shape(), + hlo->dot_dimension_numbers())); + return b->AddInstruction(HloInstruction::CreateDot( + sharded_dot_shape, l, r, hlo->dot_dimension_numbers(), + hlo->precision_config())); + }; + return HandleDotHelper(hlo, mapping, create_sharded_dot); +} + +Status SpmdPartitioningVisitor::HandleDotHelper( + HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) { + const HloSharding& lhs_sharding = hlo->operand(0)->sharding(); + const HloSharding& rhs_sharding = hlo->operand(1)->sharding(); + + // Similar to hlo_sharding_util::TransposeSharding(), but allows + // removing/adding non-partitioned dimensions. + auto transpose_sharding = + [&](const HloSharding& source, absl::Span src_to_tgt, + absl::Span tgt_to_src) -> absl::optional { + if (source.IsTileMaximal()) { + return source; + } + std::vector tgt_dims_skipping_new(tgt_to_src.size(), -1); + int64 skipped_tgt_dims = 0; + for (int64 i = 0; i < tgt_to_src.size(); ++i) { + if (tgt_to_src[i] < 0) { + skipped_tgt_dims++; + } else { + tgt_dims_skipping_new[i] = i - skipped_tgt_dims; + } + } + int64 skipped_src_dims = absl::c_count(src_to_tgt, -1); + std::vector perm(src_to_tgt.size()); + for (int64 i = 0; i < src_to_tgt.size(); ++i) { + if (src_to_tgt[i] < 0) { + if (source.tile_assignment().dim(i) > 1) { + return absl::nullopt; + } + perm[src_to_tgt.size() - skipped_src_dims] = i; + skipped_src_dims--; + } else { + perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i; + } + } + auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm); + if (skipped_tgt_dims == 0) { + return tgt_sharding; + } + auto reshape_tiles = tgt_sharding.tile_assignment(); + std::vector tgt_tiles(tgt_to_src.size(), 1); + for (int64 i = 0; i < tgt_tiles.size(); ++i) { + if (tgt_to_src[i] >= 0) { + tgt_tiles[i] = reshape_tiles.dim(tgt_dims_skipping_new[i]); + } + } + reshape_tiles.Reshape(tgt_tiles); + return HloSharding::Tile(reshape_tiles); + }; + + std::vector lhs_to_rhs_indices(hlo->operand(0)->shape().rank(), -1); + std::vector lhs_to_output_indices(hlo->operand(0)->shape().rank(), -1); + std::vector rhs_to_lhs_indices(hlo->operand(1)->shape().rank(), -1); + std::vector rhs_to_output_indices(hlo->operand(1)->shape().rank(), -1); + std::vector output_to_lhs_indices(hlo->shape().rank(), -1); + std::vector output_to_rhs_indices(hlo->shape().rank(), -1); + auto populate_indices_mapping = + [&](const DotGeneralDimsMapping::DimsMapping& mapping) { + if (mapping.lhs >= 0) { + lhs_to_rhs_indices[mapping.lhs] = mapping.rhs; + lhs_to_output_indices[mapping.lhs] = mapping.output; + } + if (mapping.rhs >= 0) { + rhs_to_lhs_indices[mapping.rhs] = mapping.lhs; + rhs_to_output_indices[mapping.rhs] = mapping.output; + } + if (mapping.output >= 0) { + output_to_lhs_indices[mapping.output] = mapping.lhs; + output_to_rhs_indices[mapping.output] = mapping.rhs; + } + }; + for (const auto& mapping : dims_mapping.batch_dims) { + populate_indices_mapping(mapping); + } + for (const auto& mapping : dims_mapping.contracting_dims) { + populate_indices_mapping(mapping); + } + for (const auto& mapping : dims_mapping.lhs_non_contracting_dims) { + populate_indices_mapping(mapping); + } + for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) { + populate_indices_mapping(mapping); + } + auto lhs_sharding_transposed_to_match_rhs = + transpose_sharding(lhs_sharding, lhs_to_rhs_indices, rhs_to_lhs_indices); + auto rhs_sharding_transposed_to_match_lhs = + transpose_sharding(rhs_sharding, rhs_to_lhs_indices, lhs_to_rhs_indices); + auto lhs_sharding_transposed_to_match_output = transpose_sharding( + lhs_sharding, lhs_to_output_indices, output_to_lhs_indices); + auto rhs_sharding_transposed_to_match_output = transpose_sharding( + rhs_sharding, rhs_to_output_indices, output_to_rhs_indices); + auto output_sharding_transposed_to_match_lhs = transpose_sharding( + hlo->sharding(), output_to_lhs_indices, lhs_to_output_indices); + auto output_sharding_transposed_to_match_rhs = transpose_sharding( + hlo->sharding(), output_to_rhs_indices, rhs_to_output_indices); + + // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output. + auto get_partitions_for_dims = + [&](const HloSharding& sharding, + absl::Span dims, + int lhs_rhs_or_output) { + int64 partitions = 1; + if (sharding.IsTileMaximal()) { + return partitions; + } + for (const auto& dim : dims) { + if (lhs_rhs_or_output == 0) { + partitions *= sharding.tile_assignment().dim(dim.lhs); + } else if (lhs_rhs_or_output == 1) { + partitions *= sharding.tile_assignment().dim(dim.rhs); + } else { + CHECK_EQ(lhs_rhs_or_output, 2); + partitions *= sharding.tile_assignment().dim(dim.output); + } + } + return partitions; + }; + const int64 lhs_batch_partitions = + get_partitions_for_dims(lhs_sharding, dims_mapping.batch_dims, 0); + const int64 rhs_batch_partitions = + get_partitions_for_dims(rhs_sharding, dims_mapping.batch_dims, 1); + const int64 output_batch_partitions = + get_partitions_for_dims(hlo->sharding(), dims_mapping.batch_dims, 2); + const int64 lhs_contracting_partitions = + get_partitions_for_dims(lhs_sharding, dims_mapping.contracting_dims, 0); + const int64 rhs_contracting_partitions = + get_partitions_for_dims(rhs_sharding, dims_mapping.contracting_dims, 1); + const int64 lhs_non_contracting_partitions = get_partitions_for_dims( + lhs_sharding, dims_mapping.lhs_non_contracting_dims, 0); + const int64 rhs_non_contracting_partitions = get_partitions_for_dims( + rhs_sharding, dims_mapping.rhs_non_contracting_dims, 1); + const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims( + hlo->sharding(), dims_mapping.lhs_non_contracting_dims, 2); + const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims( + hlo->sharding(), dims_mapping.rhs_non_contracting_dims, 2); + + auto& lhs = GetPartitionedHlo(hlo->operand(0)); + auto& rhs = GetPartitionedHlo(hlo->operand(1)); + // LHS and RHS are partitioned the same way and only partitioned in batch + // dimensions. + if (lhs_batch_partitions == rhs_batch_partitions && + rhs_batch_partitions == num_partitions_ && + lhs_sharding_transposed_to_match_rhs == rhs_sharding) { + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { + dot->set_sharding(*lhs_sharding_transposed_to_match_output); + return PartitionedHlo(dot, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + + // Try emit batch-partitioned einsum with one operand resharded. Returns + // whether the attempt succeeds. If may_reshard_with_allreduce is false, + // reshard must be done using all-to-all; otherwise this attempt fails. + auto try_emit_output_batch_partitioned_einsum_with_reshard = + [&](bool may_reshard_with_allreduce) -> StatusOr { + // LHS and output are batch partitioned in the same way. + if (lhs_batch_partitions == num_partitions_ && + output_batch_partitions == num_partitions_ && + lhs_sharding_transposed_to_match_output == hlo->sharding()) { + if (!may_reshard_with_allreduce && + !CanReshardWithAllToAll(rhs.sharding(), + *lhs_sharding_transposed_to_match_rhs)) { + return false; + } + auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return true; + } + // RHS and output are batch partitioned in the same way. + if (rhs_batch_partitions == num_partitions_ && + output_batch_partitions == num_partitions_ && + rhs_sharding_transposed_to_match_output == hlo->sharding()) { + if (!may_reshard_with_allreduce && + !CanReshardWithAllToAll(lhs.sharding(), + *rhs_sharding_transposed_to_match_lhs)) { + return false; + } + auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return true; + } + return false; + }; + + { + // Try batch-parallel by resharding one operand, and not using all-reduce. + TF_ASSIGN_OR_RETURN( + bool emitted, + try_emit_output_batch_partitioned_einsum_with_reshard(false)); + if (emitted) { + return Status::OK(); + } + } + + // Try to emit windowed DotGeneral when one operand is partitioned in the same + // way as the output along non-contracting dimensions, but the other operand + // is tiled in other dimensions. + auto emit_windowed_dot_general = [&](int64 matching_operand, + int64 windowing_operand, + bool windowed_at_contracting_dims, + bool windowed_at_batch_dims) { + CHECK_EQ(matching_operand + windowing_operand, 1); + CHECK(!windowed_at_batch_dims || !windowed_at_contracting_dims); + auto unpadded_result_buffer_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto padded_result_buffer_shape = unpadded_result_buffer_shape; + // For windowing at batch/non-contracting dims, we produce the result one + // partition at a time, so we need to pad the shape in case of uneven + // partitioning in order to make dynamic-update-slice in-bound. + if (!windowed_at_contracting_dims) { + padded_result_buffer_shape = GetPaddedShapeForUnevenPartitioning( + padded_result_buffer_shape, + windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output + : *rhs_sharding_transposed_to_match_output); + } + // Mask the padding area of the windowed operand with zero if there is + // uneven partitioning. + if (windowed_at_contracting_dims) { + auto& to_mask = windowing_operand == 0 ? lhs : rhs; + to_mask = + to_mask.PadWithValue(b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type())))); + } + auto result_buffer = CreateZero(padded_result_buffer_shape, &b_); + auto iteration = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + + // Create a while loop that computes one window per iteration. During each + // iteration, each partition sends its input window to its neighbor using + // collective-permute for the next iteration. + SpmdBuilder body_b("windowed_dot_general_body", visiting_hlo_); + auto param = body_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, + ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(), + result_buffer->shape(), iteration->shape()}), + "param")); + auto l = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(lhs.hlo()->shape(), param, 0)); + auto r = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(rhs.hlo()->shape(), param, 1)); + auto o = body_b.AddInstruction(HloInstruction::CreateGetTupleElement( + result_buffer->shape(), param, 2)); + auto i = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(iteration->shape(), param, 3)); + + auto partition_id = collective_ops_creator_.create_partition_id(&body_b); + auto data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kAdd, i, partition_id)); + auto partition_count = body_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions_))); + data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kRemainder, data_partition_id, partition_count)); + auto dot_lhs = l; + auto dot_rhs = r; + if (windowed_at_contracting_dims || windowed_at_batch_dims) { + // Slice the matching operand according to the partitioned contracting + // dimensions on the windowed operand. We do this by treating the matching + // operand as replicated, and resharding it to match the windowed operand. + auto slice_operand = matching_operand == 0 ? l : r; + slice_operand->set_sharding(HloSharding::Replicate()); + auto state = MakePartitioningState(); + state.b = &body_b; + state.partition_id = data_partition_id; + auto slice = PartitionedHlo(slice_operand, slice_operand->shape(), state) + .Reshard(windowing_operand == 0 + ? *lhs_sharding_transposed_to_match_rhs + : *rhs_sharding_transposed_to_match_lhs) + .hlo(); + slice_operand->clear_sharding(); + if (matching_operand == 0) { + dot_lhs = slice; + } else { + dot_rhs = slice; + } + } + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(dot_lhs, dot_rhs, &body_b)); + if (windowed_at_contracting_dims) { + // Accumulate the partial output to the result buffer. + o = body_b.AddInstruction( + HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot)); + } else { + // The windowing operand is partitioned along batch/non-contracting + // dimensions, so we need a dynamic-update-slice to save the partial + // output in the result buffer. + auto offsets = MakePartitionOffsets( + o->shape(), + windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output + : *rhs_sharding_transposed_to_match_output, + data_partition_id, &body_b); + o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + o->shape(), o, dot, offsets)); + } + + // ++i + i = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kAdd, i, + body_b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))))); + auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), i, + body_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions_))), + ComparisonDirection::kLt)); + // Collective-permute for the next window. We don't need it for the last + // iteration, so we use a conditional around the collective-permute. + HloInstruction* conditional; + { + SpmdBuilder cp_b("window_collective_permute", visiting_hlo_); + { + auto p = cp_b.AddInstruction(HloInstruction::CreateParameter( + 0, windowing_operand == 0 ? l->shape() : r->shape(), "window")); + std::vector> sd_pairs(num_partitions_); + for (int64 source = 0; source < num_partitions_; ++source) { + // 0 -> n-1, 1 -> 0, 2 -> 1, ... + sd_pairs[source] = {source, + (source - 1 + num_partitions_) % num_partitions_}; + } + collective_ops_creator_.create_cross_partition_collective_permute( + &cp_b, p, sd_pairs, (*next_channel_id_)++); + } + SpmdBuilder ncp_b("last_iteration_noop", visiting_hlo_); + { + ncp_b.AddInstruction(HloInstruction::CreateParameter( + 0, windowing_operand == 0 ? l->shape() : r->shape(), "window")); + } + conditional = body_b.AddInstruction(HloInstruction::CreateConditional( + windowing_operand == 0 ? l->shape() : r->shape(), has_more, + windowing_operand == 0 ? l : r, + module_->AddEmbeddedComputation(cp_b.Build()), + windowing_operand == 0 ? l : r, + module_->AddEmbeddedComputation(ncp_b.Build()))); + } + if (windowing_operand == 0) { + l = conditional; + } else { + r = conditional; + } + body_b.AddInstruction(HloInstruction::CreateTuple({l, r, o, i})); + + SpmdBuilder cond_b("windowed_dot_general_cond", visiting_hlo_); + auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, + ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(), + result_buffer->shape(), iteration->shape()}), + "param")); + auto cond_i = cond_b.AddInstruction(HloInstruction::CreateGetTupleElement( + iteration->shape(), cond_param, 3)); + cond_b.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), cond_i, + cond_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions_))), + ComparisonDirection::kLt)); + auto while_loop = b_.AddInstruction(HloInstruction::CreateWhile( + cond_param->shape(), module_->AddEmbeddedComputation(cond_b.Build()), + module_->AddEmbeddedComputation(body_b.Build()), + b_.AddInstruction(HloInstruction::CreateTuple( + {lhs.hlo(), rhs.hlo(), result_buffer, iteration})))); + windowed_dot_general_loops_.push_back({while_loop, windowing_operand, + windowed_at_contracting_dims, + windowed_at_batch_dims}); + SetPartitionedHlo(hlo, [&] { + auto result = b_.AddInstruction(HloInstruction::CreateGetTupleElement( + result_buffer->shape(), while_loop, 2)); + if (!ShapeUtil::Compatible(padded_result_buffer_shape, + unpadded_result_buffer_shape)) { + result = b_.AddInstruction(HloInstruction::CreateSlice( + unpadded_result_buffer_shape, result, + std::vector(padded_result_buffer_shape.rank(), 0), + unpadded_result_buffer_shape.dimensions(), + std::vector(padded_result_buffer_shape.rank(), 1))); + } + return result; + }); + return Status::OK(); + }; + if (output_lhs_non_contracting_partitions == num_partitions_ && + output_sharding_transposed_to_match_lhs == lhs_sharding && + ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()) >= + options_.threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (rhs_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(0, 1, true, false); + } + if (rhs_non_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(0, 1, false, false); + } + if (rhs_batch_partitions == num_partitions_) { + return emit_windowed_dot_general(0, 1, false, true); + } + } + if (output_rhs_non_contracting_partitions == num_partitions_ && + output_sharding_transposed_to_match_rhs == rhs_sharding && + ShapeUtil::ByteSizeOf(hlo->operand(0)->shape()) >= + options_.threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (lhs_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(1, 0, true, false); + } + if (lhs_non_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(1, 0, false, false); + } + if (lhs_batch_partitions == num_partitions_) { + return emit_windowed_dot_general(1, 0, false, true); + } + } + + { + // Try batch-parallel by resharding one operand, and allowing all-reduce. + TF_ASSIGN_OR_RETURN( + bool emitted, + try_emit_output_batch_partitioned_einsum_with_reshard(true)); + if (emitted) { + return Status::OK(); + } + } + + // LHS and RHS have the same partitioned contracting dimensions. + if (lhs_contracting_partitions == rhs_contracting_partitions && + lhs_contracting_partitions == num_partitions_) { + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + // Pad both sides with zero, since NaN at one side cannot be masked by zero + // on the other side. + if (ShapeUtil::ByteSizeOf(lhs.base_shape()) < + ShapeUtil::ByteSizeOf(rhs.base_shape())) { + lhs = + lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + lhs = lhs.PadWithValue(zero); + rhs = + rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); + } + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + + // LHS and output have the same partitioned non-contracting dimensions. + if (lhs_non_contracting_partitions == num_partitions_ && + output_lhs_non_contracting_partitions == num_partitions_ && + lhs_sharding == hlo->sharding()) { + auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo(); + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs_replicated, &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + + // RHS and output have the same partitioned non-contracting dimensions. + if (rhs_non_contracting_partitions == num_partitions_ && + output_rhs_non_contracting_partitions == num_partitions_ && + rhs_sharding_transposed_to_match_output == hlo->sharding()) { + auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo(); + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs_replicated, rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + + // Output is batch partitioned. + if (output_batch_partitions == num_partitions_) { + auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); + auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), + resharded_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + // Output is partitioned along LHS non-contracting dimensions. + if (output_lhs_non_contracting_partitions == num_partitions_) { + auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); + auto replicated_rhs = rhs.Reshard(HloSharding::Replicate()); + TF_ASSIGN_OR_RETURN( + auto dot, + create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + // Output is partitioned along RHS non-contracting dimensions. + if (output_rhs_non_contracting_partitions == num_partitions_) { + auto replicated_lhs = lhs.Reshard(HloSharding::Replicate()); + auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(), + resharded_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + + // Returns true if it is beneficial to reshard the operand at `operand_idx` + // across the contracting dimension. + const auto should_partition_contracting_dim = [&](int64 operand_idx) { + if (!hlo->sharding().IsReplicated()) { + return false; + } + + if (operand_idx == 0) { + // If LHS and output are replicated, we compare the cost of all-gather + // on RHS vs all-reduce on the output. + return (rhs_contracting_partitions == num_partitions_) && + lhs.sharding().IsReplicated() && + ShapeUtil::ElementsIn(hlo->operand(1)->shape()) > + ShapeUtil::ElementsIn(hlo->shape()); + } else { + return (lhs_contracting_partitions == num_partitions_) && + rhs.sharding().IsReplicated() && + ShapeUtil::ElementsIn(hlo->operand(0)->shape()) > + ShapeUtil::ElementsIn(hlo->shape()); + } + }; + + // When the output is replicated and one of the operands is partitioned along + // contracting dimension, align the other operand to be partitioned along + // the contracting dimensions. + if (hlo->sharding().IsReplicated() && (should_partition_contracting_dim(0) || + should_partition_contracting_dim(1))) { + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + if (should_partition_contracting_dim(0)) { + lhs = + lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + lhs = lhs.PadWithValue(zero); + rhs = + rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); + } + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()).hlo(); + }); + return Status::OK(); + } + + return DefaultAction(hlo); +} + +namespace { + +// Finds a cluster of nodes that produce the inputs for `hlo` which only depend +// on small operands, which means the cluster should start with broadcasts, +// constants and iotas. All other internal nodes must be non-side-effecting +// elemntwise ops. Returns the set of nodes, and the small operands. E.g., for +// the following graph, +// +// a -> broadcast -> multiply +// iota ---> add--/ +// constant/ +// +// FindInputNodesIfOnlyDependOnSmallOperands(multiply) will return +// <{broadcast, iota, constant, add, multiply}, [a]>. +std::pair, std::vector> +FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction* hlo) { + std::unordered_set nodes_found; + std::vector new_operands; + std::unordered_set new_operands_set; + std::vector worklist; + worklist.push_back(hlo); + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + if (nodes_found.count(inst) > 0) { + continue; + } + if (inst->opcode() == HloOpcode::kBroadcast || + inst->opcode() == HloOpcode::kConstant || + inst->opcode() == HloOpcode::kIota) { + nodes_found.insert(inst); + for (auto o : inst->operands()) { + auto res = new_operands_set.emplace(o); + if (res.second) { + new_operands.push_back(o); + } + } + } else if (inst->IsElementwise() && !inst->HasSideEffectNoRecurse() && + inst->opcode() != HloOpcode::kAllReduce && + absl::c_all_of(inst->operands(), + [inst](const HloInstruction* o) { + return ShapeUtil::CompatibleIgnoringElementType( + o->shape(), inst->shape()); + })) { + nodes_found.insert(inst); + for (auto o : inst->operands()) { + worklist.push_back(o); + } + } else { + nodes_found.clear(); + new_operands.clear(); + break; + } + } + return {std::move(nodes_found), std::move(new_operands)}; +} + +// Moves a cluster of memory-reducing nodes into the windowed dot-general loop +// on contracting dimensions. Such a loop has a dynamic slice on the +// non-windowed operand. If we move the input nodes into the loop, the +// dynamic-slice could be merged with them by later optimization passes, which +// reduces memory. +// +// small_operands small_operands +// | | +// input_nodes loop { | +// | => input_nodes +// loop { | | +// dynamic-slice dynamic-slice +// ... ... +// } } +// +// Later optimization passes (TpuPadSliceMover) will merge the dynamic slice +// with the input nodes. +Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions( + HloInstruction* loop, int64 non_windowed_operand_index) { + auto input_tuple = loop->mutable_operand(0); + auto old_operand = input_tuple->mutable_operand(non_windowed_operand_index); + auto input_nodes = FindInputNodesIfOnlyDependOnSmallOperands(old_operand); + auto to_sink = std::move(input_nodes.first); + auto new_operands = std::move(input_nodes.second); + if (to_sink.empty()) { + return Status::OK(); + } + auto computation = loop->parent(); + // Replace the old operand with a tuple of the found small operands. + auto new_input_subtuple = + computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR(input_tuple->ReplaceOperandWithDifferentShape( + non_windowed_operand_index, new_input_subtuple)); + + auto body = loop->while_body(); + auto body_param = body->parameter_instruction(0); + auto old_body_param_users = body_param->users(); + // Update all tuple shapes. + for (auto tuple : std::vector{ + input_tuple, loop, loop->while_condition()->parameter_instruction(0), + body_param, body->root_instruction()}) { + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), + {non_windowed_operand_index}) = + new_input_subtuple->shape(); + } + // Now update the loop body. + auto new_operand_tuple_inside = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_input_subtuple->shape(), body_param, non_windowed_operand_index)); + TF_RETURN_IF_ERROR(body->root_instruction()->ReplaceOperandWithDifferentShape( + non_windowed_operand_index, new_operand_tuple_inside)); + + // Create nodes inside the loop body. + std::vector worklist; + std::unordered_map outside_to_inside; + auto add_users_if_available = [&](HloInstruction* inst) { + for (auto u : inst->users()) { + if (outside_to_inside.count(u) == 0 && to_sink.count(u) > 0 && + absl::c_all_of(u->operands(), [&](const HloInstruction* o) { + return outside_to_inside.count(o) > 0; + })) { + worklist.push_back(u); + } + } + }; + for (int64 i = 0; i < new_operands.size(); ++i) { + outside_to_inside[new_operands[i]] = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_operands[i]->shape(), new_operand_tuple_inside, i)); + add_users_if_available(new_operands[i]); + } + // HLOs to sink without operands. + std::vector nullaries_to_sink; + for (auto inst : to_sink) { + if (inst->operand_count() == 0) { + nullaries_to_sink.push_back(inst); + } + } + // Sort nullaries_to_sink to make it deterministic. + absl::c_sort(nullaries_to_sink, + [](const HloInstruction* a, const HloInstruction* b) { + return a->unique_id() < b->unique_id(); + }); + for (auto inst : nullaries_to_sink) { + worklist.push_back(inst); + } + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + std::vector inst_new_operands(inst->operand_count()); + for (int64 i = 0; i < inst->operand_count(); ++i) { + inst_new_operands[i] = outside_to_inside[inst->operand(i)]; + } + outside_to_inside[inst] = body->AddInstruction( + inst->CloneWithNewOperands(inst->shape(), inst_new_operands)); + add_users_if_available(inst); + } + TF_RET_CHECK(outside_to_inside.count(old_operand) > 0); + for (auto ou : old_body_param_users) { + if (ou->opcode() == HloOpcode::kGetTupleElement && + ou->tuple_index() == non_windowed_operand_index) { + TF_RETURN_IF_ERROR( + ou->ReplaceAllUsesWith(outside_to_inside[old_operand])); + TF_RETURN_IF_ERROR(body->RemoveInstruction(ou)); + } + } + return Status::OK(); +} + +// Moves a cluster of memory-reducing nodes (with reduce nodes at the end) into +// the windowed dot-general loop on non-contracting dimensions. Such a loop has +// a dynamic-update-slice at the output. If we move the user nodes into the loop +// and before the dynamic-update-slice, the user nodes can operate on smaller +// shapes, which reduces memory. +// +// small_operands small_operands +// | | => | | +// | | loop { loop { | | +// | | conv | broadcast conv +// | | | | | / +// | | dynamic-update-slice | dynamic-slice / +// | | | | | / +// | | } | | multiply----- +// |broadcast / | / +// | | / reduce +// |multiply-- | +// \ | dynamic-update-slice +// reduce } +// +// Later optimization passes (TpuPadSliceMover) will merge the dynamic slice +// with the input nodes (broadcast). +Status MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions( + HloInstruction* loop) { + CHECK_EQ(loop->user_count(), 1); + // There should be a single direct user of the while loop, which is the + // gte for element 2, i.e., the dot output. + auto user_gte = loop->users().front(); + CHECK_EQ(user_gte->opcode(), HloOpcode::kGetTupleElement); + CHECK_EQ(user_gte->tuple_index(), 2); + auto computation = loop->parent(); + + // Find the reduce outputs and the input nodes they depend on, if input nodes + // only have small operands. + std::unordered_set to_move; + std::vector new_operands; + std::unordered_set new_operands_set; + std::vector reduce_outputs; + std::vector worklist; + Shape padded_shape = user_gte->shape(); + Shape unpadded_shape = user_gte->shape(); + auto original_output = user_gte; + + if (user_gte->user_count() == 1 && + user_gte->users().back()->opcode() == HloOpcode::kSlice) { + original_output = user_gte->users().back(); + unpadded_shape = original_output->shape(); + } + for (auto u : original_output->users()) { + worklist.push_back(u); + } + to_move.insert(original_output); + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + if (to_move.count(inst) > 0) { + continue; + } + // We only support reduces with simple reduction function, since we may need + // to accumulate across iterations manually. + if (inst->opcode() == HloOpcode::kReduce && + inst->to_apply()->instruction_count() == 3 && + inst->to_apply()->num_parameters() == 2 && + inst->to_apply()->root_instruction()->IsElementwise()) { + to_move.insert(inst); + auto other_operand = inst->mutable_operand(1); + auto res = new_operands_set.emplace(other_operand); + if (res.second) { + new_operands.push_back(other_operand); + } + reduce_outputs.push_back(inst); + } else if (inst != computation->root_instruction() && + inst->user_count() > 0 && inst->IsElementwise() && + !inst->HasSideEffectNoRecurse() && + inst->opcode() != HloOpcode::kAllReduce && + absl::c_all_of(inst->operands(), + [inst](const HloInstruction* o) { + return ShapeUtil::CompatibleIgnoringElementType( + o->shape(), inst->shape()); + })) { + // For an elementwise op, we need to make sure that they depend on only + // nodes already in to_move and nodes with small operands. + bool can_include = true; + for (auto operand : inst->operands()) { + if (to_move.count(operand) > 0) { + continue; + } + auto find_result = FindInputNodesIfOnlyDependOnSmallOperands(operand); + if (find_result.first.empty()) { + can_include = false; + break; + } + for (auto n : find_result.first) { + to_move.insert(n); + } + for (auto new_operand : find_result.second) { + auto res = new_operands_set.insert(new_operand); + if (res.second) { + new_operands.push_back(new_operand); + } + } + } + if (!can_include) { + to_move.clear(); + break; + } + to_move.insert(inst); + for (auto u : inst->users()) { + worklist.push_back(u); + } + } else { + to_move.clear(); + break; + } + } + // If nothing is found, to_move could contain only original_output, or cleared + // by the above code. + if (to_move.size() <= 1) { + return Status::OK(); + } + + // We will replace the original loop output with reduce-shape outputs. Create + // the initial buffers before the loop. + for (auto out : reduce_outputs) { + auto padded_out_shape = out->shape(); + int64 operand_dim = 0; + int64 output_dim = 0; + while (output_dim < padded_out_shape.rank()) { + if (absl::c_linear_search(out->dimensions(), operand_dim)) { + // Dimension colapsed. + ++operand_dim; + continue; + } + // Kept dimensions have the same size of the padded shape. + padded_out_shape.set_dimensions(output_dim, + padded_shape.dimensions(operand_dim)); + ++operand_dim; + ++output_dim; + } + auto broadcast = + computation->AddInstruction(HloInstruction::CreateBroadcast( + padded_out_shape, + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(out->shape().element_type()))), + {})); + new_operands.push_back(broadcast); + } + + auto input_tuple = loop->mutable_operand(0); + // Create the new input subtuple that contains the small operands and the + // reduce-shape result buffers. + auto new_input_subtuple = + computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR( + input_tuple->ReplaceOperandWithDifferentShape(2, new_input_subtuple)); + auto body = loop->while_body(); + auto body_param = body->parameter_instruction(0); + auto body_root = body->root_instruction(); + CHECK_EQ(body_root->opcode(), HloOpcode::kTuple); + // Update tuple shapes. + for (auto tuple : std::vector{ + input_tuple, loop, loop->while_condition()->parameter_instruction(0), + body_param, body_root}) { + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {2}) = + new_input_subtuple->shape(); + } + auto new_loop_input = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_input_subtuple->shape(), body_param, 2)); + + // Now create the moved nodes inside the loop body. + std::unordered_map outside_to_inside; + worklist.clear(); + auto add_users_if_available = [&](HloInstruction* inst) { + for (auto u : inst->users()) { + if (outside_to_inside.count(u) == 0 && to_move.count(u) > 0 && + absl::c_all_of(u->operands(), [&](const HloInstruction* o) { + return outside_to_inside.count(o) > 0; + })) { + worklist.push_back(u); + } + } + }; + for (int64 i = 0; i < new_operands.size(); ++i) { + outside_to_inside[new_operands[i]] = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_operands[i]->shape(), new_loop_input, i)); + add_users_if_available(new_operands[i]); + } + // The elementwise nodes will be created with sliced shape. The original loop + // output corresponds to the dynamic-update-slice's update slice. + auto dus = body_root->mutable_operand(2); + CHECK_EQ(dus->opcode(), HloOpcode::kDynamicUpdateSlice); + outside_to_inside[original_output] = dus->mutable_operand(1); + add_users_if_available(original_output); + std::vector slice_offsets(padded_shape.rank()); + for (int64 i = 0; i < slice_offsets.size(); ++i) { + slice_offsets[i] = dus->mutable_operand(i + 2); + } + auto get_slice = [&](HloInstruction* padded) { + return body->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::ChangeElementType(dus->operand(1)->shape(), + padded->shape().element_type()), + padded, slice_offsets, dus->operand(1)->shape().dimensions())); + }; + // Helper functions to create nodes with small operands. + auto add_broadcast = [&](const HloInstruction* broadcast) { + auto padded_operand_shape = broadcast->operand(0)->shape(); + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + padded_operand_shape.set_dimensions( + i, padded_shape.dimensions(broadcast->dimensions(i))); + } + auto padded_operand = PadToShape(outside_to_inside[broadcast->operand(0)], + padded_operand_shape, nullptr, body); + outside_to_inside[broadcast] = + get_slice(body->AddInstruction(broadcast->CloneWithNewOperands( + ShapeUtil::ChangeElementType(padded_shape, + padded_operand_shape.element_type()), + {padded_operand}))); + }; + auto add_iota = [&](const HloInstruction* iota) { + outside_to_inside[iota] = + get_slice(body->AddInstruction(iota->CloneWithNewOperands( + ShapeUtil::ChangeElementType(padded_shape, + iota->shape().element_type()), + {}))); + }; + auto add_constant = [&](const HloInstruction* constant) { + outside_to_inside[constant] = body->AddInstruction(constant->Clone()); + outside_to_inside[constant] = get_slice( + PadToShape(outside_to_inside[constant], + ShapeUtil::ChangeElementType( + padded_shape, constant->shape().element_type()), + nullptr, body)); + }; + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + if (outside_to_inside.count(inst) > 0) { + continue; + } + if (inst->opcode() == HloOpcode::kBroadcast) { + add_broadcast(inst); + } else if (inst->opcode() == HloOpcode::kIota) { + add_iota(inst); + } else if (inst->opcode() == HloOpcode::kConstant) { + add_constant(inst); + } else if (inst->opcode() == HloOpcode::kReduce) { + // This is an output, for which we has special handling later. + } else { + std::vector operands_inside(inst->operand_count()); + for (int64 i = 0; i < operands_inside.size(); ++i) { + operands_inside[i] = outside_to_inside[inst->operand(i)]; + } + outside_to_inside[inst] = body->AddInstruction(inst->CloneWithNewOperands( + ShapeUtil::ChangeElementType(dus->operand(1)->shape(), + inst->shape().element_type()), + operands_inside)); + } + add_users_if_available(inst); + } + std::vector new_outputs_inside(new_operands.size()); + for (int64 i = 0; i < new_outputs_inside.size(); ++i) { + new_outputs_inside[i] = outside_to_inside[new_operands[i]]; + } + // Now create the reduce outpus inside of the loop. + for (int64 i = 0; i < reduce_outputs.size(); ++i) { + auto reduce_outside = reduce_outputs[i]; + CHECK_EQ(reduce_outside->opcode(), HloOpcode::kReduce); + int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i; + auto last_iter_result = outside_to_inside[new_operands[index_in_operand]]; + auto operand0 = outside_to_inside[reduce_outside->operand(0)]; + auto operand1 = outside_to_inside[reduce_outside->operand(1)]; + TF_ASSIGN_OR_RETURN(auto reduce_shape, + ShapeInference::InferReduceShape( + {&operand0->shape(), &operand1->shape()}, + reduce_outside->dimensions(), + reduce_outside->to_apply()->ComputeProgramShape())); + *reduce_shape.mutable_layout() = reduce_outside->shape().layout(); + std::vector reduce_dus_offsets; + // If any collapsed dimension is windowed, we need to accumulate with last + // iteration's result. If such a dimension has padding, we also need to mask + // off invalid data. + bool needs_accumulate = false; + std::vector dims_to_mask; + for (int64 i = 0; i < slice_offsets.size(); ++i) { + if (absl::c_linear_search(reduce_outside->dimensions(), i)) { + if (reduce_outside->operand(0)->shape().dimensions(i) != + operand0->shape().dimensions(i)) { + needs_accumulate = true; + if (unpadded_shape.dimensions(i) != padded_shape.dimensions(i)) { + dims_to_mask.push_back(i); + } + } + continue; + } + reduce_dus_offsets.push_back(slice_offsets[i]); + } + // Mask off invalid data in collapsed dimensions. + for (int64 dim : dims_to_mask) { + auto iota = body->AddInstruction(HloInstruction::CreateIota( + ShapeUtil::ChangeElementType(operand0->shape(), S32), dim)); + auto add = body->AddInstruction(HloInstruction::CreateBinary( + iota->shape(), HloOpcode::kAdd, iota, + body->AddInstruction(HloInstruction::CreateBroadcast( + iota->shape(), slice_offsets[dim], {})))); + auto limit = body->AddInstruction(HloInstruction::CreateBroadcast( + iota->shape(), + body->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + reduce_outside->operand(0)->shape().dimensions(dim)))), + {})); + auto compare = body->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(iota->shape(), PRED), add, limit, + ComparisonDirection::kLt)); + operand0 = body->AddInstruction(HloInstruction::CreateTernary( + operand0->shape(), HloOpcode::kSelect, compare, operand0, + body->AddInstruction(HloInstruction::CreateBroadcast( + operand0->shape(), operand1, {})))); + } + auto output_inside = + body->AddInstruction(reduce_outside->CloneWithNewOperands( + reduce_shape, {operand0, operand1})); + // Accumulate with previous results if needed. + if (needs_accumulate) { + auto input_slice = + body->AddInstruction(HloInstruction::CreateDynamicSlice( + output_inside->shape(), last_iter_result, reduce_dus_offsets, + output_inside->shape().dimensions())); + output_inside = body->AddInstruction(HloInstruction::CreateBinary( + output_inside->shape(), + reduce_outside->to_apply()->root_instruction()->opcode(), + output_inside, input_slice)); + } + // Dynamic-update-slice if needed. + if (!ShapeUtil::Compatible(output_inside->shape(), + last_iter_result->shape())) { + output_inside = + body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + last_iter_result->shape(), last_iter_result, output_inside, + reduce_dus_offsets)); + } + new_outputs_inside[index_in_operand] = output_inside; + } + // Body output. + auto new_output_inside = + body->AddInstruction(HloInstruction::CreateTuple(new_outputs_inside)); + TF_RETURN_IF_ERROR( + body_root->ReplaceOperandWithDifferentShape(2, new_output_inside)); + TF_RETURN_IF_ERROR(body->RemoveInstructionAndUnusedOperands(dus)); + // Replace uses of the reduces outside the loop. + auto new_output_gte = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_output_inside->shape(), loop, 2)); + for (int64 i = 0; i < reduce_outputs.size(); ++i) { + int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i; + auto new_output = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_outputs_inside[index_in_operand]->shape(), new_output_gte, + index_in_operand)); + if (!ShapeUtil::Compatible(new_output->shape(), + reduce_outputs[i]->shape())) { + new_output = computation->AddInstruction(HloInstruction::CreateSlice( + reduce_outputs[i]->shape(), new_output, + std::vector(new_output->shape().rank(), 0), + reduce_outputs[i]->shape().dimensions(), + std::vector(new_output->shape().rank(), 1))); + } + TF_RETURN_IF_ERROR(reduce_outputs[i]->ReplaceAllUsesWith(new_output)); + TF_RETURN_IF_ERROR( + computation->RemoveInstructionAndUnusedOperands(reduce_outputs[i])); + } + return Status::OK(); +} + +} // namespace + +Status SpmdPartitioningVisitor::DoCodeMotionForWindowedDotGeneralLoops( + HloComputation* computation) { + for (auto& loop : windowed_dot_general_loops_) { + if (loop.windowed_in_contracting_dims || loop.windowed_in_batch_dims) { + // We have a dynamic-slice for the non-windowed operand in + // batch/contracting-dim windowed dot-general. So moving the + // broadcast/iota/elementwise ops into the loop could help reduce memory + // via fusion. + TF_RETURN_IF_ERROR( + SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions( + loop.while_loop, 1 - loop.windowed_operand)); + } + if (!loop.windowed_in_contracting_dims) { + // We have a dynamic-update-slice for the output in + // batch/non-contracting-dim windowed dot-general. So moving reduce ops + // into the loop could help reduce memory. + TF_RETURN_IF_ERROR( + MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions( + loop.while_loop)); + } + } + return Status::OK(); +} + +StatusOr SpmdPartitioningVisitor::DoPartition( + HloComputation* computation, const HloSharding& root_sharding) { + VLOG(2) << "Partitioning computation " << computation->name() << " for " + << num_replicas_ << " replicas and " << num_partitions_ + << " partitions"; + TF_RETURN_IF_ERROR(computation->Accept(this)); + + HloModule* module = computation->parent(); + auto new_root = + GetPartitionedHlo(computation->root_instruction()).Reshard(root_sharding); + auto new_computation = + module->AddEmbeddedComputation(b_.Build(new_root.hlo())); + TF_RETURN_IF_ERROR(DoCodeMotionForWindowedDotGeneralLoops(new_computation)); + + // Replace the original computation with the new SPMD computation. + std::unordered_map replacement; + replacement[computation] = new_computation; + module->ReplaceComputations(replacement); + return changed_; +} + +Status SpmdPartitioningVisitor::HandlePartitionId(HloInstruction* hlo) { + return Unimplemented( + "PartitionId instruction is not supported for SPMD partitioning since " + "the meaning is ambiguous -- whether the instruction is replicated or " + "the data is replicated, and if the latter which data is replicated."); +} + +SpmdPartitioner::SpmdPartitioner(int64 num_partitions, int64 num_replicas, + SpmdPartitionerOptions options) + : SpmdPartitioner( + num_partitions, num_replicas, std::move(options), + SPMDCollectiveOpsCreator{ + [](SpmdBuilder* b) { + return b->AddInstruction(HloInstruction::CreatePartitionId()); + }, + [num_replicas](SpmdBuilder* b, HloInstruction* operand, + HloComputation* reduction, int64 channel_id) { + return b->AddInstruction(HloInstruction::CreateAllReduce( + operand->shape(), {operand}, reduction, + CreateReplicaGroups(num_replicas), + /*constrain_layout=*/false, channel_id, + /*use_global_device_ids=*/false)); + }, + [](SpmdBuilder* b, HloInstruction* operand, + std::vector>& src_dst_pairs, + int64 channel_id) { + return b->AddInstruction( + HloInstruction::CreateCollectivePermute( + operand->shape(), operand, src_dst_pairs, channel_id)); + }, + [](SpmdBuilder* b, absl::Span operands, + const std::vector& replica_groups, + int64 channel_id, absl::optional split_dimension) { + std::vector shapes(operands.size(), + operands[0]->shape()); + const Shape output_shape = + (shapes.size() == 1) ? shapes[0] + : ShapeUtil::MakeTupleShape(shapes); + return b->AddInstruction(HloInstruction::CreateAllToAll( + output_shape, operands, replica_groups, + /*constrain_layout=*/false, channel_id, split_dimension)); + }, + }) {} + +StatusOr SpmdPartitioner::PartitionComputation( + HloComputation* computation, const HloSharding& root_sharding, + int64* next_channel_id, SpmdLogger* logger) { + auto visitor = + CreateVisitor(computation, num_partitions_, num_replicas_, + collective_ops_creator_, next_channel_id, logger, options_); + return visitor->DoPartition(computation, root_sharding); +} + +std::unique_ptr SpmdPartitioner::CreateVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, + SpmdPartitionerOptions options) { + return absl::make_unique( + computation, num_partitions, num_replicas, collective_ops_creator, + next_channel_id, logger, std::move(options), this); +} + +StatusOr SpmdPartitioner::Run(HloModule* module) { + TF_RETURN_IF_ERROR(PreprocessSharding(module)); + + XLA_VLOG_LINES(1, SpmdLogger::ReportBeforePartition( + *module, options_.report_instruction_count)); + + // Add the parameters' and output's shardings to the module. + std::vector entry_params_shardings; + for (int64 i = 0; i < module->entry_computation()->num_parameters(); ++i) { + auto param = module->entry_computation()->parameter_instruction(i); + CHECK(param->has_sharding()) << "Missing sharding in entry parameter " << i; + entry_params_shardings.push_back(param->sharding()); + } + module->set_spmd_parameters_shardings(entry_params_shardings); + auto entry_root = module->entry_computation()->root_instruction(); + CHECK(entry_root->has_sharding()) << "Missing sharding in entry root."; + module->set_spmd_output_sharding(entry_root->sharding()); + + FlattenCallGraph flatten; + TF_ASSIGN_OR_RETURN(auto changed, flatten.Run(module)); + + SpmdLogger logger(options_.report_instruction_count); + auto program_shape = module->entry_computation()->ComputeProgramShape(); + int64 next_channel_id = hlo_query::NextChannelId(*module); + TF_ASSIGN_OR_RETURN( + bool partition_changed, + PartitionComputation( + module->entry_computation(), + module->entry_computation()->root_instruction()->sharding(), + &next_channel_id, &logger)); + changed |= partition_changed; + + // For the entry computation, make sure that the root instruction and the + // parameters preserve their signatures. + auto new_program_shape = module->entry_computation()->ComputeProgramShape(); + if (!options_.allow_module_signature_change) { + TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( + program_shape.result(), new_program_shape.result())) + << "Result shape changed for the entry computation"; + TF_RET_CHECK(program_shape.parameters_size() == + new_program_shape.parameters_size()) + << "Parameter count changed for the entry computation"; + for (int64 i = 0; i < program_shape.parameters_size(); ++i) { + TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( + program_shape.parameters(i), new_program_shape.parameters(i))) + << "Parameter shape changed for the entry computation"; + } + } else { + const auto& old_entry_layout = module->entry_computation_layout(); + // Shapes can change but the layout should still remain the same. + for (int64 i = 0; i < new_program_shape.parameters_size(); ++i) { + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + old_entry_layout.parameter_shape(i), + new_program_shape.mutable_parameters(i))); + } + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + old_entry_layout.result_shape(), new_program_shape.mutable_result())); + + HloModuleConfig config = module->config(); + *config.mutable_entry_computation_layout() = + ComputationLayout(new_program_shape, /*ignore_layouts=*/false); + module->set_config(config); + } + + XLA_VLOG_LINES(1, SpmdLogger::ReportAfterPartition( + *module, options_.report_instruction_count)); + XLA_VLOG_LINES(1, logger.MakeReport()); + + if (changed) { + HloPassPipeline pass("spmd-cleanup"); + pass.AddPass(); + pass.AddPass(); + pass.AddPass(/*is_layout_sensitive=*/true); + pass.AddPass(); + TF_RETURN_IF_ERROR(pass.Run(module).status()); + } + + TF_RETURN_IF_ERROR(ClearShardingAttributes(module)); + return changed; +} + +Status SpmdPartitioner::PreprocessSharding(HloModule* module) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* hlo : computation->instructions()) { + if (hlo->HasSideEffectNoRecurse() && hlo->opcode() != HloOpcode::kRng) { + TF_RET_CHECK(hlo->has_sharding()) + << "Side-effect HLO must have sharding: " << hlo->ToString(); + TF_RET_CHECK(!HasReplicatedSharding(hlo->sharding()) || + hlo->opcode() == HloOpcode::kInfeed) + << "Non-infeed side-effect HLO cannot have a replicated sharding:" + << hlo->ToString(); + } + + // For unassigned HLOs, annotate with replicated sharding. + // + // Among side-effecting ops, only Rng is allowed to omit the annotation. + // In that case, we currently force it to run on core 0, since we don't + // support partitioning or replicating the Rng op (the values depend on + // the seed provided to each device). + // + // TODO(hyouklee): Should we also convert single-device shardings (without + // side-effects) into replicated? + if (!hlo->has_sharding()) { + if (hlo->opcode() == HloOpcode::kRng) { + hlo->set_sharding(HloSharding::AssignDevice(0)); + } else { + hlo->set_sharding( + HloSharding::Single(hlo->shape(), HloSharding::Replicate())); + } + } else if (!hlo->sharding().IsTileMaximal()) { + std::vector available(num_partitions_); + std::iota(available.begin(), available.end(), 0); + TF_RET_CHECK(num_partitions_ == hlo_sharding_util::DevicesForSharding( + hlo->sharding(), available) + .size()) + << "num_partitions:" << num_partitions_ << "\n" + << "SPMD partitioner only supports tile sharding that includes all " + "partitions. If you didn't add this sharding annotation in the " + "model, please file a bug to XLA team.\n" + << hlo->ToString(); + } + } + } + + // Entry computation's parameter and root sharding must be either all + // replicated or all on a single device. + if (!options_.allow_module_signature_change) { + const HloComputation* entry = module->entry_computation(); + TF_RET_CHECK(entry->root_instruction()->has_sharding()); + const HloSharding& root_sharding = entry->root_instruction()->sharding(); + TF_RET_CHECK(root_sharding.IsReplicated() || + root_sharding.UniqueDevice().has_value()) + << "Unsupported entry root sharding: " << root_sharding.ToString(); + + for (const HloInstruction* param : entry->parameter_instructions()) { + TF_RET_CHECK(param->has_sharding()); + TF_RET_CHECK(param->sharding().IsReplicated() || + param->sharding().UniqueDevice().has_value()) + << "Unsupported entry parameter sharding:" + << param->sharding().ToString(); + } + } + + return Status::OK(); +} + +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h new file mode 100644 index 00000000000..09d2c4af908 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h @@ -0,0 +1,435 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ + +#include +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" + +namespace xla { +namespace spmd { + +struct SpmdPartitionerOptions { + // Always exchange halo on LHS for all convolutions. If false, backprop filter + // convolution exchanges halo on RHS. + bool conv_halo_exchange_always_on_lhs = true; + + // The number of instructions to be reported for the highest memory profile + // instructions. + int64 report_instruction_count = 5; + + // The minimum size in MiB of an einsum operand to be considered using + // windowed implementation in an HLO loop. + int64 threshold_for_windowed_einsum_mib = 256; + + // Whether the entry computations' signature could change after partitioning. + bool allow_module_signature_change = false; +}; + +// Class to wrap the computation builder to capture information during SPMD +// transformation. +class SpmdBuilder : public HloComputation::Builder { + public: + SpmdBuilder(const std::string& name, HloInstruction* hlo) + : HloComputation::Builder(name) { + visiting_hlo_ = hlo; + } + HloInstruction* AddInstruction(std::unique_ptr instruction); + + const std::vector& derived_instructions( + HloInstruction* hlo) { + return instructions_.at(hlo); + } + + void set_visiting_hlo(HloInstruction* hlo) { visiting_hlo_ = hlo; } + + HloInstruction* visiting_hlo() const { return visiting_hlo_; } + + private: + // Currently visiting instruction. + HloInstruction* visiting_hlo_; + + // Map from the currently visiting (old) instruction to new instructions + // created during SPMD partitioning. + HloInstructionMap> instructions_; +}; + +// A set of functions that create the cross-partition collective ops. +struct SPMDCollectiveOpsCreator { + // Function used to create a partition ID HLO. + std::function create_partition_id; + + // Function used to create a cross-partition all-reduce HLO. + std::function + create_cross_partition_all_reduce; + + // Function used to create a cross-partition collective-permute HLO. + std::function>& src_dst_pairs, + int64 next_channel_id)> + create_cross_partition_collective_permute; + + // Function used to create a cross-partition all-to-all HLO. + std::function operands, + const std::vector& replica_groups, int64 channel_id, + absl::optional split_dimension)> + create_cross_partition_all_to_all; +}; + +// Logger to report memory usage during SPMD partitioning. +class SpmdLogger { + public: + explicit SpmdLogger(int64 report_instruction_count) + : report_instruction_count_(report_instruction_count) {} + static std::string ReportBeforePartition(const HloModule& module, + int64 report_instruction_count); + static std::string ReportAfterPartition(const HloModule& module, + int64 report_instruction_count); + + // Registers the logging for the groups of instructions created to transform + // the given hlo. + void RegisterLogEntry(HloInstruction* hlo, + const std::vector& group); + + std::string MakeReport(); + + private: + template + static std::string ReportMemoryUsage(const HloModule& module, const F& filter, + int64 report_instruction_count); + + // A vector of logging messages (one for each original HLO instruction), where + // the first integer of the pair represents the size of the HBM used. + std::vector> entries_; + + int64 report_instruction_count_; +}; + +class SpmdPartitioningVisitor; + +class SpmdPartitioner : public HloModulePass { + public: + SpmdPartitioner(int64 num_partitions, int64 num_replicas, + SpmdPartitionerOptions options); + SpmdPartitioner(int64 num_partitions, int64 num_replicas, + SpmdPartitionerOptions options, + SPMDCollectiveOpsCreator collective_ops_creator) + : num_partitions_(num_partitions), + num_replicas_(num_replicas), + options_(std::move(options)), + collective_ops_creator_(std::move(collective_ops_creator)) {} + absl::string_view name() const override { return "spmd-partitioning"; } + StatusOr Run(HloModule* module) override; + + // Transforms the given computation with SPMD instructions, replacing it with + // a new computation. + StatusOr PartitionComputation(HloComputation* computation, + const HloSharding& root_sharding, + int64* next_channel_id, + SpmdLogger* logger); + + protected: + virtual std::unique_ptr CreateVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, + SpmdPartitionerOptions options); + + private: + // Verify that the sharding of instructions in the module are valid, and also + // fill in missing sharding information. + Status PreprocessSharding(HloModule* module); + + const int64 num_partitions_; + const int64 num_replicas_; + + SpmdPartitionerOptions options_; + SPMDCollectiveOpsCreator collective_ops_creator_; +}; + +// Class describes partition state of the data represented by an HLO created +// during SPMD partitioning pass. +// +// Data on some devices may include padding region, if the base (full) shape +// could not be evenly partitioned. +class PartitionedHlo { + public: + // Return value for ReshardAsWindowedInput which describes the resharded HLO, + // the window for the user on the shard, and if necessary, the dynamic slice + // offsets to be applied to the output of the op being sharded. + struct WindowedInputShardReturnValue { + HloInstruction* sharded_input; + Window shard_window; + absl::optional> dynamic_slice_index_on_output; + }; + // A cache for resharding each partitioned HLO. + struct ReshardCache { + struct PerHloCache { + std::vector> reshard_cache; + std::vector< + std::tuple> + window_reshard_cache; + }; + std::unordered_map per_hlo_cache; + }; + struct PartitioningState { + SpmdBuilder* b; + HloModule* module; + int64 num_replicas; + HloInstruction* partition_id; + SPMDCollectiveOpsCreator collective_ops_creator; + int64* next_channel_id; + ReshardCache* reshard_cache; + }; + PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state) + : hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) { + CHECK(hlo->has_sharding()) + << "PartitionedHlo is missing sharding:" << hlo->ToString(); + // If the tuple shape instruction does not have a tuple sharding, reassign + // to use the tuple sharding. Reshard() implementation assumes this. + if (hlo_->shape().IsTuple() && !hlo_->sharding().IsTuple()) { + hlo_->set_sharding( + hlo_->sharding().GetTupleSharding(hlo_->shape()).ValueOrDie()); + } + } + + // Reshards the current SPMD instruction to a new sharding. Could only modify + // the reshard cache. + PartitionedHlo Reshard(const HloSharding& target); + + // Pads the garbage area of the output with the provided value. + PartitionedHlo PadWithValue(HloInstruction* pad_value) const; + + // Returns the SPMD instruction. + HloInstruction* hlo() const { return hlo_; } + + // Returns the sharding of the SPMD instruction. + const HloSharding& sharding() const { return hlo_->sharding(); } + + // Original full shape of the data. + const Shape& base_shape() const { return base_shape_; } + + int64 NewChannel() const { return (*state_.next_channel_id)++; } + + // Reshards the HLO to a usable partitioned input for a windowed user. Could + // only modify the reshard cache. + absl::optional ReshardAsWindowedInput( + const Window& window, const HloSharding& target, + HloInstruction* pad_value, bool mask_invalid_region = true); + + private: + // Same as Reshard except that it does not explicitly modify the reshard + // cache, although it would indirectly modify by calling Replicate(). + PartitionedHlo ReshardNoCache(const HloSharding& target); + + // Helper function to replicate the data on all devices. Could only modify + // the reshard cache. + PartitionedHlo Replicate(); + + // Helper function to broadcast data from a single device to all devices. + PartitionedHlo Broadcast() const; + + // Helper function to reshard the tensor using AllToAll (instead of the + // default of Replicate followed by Slice). + PartitionedHlo ReshardWithAllToAll(const HloSharding& target) const; + + // Helper function to reshard the tensor using CollectivePermute. + PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const; + + // SPMD instruction. + HloInstruction* hlo_; + + // The original shape of the data before SPMD transformation is applied. + Shape base_shape_; + + PartitioningState state_; +}; + +struct DotGeneralDimsMapping { + // The dimension numbers for the operands and output corresponding to a + // logical dimension (e.g., batch, contracting, non-contracting). If an + // operand or the output doesn't have the logical dimension, it is set to + // -1. + struct DimsMapping { + int64 lhs; + int64 rhs; + int64 output; + }; + std::vector batch_dims; + std::vector contracting_dims; + std::vector lhs_non_contracting_dims; + std::vector rhs_non_contracting_dims; +}; + +class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { + public: + SpmdPartitioningVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, + SpmdPartitionerOptions options, SpmdPartitioner* partitioner); + + Status DefaultAction(HloInstruction* hlo) override; + Status HandleAllReduce(HloInstruction* hlo) override; + Status HandleBroadcast(HloInstruction* hlo) override; + Status HandleConstant(HloInstruction* hlo) override; + Status HandleCustomCall(HloInstruction* hlo) override; + Status HandleDot(HloInstruction* hlo) override; + Status HandleDynamicSlice(HloInstruction* hlo) override; + Status HandleDynamicUpdateSlice(HloInstruction* hlo) override; + Status HandleGather(HloInstruction* hlo) override; + Status HandleGetTupleElement(HloInstruction* hlo) override; + Status HandleInfeed(HloInstruction* hlo) override; + Status HandleOutfeed(HloInstruction* hlo) override; + Status HandlePad(HloInstruction* hlo) override; + Status HandleParameter(HloInstruction* hlo) override; + Status HandleReduce(HloInstruction* hlo) override; + Status HandleReverse(HloInstruction* hlo) override; + Status HandleWhile(HloInstruction* hlo) override; + Status HandleConditional(HloInstruction* hlo) override; + Status HandleReduceWindow(HloInstruction* hlo) override; + Status HandleSelectAndScatter(HloInstruction* hlo) override; + Status HandleTuple(HloInstruction* hlo) override; + Status HandleRng(HloInstruction* hlo) override; + Status HandleConvolution(HloInstruction* hlo) override; + Status HandleConcatenate(HloInstruction* hlo) override; + Status HandleScatter(HloInstruction* hlo) override; + Status HandleSlice(HloInstruction* hlo) override; + Status HandleSort(HloInstruction* hlo) override; + Status HandleTranspose(HloInstruction* hlo) override; + Status HandleReshape(HloInstruction* hlo) override; + Status HandleIota(HloInstruction* hlo) override; + Status HandlePartitionId(HloInstruction* hlo) override; + + // Handles convolution where both LHS and RHS operands are tiled. + Status HandleConvolutionTiledLhsAndRhs(HloInstruction* hlo); + + // Implementation of dot partitioning given DotGeneralDimsMapping. + Status HandleDotHelper( + HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot); + + // Common handle for elementwise HLOs. + Status HandleElementwise(HloInstruction* hlo); + + // Common handle for HLOs that runs on a single device. + Status HandleSingleDevice(const HloInstruction* hlo); + + // Returns the PartitionedHlo that corresponds to the original hlo. + PartitionedHlo& GetPartitionedHlo(const HloInstruction* hlo) { + CHECK_EQ(partitioned_instructions_.count(hlo), 1); + return partitioned_instructions_.find(hlo)->second; + } + + // Sets the PartitionedHlo for the original hlo. + void SetPartitionedHlo(const HloInstruction* hlo, + const PartitionedHlo& partitioned_hlo) { + CHECK_EQ(partitioned_instructions_.count(hlo), 0); + partitioned_instructions_.emplace(hlo, partitioned_hlo); + changed_ = true; + } + + // Convenient wrapper that creates PartitionedHlo from the result of the func + // and maps it to the given original hlo. + void SetPartitionedHlo(const HloInstruction* hlo, + const std::function& func) { + HloInstruction* new_hlo = func(); + new_hlo->set_sharding(hlo->sharding()); + new_hlo->set_metadata(hlo->metadata()); + SetPartitionedHlo( + hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState())); + changed_ = true; + } + + int64 NewChannel() { return (*next_channel_id_)++; } + + PartitionedHlo::PartitioningState MakePartitioningState() { + return PartitionedHlo::PartitioningState{ + .b = &b_, + .module = module_, + .num_replicas = num_replicas_, + .partition_id = partition_id_, + .collective_ops_creator = collective_ops_creator_, + .next_channel_id = next_channel_id_, + .reshard_cache = &reshard_cache_}; + } + + SpmdBuilder* builder() { return &b_; } + + StatusOr DoPartition(HloComputation* computation, + const HloSharding& root_sharding); + + private: + Status Preprocess(HloInstruction* hlo) override; + Status Postprocess(HloInstruction* hlo) override; + + // Performs code motion for windowed dot-general loops in + // windowed_dot_general_loops_. Invoked after the visitor finishes traversing + // the graph. + Status DoCodeMotionForWindowedDotGeneralLoops(HloComputation* computation); + + bool changed_; + HloModule* module_; + int64 num_partitions_; + int64 num_replicas_; + + SPMDCollectiveOpsCreator collective_ops_creator_; + + // Tracks the next channel id to use for cross-partition all-reduce. + int64* next_channel_id_; + SpmdBuilder b_; + + HloInstruction* partition_id_; + + PartitionedHlo::ReshardCache reshard_cache_; + + // Mapping from the instruction in the original computation to the new SPMD + // partitioned instruction. + ConstHloInstructionMap partitioned_instructions_; + + // Information about a loop created for windowed dot-general. Used when + // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor + // finishes traversing the graph. + struct WindowedDotGeneralLoop { + HloInstruction* while_loop; + int64 windowed_operand; + bool windowed_in_contracting_dims; + bool windowed_in_batch_dims; + }; + std::vector windowed_dot_general_loops_; + + HloInstruction* visiting_hlo_; + SpmdLogger* logger_; + const SpmdPartitionerOptions options_; + SpmdPartitioner* partitioner_; +}; + +} // namespace spmd +} // namespace xla +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc new file mode 100644 index 00000000000..7a7f2dcc807 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -0,0 +1,3191 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace spmd { +namespace { + +using ::testing::AllOf; +namespace op = xla::testing::opcode_matchers; + +class SpmdPartitioningTest : public HloTestBase { + public: + StatusOr> PartitionComputation( + const char* hlo_module, int64 num_devices, + bool conv_halo_exchange_always_on_lhs = true) { + // Some tests (BackpropFilter convs) set this flag false to test two + // different paths of the implementation. + SpmdPartitionerOptions options; + options.conv_halo_exchange_always_on_lhs = conv_halo_exchange_always_on_lhs; + options.allow_module_signature_change = true; + + TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule( + hlo_module, GetModuleConfigForTest())); + HloPassPipeline pass("spmd-partitioning"); + pass.AddPass(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + pass.AddPass(num_devices, /*num_replicas=*/1, options); + pass.AddPass(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + TF_RETURN_IF_ERROR(pass.Run(module.get()).status()); + return StatusOr>(std::move(module)); + } +}; + +TEST_F(SpmdPartitioningTest, InvalidSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[8,2]{1,0}, token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {maximal device=0}} + ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0, + sharding={maximal device=0} +})"; + auto module_status = PartitionComputation(hlo_string, /*num_devices=*/4); + EXPECT_FALSE(module_status.status().ok()); + EXPECT_THAT(module_status.status().ToString(), + ::testing::HasSubstr( + "only supports tile sharding that includes all partitions")); +} + +TEST_F(SpmdPartitioningTest, SingleDeviceToReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={maximal device=0} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce( + op::Select(op::Broadcast(op::Compare()), + op::Constant(), op::Broadcast()))), + op::Shape("s32[2,3]"))); +} + +TEST_F(SpmdPartitioningTest, SingleDeviceToSingleDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={maximal device=0} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + VLOG(1) << module->ToString(); + EXPECT_THAT(root, op::Copy(AllOf(op::Copy(op::AllReduce(op::Select( + op::Broadcast(op::Compare()), + op::Constant(), op::Broadcast()))), + op::Shape("s32[2,3]")))); +} + +TEST_F(SpmdPartitioningTest, SingleDeviceToTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={maximal device=0} + ROOT %copy = s32[2,3]{1,0} copy(%constant), + sharding={devices=[2,1]1,0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf( + op::Copy(op::DynamicSlice( + op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::Constant(), op::Broadcast())), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant())), + op::Shape("s32[1,3]"))); +} + +TEST_F(SpmdPartitioningTest, TiledToReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + op::Copy(op::AllReduce(AllOf( + op::DynamicUpdateSlice( + op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant()), + op::Shape("s32[2,3]"))))); +} + +TEST_F(SpmdPartitioningTest, TiledToSingleDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + op::Copy(op::Copy(op::AllReduce(AllOf( + op::DynamicUpdateSlice( + op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant()), + op::Shape("s32[2,3]")))))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledEven) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param= s32[8,2]{1,0} parameter(0), sharding={devices=[2,1]0,1} + ROOT %copy = s32[8,2]{1,0} copy(%param), sharding={devices=[1,2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Copy(op::Reshape(op::Transpose(op::AllToAll(AllOf( + op::Reshape(op::Parameter()), op::Shape("s32[4,2,1]")))))), + op::Shape("s32[8,1]"))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledUneven) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param= f32[7,31,128]{2,1,0} parameter(0), sharding={devices=[1,2,1]0,1} + ROOT %copy = f32[7,31,128]{2,1,0} copy(%param), sharding={devices=[2,1,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Copy(op::Slice(op::Reshape(AllOf(op::Transpose(op::AllToAll( + op::Reshape(AllOf(op::Pad(), op::Shape("f32[8,16,128]"))))))))))); +} + +TEST_F(SpmdPartitioningTest, GetTupleElementSwapDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param.0 = (f32[2,3]{1,0}, u32[]) parameter(0), + sharding={{maximal device=1}, {maximal device=1}} + %gte.0 = f32[2,3]{1,0} get-tuple-element(%param.0), index=0, + sharding={maximal device=0} + %gte.1 = u32[] get-tuple-element(%param.0), index=1, + sharding={maximal device=0} + ROOT %tuple = (f32[2,3]{1,0}, u32[]) tuple(%gte.0, %gte.1), + sharding={{maximal device=0},{maximal device=0}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::Tuple()); + + EXPECT_THAT(root->operand(0), + op::Copy(op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::GetTupleElement(op::Parameter()), op::Broadcast())))); + EXPECT_THAT(root->operand(1), + op::Copy(op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::GetTupleElement(op::Parameter()), op::Broadcast())))); +} + +TEST_F(SpmdPartitioningTest, GetTupleElementTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param.0 = (f32[2,3]{1,0}, u32[2,3]{1,0}) parameter(0), + sharding={{replicated}, {replicated}} + gte.0 = f32[2,3]{1,0} get-tuple-element(param.0), index=0, + sharding={devices=[2,1]0,1} + gte.1 = u32[2,3]{1,0} get-tuple-element(param.0), index=1, + sharding={devices=[2,1]0,1} + ROOT %tuple = (f32[2,3]{1,0}, u32[2,3]{1,0}) tuple(gte.0, gte.1), + sharding={{devices=[2,1]0,1},{devices=[2,1]0,1}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::Tuple()); + + auto offset = op::Reshape( + op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + + EXPECT_THAT(root->operand(0), + op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset, + op::Constant())); + EXPECT_THAT(root->operand(1), + op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset, + op::Constant())); +} + +TEST_F(SpmdPartitioningTest, TiledInfeed) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[8,2]{1,0}, token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {maximal device=0}} + ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0, + sharding={maximal device=0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, op::Copy(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), + op::GetTupleElement( + AllOf(op::Infeed(), op::Shape("(f32[4,2]{1,0}, token[])"))), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, UnevenTiledInfeed) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[9,2]{1,0}, token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {maximal device=0}} + ROOT infeed.data = f32[9,2]{1,0} get-tuple-element(infeed), index=0, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, AllOf(op::Shape("f32[5,2]"), op::GetTupleElement(op::Conditional( + op::Convert(op::PartitionId()), + op::AfterAll(), op::AfterAll())))); + EXPECT_THAT( + root->operand(0)->called_computations()[0]->root_instruction(), + AllOf(op::Shape("(f32[5,2], token[])"), op::Infeed(op::Parameter()))); + auto second_infeed = + AllOf(op::Shape("(f32[4,2], token[])"), op::Infeed(op::Parameter())); + EXPECT_THAT(root->operand(0)->called_computations()[1]->root_instruction(), + AllOf(op::Shape("(f32[5,2], token[])"), + op::Tuple(op::Pad(op::GetTupleElement(second_infeed), + op::Constant()), + op::GetTupleElement(second_infeed)))); +} + +TEST_F(SpmdPartitioningTest, UnevenTiledTupleInfeed) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = ((f32[9,2]{1,0}, f32[2]{0}), token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {replicated}, {maximal device=0}} + ROOT infeed.data = (f32[9,2]{1,0}, f32[2]{0}) get-tuple-element(infeed), + index=0, sharding={{devices=[2,1]0,1}, {replicated}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("(f32[5,2], f32[2])"), + op::GetTupleElement(op::Conditional( + op::Convert(op::PartitionId()), op::AfterAll(), + op::AfterAll())))); + EXPECT_THAT(root->operand(0)->called_computations()[0]->root_instruction(), + AllOf(op::Shape("((f32[5,2], f32[2]), token[])"), + op::Infeed(op::Parameter()))); + auto second_infeed = AllOf(op::Shape("((f32[4,2], f32[2]), token[])"), + op::Infeed(op::Parameter())); + EXPECT_THAT( + root->operand(0)->called_computations()[1]->root_instruction(), + AllOf(op::Shape("((f32[5,2], f32[2]), token[])"), + op::Tuple(op::Tuple(op::Pad(op::GetTupleElement( + op::GetTupleElement(second_infeed)), + op::Constant()), + op::GetTupleElement( + op::GetTupleElement(second_infeed))), + op::GetTupleElement(second_infeed)))); +} + +TEST_F(SpmdPartitioningTest, TiledToReplicatedReduce) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce = f32[] reduce(constant, constant.1), dimensions={0,1}, + to_apply=sum, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + op::AllReduce(op::Reduce( + op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(op::Reshape())), + op::Broadcast(op::Constant())), + AllOf(op::Shape("f32[2,3]{1,0}"), + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())), + op::Broadcast(op::Constant())), + op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, TiledElementwise) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[3,3]{1,0} constant({{2,2,2},{2,2,2},{2,2,2}}), + sharding={replicated} + multiply = f32[3,3]{1,0} multiply(constant, constant.1), + sharding={devices=[2,1]0,1} + ROOT add = f32[3,3]{1,0} add(multiply, constant.1), + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf( + op::Shape("f32[2,3]{1,0}"), + op::Add(op::Multiply( + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant()), + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())), + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, TiledAllReduce) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + parameter = f32[3,3]{1,0} parameter(0), sharding={devices=[2,1]0,1} + ROOT all-reduce = f32[3,3]{1,0} all-reduce(parameter), to_apply=sum, + replica_groups={}, sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, AllOf(op::Shape("f32[2,3]{1,0}"), op::AllReduce(op::Parameter(0)))); +} + +TEST_F(SpmdPartitioningTest, BroadcastOnlyNewDimsSharded) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}), + sharding={replicated} + ROOT broadcast = f32[3,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[2,1,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,4,3]{2,1,0}"), + op::Broadcast(op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, BroadcastOnlyOldDimsSharded) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}), + sharding={replicated} + ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[1,2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"), + op::Broadcast(op::DynamicSlice( + op::Constant(), op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, BroadcastBothOldAndNewDimsSharded) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}), + sharding={replicated} + ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[2,2,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Shape("f32[2,2,3]{2,1,0}"), + op::Broadcast(AllOf(op::Shape("f32[2,3]{1,0}"), + op::DynamicSlice(op::Constant(), op::Reshape(), + op::Constant()))))); +} + +TEST_F(SpmdPartitioningTest, BroadcastPropagateTiledSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,4,1},{1,3,1},{1,2,1}}), + sharding={devices=[2,1]0,1} + ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[1,2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"), + op::Broadcast(op::DynamicSlice( + op::Constant(), op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, OutfeedSingleDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token.0 = token[] after-all() + data = f32[1024]{0} parameter(0), sharding={maximal device=0} + outfeed = token[] outfeed(data, token.0), sharding={maximal device=0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("token[]"), + op::Conditional( + op::Compare(op::PartitionId(), op::Constant()), + op::Tuple(op::Parameter(0), op::AfterAll()), + op::Tuple(op::Parameter(0), op::AfterAll())))); + + HloInstruction* root_b0 = root->branch_computation(0)->root_instruction(); + EXPECT_THAT(root_b0, + AllOf(op::Shape("token[]"), + op::Outfeed(op::GetTupleElement(op::Parameter(), 0), + op::GetTupleElement(op::Parameter(), 1)))); + + HloInstruction* root_b1 = root->branch_computation(1)->root_instruction(); + EXPECT_THAT(root_b1, AllOf(op::Shape("token[]"), op::AfterAll())); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowReplicatedInput) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}), + sharding={replicated} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[3,2]{1,0} reduce-window(constant, constant.1), + window={size=3x1 stride=2x1 pad=1_0x0_0}, to_apply=sum, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Shape("f32[2,2]{1,0}"), + op::ReduceWindow( + op::DynamicSlice(AllOf(op::Shape("f32[9,2]{1,0}"), + op::Pad(op::Constant(), op::Constant())), + op::Multiply(op::Reshape(), op::Constant()), + op::Constant()), + op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiledNegativeLeftHalo) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT %reduce-window = f32[3,2]{1,0} reduce-window(%constant, %constant.1), + window={size=3x1 stride=2x1 pad=0_1x0_0}, to_apply=sum, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()); + auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto pre_masking = op::DynamicSlice( + AllOf( + op::Shape("f32[6,2]{1,0}"), + op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())), + op::Reshape(), op::Constant()); + auto index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto masked = + op::Select(op::Compare(index_in_padded, op::Broadcast(op::Constant())), + pre_masking, op::Broadcast(op::Constant())); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"), + op::ReduceWindow(masked, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[9,2]{1,0} constant( + {{1,1},{1,4},{2,1},{3,1},{1,2},{2,2},{4,1},{1,2},{2,1}}), + sharding={devices=[3,1]0,1,2} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[5,2]{1,0} reduce-window(constant, constant.1), + window={size=3x1 stride=2x1 pad=1_1x0_0}, to_apply=sum, + sharding={devices=[3,1]0,1,2} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/3)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()); + auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto pre_masking = op::DynamicSlice( + AllOf( + op::Shape("f32[7,2]{1,0}"), + op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())), + op::Reshape(), op::Constant()); + auto index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto masked = op::Select( + op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())), + op::Compare(index_in_padded, op::Broadcast(op::Constant()))), + pre_masking, op::Broadcast(op::Constant())); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"), + op::ReduceWindow(masked, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiledTwoSideHalo) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[2,2]{1,0} reduce-window(constant, constant.1), + window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()); + auto left_halo = AllOf(op::Shape("f32[1,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto right_halo = AllOf(op::Shape("f32[1,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto pre_masking = AllOf( + op::Shape("f32[5,2]{1,0}"), + op::DynamicSlice( + AllOf(op::Shape("f32[6,2]{1,0}"), + op::Pad(op::Concatenate(left_halo, sharded_input, right_halo), + op::Constant())), + op::Reshape(), op::Constant())); + auto index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto masked = op::Select( + op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())), + op::Compare(index_in_padded, op::Broadcast(op::Constant()))), + pre_masking, op::Broadcast(op::Constant())); + EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"), + op::ReduceWindow(masked, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiled2D) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[4,4,2,2]{3,2,1,0}, token[]) infeed(token0), + sharding={{devices=[2,2,1,1]0,1,2,3}, {maximal device=0}} + infeed.data = f32[4,4,2,2]{3,2,1,0} get-tuple-element(infeed), index=0, + sharding={devices=[2,2,1,1]0,1,2,3} + constant = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[2,2,2,2]{3,2,1,0} reduce-window(infeed.data, constant), + window={size=5x5x1x1 stride=3x3x1x1 pad=2_2x2_2x0_0x0_0}, to_apply=sum, + sharding={devices=[2,2,1,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = AllOf(op::Shape("f32[2,2,2,2]{3,2,1,0}"), + op::GetTupleElement(op::Infeed())); + auto dim0_left_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto dim0_right_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto dim0_pre_masking = op::DynamicSlice( + AllOf(op::Shape("f32[6,2,2,2]{3,2,1,0}"), + op::Pad( + op::Concatenate(dim0_left_halo, sharded_input, dim0_right_halo), + op::Constant())), + op::Reshape(), op::Constant(), op::Constant(), op::Constant()); + auto dim0_index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto dim0_masked = op::Select( + op::And(op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant())), + op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant()))), + dim0_pre_masking, op::Broadcast(op::Constant())); + auto dim0_resharded = AllOf(op::Shape("f32[5,2,2,2]{3,2,1,0}"), dim0_masked); + auto dim1_left_halo = AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(dim0_resharded))); + auto dim1_right_halo = + AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(dim0_resharded))); + auto dim1_pre_masking = op::DynamicSlice( + AllOf(op::Shape("f32[5,6,2,2]{3,2,1,0}"), + op::Pad(op::Concatenate(dim1_left_halo, dim0_resharded, + dim1_right_halo), + op::Constant())), + op::Constant(), op::Reshape(), op::Constant(), op::Constant()); + auto dim1_index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto dim1_masked = op::Select( + op::And(op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant())), + op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant()))), + dim1_pre_masking, op::Broadcast(op::Constant())); + auto dim1_resharded = AllOf(op::Shape("f32[5,5,2,2]{3,2,1,0}"), dim1_masked); + EXPECT_THAT(root, AllOf(op::Shape("f32[1,1,2,2]{3,2,1,0}"), + op::ReduceWindow(dim1_resharded, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[7,7,3,64] parameter(1) + %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs), + sharding={replicated} + ROOT %conv = f32[128,112,112,64] convolution( + f32[128,224,224,3] %lhs.copy, + f32[7,7,3,64] %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,112,224,3]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,3,224,3]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,2,224,3]")); + EXPECT_THAT(root, + AllOf(op::Convolution( + op::Select(op::And(), + op::Concatenate(left_halo, lhs, right_halo), + op::Broadcast()), + rhs), + op::Shape("f32[128,56,112,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedNeedReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs), + sharding={devices=[2,1,1,1]0,1} + %rhs = f32[7,7,3,64] parameter(1) + %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs), + sharding={replicated} + ROOT %conv = f32[128,112,112,64] convolution( + f32[128,224,224,3] %lhs.copy, + f32[7,7,3,64] %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[64,224,224,3]")); + auto all_to_all = + AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[64,2,112,224,3]")); + auto reshard_lhs = AllOf(op::Reshape(op::Transpose(all_to_all)), + op::Shape("f32[128,112,224,3]")); + + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)), + op::Shape("f32[128,3,224,3]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)), + op::Shape("f32[128,2,224,3]")); + EXPECT_THAT( + root, + AllOf(op::Convolution( + op::Select(op::And(), + op::Concatenate(left_halo, reshard_lhs, right_halo), + op::Broadcast()), + rhs), + op::Shape("f32[128,56,112,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedReordered) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[224,224,3,128] parameter(0) + %lhs.copy = f32[224,224,3,128] copy(%lhs), sharding={devices=[2,1,1,1]0,1} + %rhs = f32[7,7,3,64] parameter(1) + %rhs.copy = f32[7,7,3,64] copy(%rhs), sharding={replicated} + ROOT %conv = f32[128,112,112,64] convolution(%lhs.copy, %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=01fb_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[112,224,3,128]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[3,224,3,128]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[2,224,3,128]")); + EXPECT_THAT(root, + AllOf(op::Convolution( + op::Select(op::And(), + op::Concatenate(left_halo, lhs, right_halo), + op::Broadcast()), + rhs), + op::Shape("f32[128,56,112,64]"))); +} + +// (stride * per_shard_window_count) % dilation == 0 +TEST_F(SpmdPartitioningTest, + ConvolutionBaseDilationSameStartPatternLhsTiledRhsReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,7,7,512] parameter(0) + %lhs.copy = f32[128,7,7,512] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[3,3,512,512] parameter(1) + %rhs.copy = f32[3,3,512,512] copy(%rhs), + sharding={replicated} + ROOT %conv = f32[128,4,4,512] convolution(%lhs.copy, %rhs.copy), + window={size=3x3 stride=4x4 pad=1_1x1_1 lhs_dilate=2x2 rhs_reversal=1x1}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + // There is no halo exchange, and because the last element in the shard is not + // needed (stride == 4), the LHS will be just a slice. + auto sliced_lhs = + AllOf(op::Slice(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant()))), + op::Shape("f32[128,3,7,512]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]")); + EXPECT_THAT(root, AllOf(op::Convolution(sliced_lhs, rhs), + op::Shape("f32[128,2,4,512]"))); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 1); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 1); +} + +// (stride * per_shard_window_count) % dilation != 0 but stride == 1 +TEST_F(SpmdPartitioningTest, + ConvolutionBaseDilationStride1LhsTiledRhsReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,7,7,512] parameter(0) + %lhs.copy = f32[128,7,7,512] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[3,3,512,512] parameter(1) + %rhs.copy = f32[3,3,512,512] copy(%rhs), + sharding={replicated} + ROOT %conv = f32[128,14,14,512] convolution(%lhs.copy, %rhs.copy), + window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Shape("f32[128,4,7,512]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,1,7,512]")); + auto start_window = op::Multiply(op::Reshape(), op::Constant()); + auto start_input_element = op::Divide(start_window, op::Constant()); + auto dynamic_offset_for_padded_concat = op::Subtract( + op::Constant(), op::Subtract(op::Multiply(op::Reshape(), op::Constant()), + start_input_element)); + auto pre_masking = + AllOf(op::Shape("f32[128,5,7,512]"), + op::DynamicSlice( + AllOf(op::Shape("f32[128,6,7,512]"), + op::Pad(op::Concatenate(left_halo, lhs), op::Constant())), + op::Constant(), dynamic_offset_for_padded_concat, + op::Constant(), op::Constant())); + auto masked = op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(start_input_element)), + op::Broadcast(op::Constant())), + pre_masking, op::Broadcast(op::Constant())); + auto dynamic_offset_on_output = op::Subtract( + start_window, op::Multiply(start_input_element, op::Constant())); + EXPECT_THAT(root, + AllOf(op::DynamicSlice(AllOf(op::Convolution(masked, rhs), + op::Shape("f32[128,8,14,512]")), + op::Constant(), dynamic_offset_on_output, + op::Constant(), op::Constant()), + op::Shape("f32[128,7,14,512]"))); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 1); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlap) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param = f32[11,4]{1,0} parameter(0) + %param.copy = f32[11,4] copy(%param), + sharding={devices=[4,1]0,1,2,3} + constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}), + sharding={devices=[4,1]0,1,2,3} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy, + constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0}, + select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto source = + AllOf(op::Shape("f32[1,2]{1,0}"), + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant())); + auto masked_data = AllOf( + op::Shape("f32[3,4]{1,0}"), + op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply( + op::Reshape(), op::Constant()))), + op::Broadcast(op::Constant())), + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Reshape(), op::Constant())), + op::Broadcast(op::Constant()))); + + EXPECT_THAT(root, + AllOf(op::SelectAndScatter(masked_data, source, op::Constant()), + op::Shape("f32[3,4]{1,0}"))); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlapReshard) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param = f32[11,4]{1,0} parameter(0) + %param.copy = f32[11,4] copy(%param), + sharding={devices=[1,4]0,1,2,3} + constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}), + sharding={devices=[4,1]0,1,2,3} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy, + constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0}, + select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto source = + AllOf(op::Shape("f32[1,2]{1,0}"), + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant())); + auto operand = AllOf(op::Copy(op::DynamicSlice( + op::Parameter(0), op::Constant(), op::Reshape())), + op::Shape("f32[11,1]")); + auto reshard_operand = op::Reshape(op::Transpose( + op::AllToAll(op::Reshape(op::Pad(operand, op::Constant()))))); + auto masked_data = AllOf( + op::Shape("f32[3,4]{1,0}"), + op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply( + op::Reshape(), op::Constant()))), + op::Broadcast(op::Constant())), + reshard_operand, op::Broadcast(op::Constant()))); + + EXPECT_THAT(root, + AllOf(op::SelectAndScatter(masked_data, source, op::Constant()), + op::Shape("f32[3,4]{1,0}"))); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatterWithOverlap) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param = f32[11,4]{1,0} parameter(0) + %param.copy = f32[11,4] copy(%param), + sharding={devices=[4,1]0,1,2,3} + constant = f32[6,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8},{6,6},{1,9}}), + sharding={devices=[4,1]0,1,2,3} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy, + constant, constant.1), window={size=3x2 stride=2x2 pad=1_1x0_0}, + select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + + auto source_shard = + AllOf(op::Shape("f32[2,2]{1,0}"), + op::DynamicSlice(op::Pad(), op::Reshape(), op::Constant())); + // Max halo size is the same as the shard size, so slice is not needed. + auto source_left_halo = op::CollectivePermute(source_shard); + auto required_source_shard_start = + op::Divide(op::Multiply(op::Reshape(), op::Constant()), op::Constant()); + auto source_with_halo = op::DynamicSlice( + AllOf(op::Shape("f32[5,2]{1,0}"), + op::Pad(op::Concatenate(source_left_halo, source_shard), + op::Constant())), + op::Subtract(op::Constant(), + op::Subtract(op::Multiply(op::Reshape(), op::Constant()), + required_source_shard_start)), + op::Constant()); + auto masked_source_with_halo = AllOf( + AllOf(op::Shape("f32[3,2]{1,0}")), + op::Select( + op::Compare( + op::Add(op::Iota(), op::Broadcast(required_source_shard_start)), + op::Broadcast(op::Constant())), + source_with_halo, op::Broadcast(op::Constant()))); + + auto data_shard = + AllOf(op::Shape("f32[3,4]{1,0}"), + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Reshape(), op::Constant()))); + auto data_left_halo = AllOf(op::Shape("f32[2,4]{1,0}"), + op::CollectivePermute(op::Slice(data_shard))); + auto data_right_halo = AllOf(op::Shape("f32[2,4]{1,0}"), + op::CollectivePermute(op::Slice(data_shard))); + auto required_data_start_on_padded = + op::Multiply(required_source_shard_start, op::Constant()); + auto left_halo_size = op::Subtract( + op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant()), + required_data_start_on_padded); + auto data_with_halo = + AllOf(op::Shape("f32[7,4]{1,0}"), + op::DynamicSlice( + AllOf(op::Shape("f32[8,4]{1,0}"), + op::Pad(op::Concatenate(data_left_halo, data_shard, + data_right_halo), + op::Constant())), + op::Subtract(op::Constant(), left_halo_size), op::Constant())); + auto index_on_padded = + op::Add(op::Iota(), op::Broadcast(required_data_start_on_padded)); + auto masked_data_with_halo = op::Select( + op::And(op::Compare(index_on_padded, op::Broadcast(op::Constant())), + op::Compare(index_on_padded, op::Broadcast(op::Constant()))), + data_with_halo, op::Broadcast(op::Constant())); + + EXPECT_THAT( + root, AllOf(op::DynamicSlice(op::SelectAndScatter(masked_data_with_halo, + masked_source_with_halo, + op::Constant()), + left_halo_size, op::Constant()), + op::Shape("f32[3,4]{1,0}"))); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,64] parameter(0) + %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,56,56,256] parameter(1) + %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy), + window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,256]")); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[1,1,64,256]"))); +} + +TEST_F(SpmdPartitioningTest, DotLhsTiledRhsTiledWithReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,64] parameter(0) + %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,56,56,256] parameter(1) + %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[2,1,1,1]0,1} + ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy), + window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[64,56,56,256]")); + auto all_to_all = + AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[2,64,28,56,64]")); + auto reshard = AllOf(op::Reshape(op::Transpose(all_to_all))); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(reshard, rhs)), + op::Shape("f32[1,1,64,256]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,512] parameter(0) + %lhs.copy = f32[128,56,56,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,28,28,64] parameter(1) + %rhs.copy = f32[128,28,28,64] copy(%rhs), sharding={devices=[2,1,1,1]0,1} + ROOT %conv = f32[1,1,512,64] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, + dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,512]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[64,28,28,64]")); + auto all_to_all = + AllOf(op::AllToAll(op::Reshape(rhs)), op::Shape("f32[64,2,14,28,64]")); + auto reshard = op::Reshape(op::Transpose(all_to_all)); + + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), reshard)), + op::Shape("f32[1,1,512,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,28,28,128] parameter(0) + %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,28,28,64] parameter(1) + %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,128]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[32,1,28,64]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[32,1,28,64]")); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution( + lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo), + op::Shape("f32[32,16,28,64]")))), + op::Shape("f32[3,3,128,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,112,112,64] parameter(1) + %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy), + window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,112,224,3]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,56,112,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[128,2,112,64]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[128,2,112,64]")); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution( + lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo), + op::Shape("f32[128,60,112,64]")))), + op::Shape("f32[7,7,3,64]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,256] parameter(0) + %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,28,28,512] parameter(1) + %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,256]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,14,28,512]")); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[1,1,256,512]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilateUneven) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,14,14,512] parameter(0) + %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,7,7,512] parameter(1) + %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy), + window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,7,14,512]")); + auto rhs = AllOf( + op::Select(op::Compare(), + op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Broadcast()), + op::Shape("f32[128,4,7,512]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[128,1,7,512]")); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution( + AllOf(op::DynamicSlice(op::Pad(lhs, op::Constant()), + op::Constant(), op::Subtract(), + op::Constant(), op::Constant()), + op::Shape("f32[128,10,14,512]")), + AllOf(op::Concatenate(left_halo, rhs), + op::Shape("f32[128,5,7,512]")))), + op::Shape("f32[3,3,512,512]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,28,28,128] parameter(0) + %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,28,28,64] parameter(1) + %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,128]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[32,1,28,128]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[32,1,28,128]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution( + AllOf(op::Concatenate(left_halo, lhs, right_halo), + op::Shape("f32[32,16,28,128]")), + rhs)), + op::Shape("f32[3,3,128,64]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilate_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,112,112,64] parameter(1) + %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy), + window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,112,224,3]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,56,112,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,3,224,3]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,2,224,3]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution( + AllOf(op::Concatenate(left_halo, lhs, right_halo), + op::Shape("f32[128,117,224,3]")), + rhs)), + op::Shape("f32[7,7,3,64]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,256] parameter(0) + %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,28,28,512] parameter(1) + %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,256]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,14,28,512]")); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), rhs)), + op::Shape("f32[1,1,256,512]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilateUneven_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,14,14,512] parameter(0) + %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,7,7,512] parameter(1) + %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy), + window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,7,14,512]")); + auto rhs = AllOf( + op::Select(op::Compare(), + op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Broadcast()), + op::Shape("f32[128,4,7,512]")); + + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,1,14,512]")); + EXPECT_THAT( + root, AllOf(op::AllReduce(op::Convolution( + AllOf(op::DynamicSlice( + AllOf(op::Pad(op::Concatenate(lhs, right_halo), + op::Constant()), + op::Shape("f32[128,10,14,512]")), + op::Constant(), op::Reshape(), op::Constant(), + op::Constant()), + op::Shape("f32[128,9,14,512]")), + rhs)), + op::Shape("f32[3,3,512,512]"))); +} + +TEST_F(SpmdPartitioningTest, ConcatenateAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[14,257] parameter(0) + %param0.copy = f32[14,257] copy(%param0), sharding={devices=[2,1]0,1} + %param1 = f32[14,116] parameter(1) + %param1.copy = f32[14,116] copy(%param1), sharding={devices=[2,1]0,1} + ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy), + dimensions={1}, sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[7,257]")); + auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[7,116]")); + EXPECT_THAT(root, + AllOf(op::Concatenate(param0, param1), op::Shape("f32[7,373]"))); +} + +TEST_F(SpmdPartitioningTest, ConcatenateAlongPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[14,257] parameter(0) + %param0.copy = f32[14,257] copy(%param0), sharding={devices=[1,2]0,1} + %param1 = f32[14,116] parameter(1) + %param1.copy = f32[14,116] copy(%param1), sharding={devices=[1,2]0,1} + ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy), + dimensions={1}, sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = + AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Constant(), op::Reshape())), + op::Shape("f32[14,129]")); + auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), + op::Reshape())), + op::Shape("f32[14,58]")); + EXPECT_THAT(root, AllOf(op::DynamicSlice( + AllOf(op::AllReduce(op::DynamicUpdateSlice( + op::DynamicUpdateSlice( + op::Broadcast(), param0, + op::Constant(), op::Multiply()), + param1, op::Constant(), op::Add())), + op::Shape("f32[14,374]")), + op::Constant(), op::Multiply()), + op::Shape("f32[14,187]"))); +} + +TEST_F(SpmdPartitioningTest, PadAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0) + %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1} + %const = f32[] constant(0) + ROOT %pad = f32[128,17,257] pad(%param0.copy, %const), padding=0_0x1_2x0_0, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Constant(), op::Constant(), op::Reshape())), + op::Shape("f32[128,14,129]")); + EXPECT_THAT(root, AllOf(op::Pad(param0, op::Constant()), + op::Shape("f32[128,17,129]"))); +} + +TEST_F(SpmdPartitioningTest, SliceAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0) + %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1} + ROOT %slice = f32[128,11,257] slice(%param0.copy), + slice={[0:128:1], [2:13:1], [0:257:1]}, sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Constant(), op::Constant(), op::Reshape())), + op::Shape("f32[128,14,129]")); + EXPECT_THAT(root, AllOf(op::Slice(param0), op::Shape("f32[128,11,129]"))); +} + +TEST_F(SpmdPartitioningTest, SliceAlongPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0) + %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1} + ROOT %slice = f32[63,14,251] slice(%param0.copy), + slice={[2:128:2], [0:14:1], [5:256:1]}, sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Constant(), op::Constant(), op::Reshape())), + op::Shape("f32[128,14,129]")); + EXPECT_THAT( + root, + AllOf(op::Slice(AllOf( + op::DynamicSlice( + AllOf(op::Concatenate( + param0, + AllOf(op::CollectivePermute(op::Slice(param0)), + op::Shape("f32[128,14,2]"))), + op::Shape("f32[128,14,131]")), + op::Constant(), op::Constant(), op::Add()), + op::Shape("f32[128,14,126]"))), + op::Shape("f32[63,14,126]"))); +} + +TEST_F(SpmdPartitioningTest, SortAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ge { + p.0.lhs.1247 = f32[]{:T(256)} parameter(0), sharding={replicated} + bitcast-convert = s32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated} + constant = s32[]{:T(256)} constant(0), sharding={replicated} + compare = pred[]{:T(256)E(32)} compare(bitcast-convert, constant), direction=LT, sharding={replicated} + constant.1 = u32[]{:T(256)} constant(2147483647), sharding={replicated} + bitcast-convert.1 = u32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated} + subtract = u32[]{:T(256)} subtract(constant.1, bitcast-convert.1), sharding={replicated} + bitcast-convert.2 = s32[]{:T(256)} bitcast-convert(subtract), sharding={replicated} + select = s32[]{:T(256)} select(compare, bitcast-convert.2, bitcast-convert), sharding={replicated} + p.0.rhs.1248 = f32[]{:T(256)} parameter(1), sharding={replicated} + bitcast-convert.3 = s32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated} + compare.1 = pred[]{:T(256)E(32)} compare(bitcast-convert.3, constant), direction=LT, sharding={replicated} + bitcast-convert.4 = u32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated} + subtract.1 = u32[]{:T(256)} subtract(constant.1, bitcast-convert.4), sharding={replicated} + bitcast-convert.5 = s32[]{:T(256)} bitcast-convert(subtract.1), sharding={replicated} + select.1 = s32[]{:T(256)} select(compare.1, bitcast-convert.5, bitcast-convert.3), sharding={replicated} + compare.2 = pred[]{:T(256)E(32)} compare(select, select.1), direction=GT, sharding={replicated} + compare.258 = pred[]{:T(256)E(32)} compare(select.1, select), direction=GT, sharding={replicated} + compare.259 = pred[]{:T(256)E(32)} compare(compare.2, compare.258), direction=EQ, sharding={replicated} + p.1.lhs.1249 = s32[]{:T(256)} parameter(2), sharding={replicated} + p.1.rhs.1250 = s32[]{:T(256)} parameter(3), sharding={replicated} + compare.260 = pred[]{:T(256)E(32)} compare(p.1.lhs.1249, p.1.rhs.1250), direction=LT, sharding={replicated} + ROOT select.86 = pred[]{:T(256)E(32)} select(compare.259, compare.260, compare.2), sharding={replicated} +} + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0) + %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,2,1]0,1} + %param1 = s32[128,14,257] parameter(1) + %param1.copy = s32[128,14,257] copy(%param1), sharding={devices=[1,2,1]0,1} + ROOT %sort.6 = (f32[128,14,257]{2,1,0:T(8,128)}, s32[128,14,257]{2,1,0:T(8,128)}) + sort(%param0.copy, %param1.copy), dimensions={2}, is_stable=true, + to_apply=%ge, sharding={{devices=[1,2,1]0,1},{devices=[1,2,1]0,1}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = + AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[128,7,257]")); + auto param1 = + AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("s32[128,7,257]")); + EXPECT_THAT(root, AllOf(op::Sort(param0, param1), + op::Shape("(f32[128,7,257], s32[128,7,257])"))); +} + +TEST_F(SpmdPartitioningTest, PartitionCustomCall) { + const char* const hlo_string = R"( +HloModule cluster_2013453984438090939__.47 + +ENTRY %cluster_2013453984438090939__.47 + (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) { + %arg_tuple.1 = bf16[2,209664] parameter(0) + %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1} + %custom-call = (bf16[2,2000]{1,0}, s32[2,2000]{1,0}) + custom-call(bf16[2,209664]{1,0} %copy.arg_tuple.1), custom_call_target="TopK" + %get-tuple-element = bf16[2,2000]{1,0} + get-tuple-element((bf16[2,2000]{1,0}, s32[2,2000]{1,0}) %custom-call), + index=0, sharding={replicated} + %get-tuple-element.1 = s32[2,2000]{1,0} get-tuple-element((bf16[2,2000]{1,0}, + s32[2,2000]{1,0}) %custom-call), index=1, sharding={replicated} + ROOT %tuple.46 = (bf16[2,2000]{1,0}, s32[2,2000]{1,0}) + tuple(bf16[2,2000]{1,0} %get-tuple-element, s32[2,2000]{1,0} + %get-tuple-element.1), sharding={{replicated}, {replicated}}, + metadata={op_name="XLA_Retvals"} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto custom_call = FindInstruction(module.get(), "custom-call.1"); + EXPECT_EQ(custom_call->operand(0)->shape().dimensions(1), 104832); + auto sort = FindInstruction(module.get(), "sort"); + EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 4000); + EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 4000); +} + +TEST_F(SpmdPartitioningTest, ShardableTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1} + ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy), + dimensions={0,3,1,2}, sharding={devices=[1,1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,19,38,4]")); + EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[16,4,19,38]"))); +} + +TEST_F(SpmdPartitioningTest, MultiDimensionShardedTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), + sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7} + ROOT %transpose = f32[38,4,16,38] transpose(%param0.copy), + dimensions={1,3,0,2}, sharding={devices=[2,1,4,1]0,2,4,6,1,3,5,7} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[4,19,38,4]")); + EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[19,4,4,38]"))); +} + +TEST_F(SpmdPartitioningTest, NonShardableTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1} + ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy), + dimensions={0,3,1,2}, sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto resahrd = AllOf(op::Reshape(op::Transpose(op::Reshape(op::AllToAll()))), + op::Shape("f32[16,38,38,2]")); + EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]"))); +} + +TEST_F(SpmdPartitioningTest, ShardableReshape) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[38,38,324] parameter(0) + %param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[2,1,1]0,1} + ROOT %reshape = f32[38,38,4,81] reshape(%param0.copy), + sharding={devices=[2,1,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = + AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[19,38,324]")); + EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]"))); +} + +TEST_F(SpmdPartitioningTest, NonShardableReshape) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[38,38,324] parameter(0) + %param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[1,1,2]0,1} + ROOT %transpose = f32[38,38,4,81] reshape(%param0.copy), + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::DynamicSlice( + AllOf(op::Pad( + AllOf(op::Reshape(AllOf(op::AllReduce(), + op::Shape("f32[38,38,324]"))), + op::Shape("f32[38,38,4,81]")), + op::Constant()), + op::Shape("f32[38,38,4,82]")), + op::Constant(), op::Constant(), op::Constant(), op::Reshape()), + op::Shape("f32[38,38,4,41]"))); +} + +// Produces an invalid module after transformation. +TEST_F(SpmdPartitioningTest, InceptionV3_4_way_ReduceWindowDilated) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %param0 = f32[128,5,5,768] parameter(0) + %param0.copy = f32[128,5,5,768] copy(%param0), + sharding={devices=[1,4,1,1]0,1,2,3} + %constant.1 = f32[] constant(0), sharding={replicated} + ROOT %rw = f32[128,17,17,768] reduce-window(%param0.copy, %constant.1), + window={size=1x5x5x1 pad=0_0x4_4x4_4x0_0 lhs_dilate=1x3x3x1}, + to_apply=sum, sharding={devices=[1,4,1,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input_shard = op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(0), op::Constant()), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())); + auto id_mul4_add1 = + op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant()); + auto id_mul5 = op::Multiply(op::Reshape(), op::Constant()); + auto id_mul5_add1_div3 = + op::Divide(op::Add(id_mul5, op::Constant()), op::Constant()); + auto before_masking = AllOf( + op::Shape("f32[128,3,5,768]"), + op::DynamicSlice( + AllOf( + op::Shape("f32[128,4,5,768]"), + op::Concatenate(op::CollectivePermute(input_shard), input_shard)), + op::Constant(), + op::Subtract(op::Constant(), + op::Subtract(id_mul4_add1, id_mul5_add1_div3)), + op::Constant(), op::Constant())); + auto masked = op::Select( + op::And(op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)), + op::Broadcast(op::Constant())), + op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)), + op::Broadcast(op::Constant()))), + before_masking, op::Broadcast(op::Constant())); + auto rw = AllOf(op::Shape("f32[128,7,17,768]"), + op::ReduceWindow(masked, op::Constant())); + auto final_slice_index = op::Subtract( + id_mul5, + op::Add(op::Multiply(id_mul5_add1_div3, op::Constant()), op::Constant())); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Shape("f32[128,5,17,768]"), + op::DynamicSlice(rw, op::Constant(), final_slice_index, + op::Constant(), op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledReduce) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %param0 = f32[4,32,32,128] parameter(0) + %param0.copy = f32[4,32,32,128] copy(%param0), + sharding={devices=[1,1,1,2]0,1} + %constant.1 = f32[] constant(0), sharding={replicated} + %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2}, + to_apply=%sum, sharding={devices=[2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[4,32,32,64]")); + + EXPECT_THAT(root, + AllOf(op::Reduce(param0, op::Constant()), op::Shape("f32[64]"))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledTupleReduce) { + const char* const hlo_string = R"( +HloModule module + +%minmax_func { + %lhs_value = f32[] parameter(0) + %rhs_value = f32[] parameter(2) + %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT + %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value) + %lhs_index = s32[] parameter(1) + %rhs_index = s32[] parameter(3) + %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index) + ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5) +} + +ENTRY %main { + %param0 = f32[28,10] parameter(0), sharding={devices=[2,1]0,1} + %param1 = s32[28,10] parameter(1), sharding={devices=[2,1]0,1} + %init0 = f32[] parameter(2) + %init1 = s32[] parameter(3) + ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1), + dimensions={1}, to_apply=%minmax_func, + sharding={{devices=[2]0,1}, {devices=[2]0,1}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Reduce(op::Parameter(0), op::Parameter(1), + op::Parameter(2), op::Parameter(3)), + op::Shape("(f32[14], s32[14])"))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledReduceOutputReshard) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %param0 = f32[4,32,32,128] parameter(0) + %param0.copy = f32[4,32,32,128] copy(%param0), + sharding={devices=[1,2,1,1]0,1} + %constant.1 = f32[] constant(0), sharding={replicated} + %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2}, + to_apply=%sum, sharding={devices=[2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[4,16,32,128]")); + + EXPECT_THAT(root, + AllOf(op::DynamicSlice( + AllOf(op::AllReduce(op::Reduce(param0, op::Constant())), + op::Shape("f32[128]")), + op::Reshape()), + op::Shape("f32[64]"))); +} + +TEST_F(SpmdPartitioningTest, IotaAlongNonTileDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + ROOT %iota = s32[16,80,91] iota(), iota_dimension=1, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Iota(), op::Shape("s32[16,80,46]"))); +} + +TEST_F(SpmdPartitioningTest, IotaAlongTileDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + ROOT %iota = s32[16,80,91] iota(), iota_dimension=2, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()), + op::Shape("s32[16,80,46]"))); +} + +TEST_F(SpmdPartitioningTest, U32IotaAlongTileDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + ROOT %iota = u32[16,80,91] iota(), iota_dimension=2, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()), + op::Shape("u32[16,80,46]"))); +} + +TEST_F(SpmdPartitioningTest, Conditional) { + const char* const hlo_string = R"( +HloModule module + +Negate { + x = f32[4,5] parameter(0), sharding={replicated} + ROOT negate = f32[4,5] negate(x), sharding={replicated} +} + +Identity { + y = f32[4,5] parameter(0), sharding={devices=[2,1]0,1} + ROOT copy = f32[4,5] copy(y), sharding={devices=[2,1]0,1} +} + +ENTRY entry { + %param.0 = pred[] parameter(0) + %param.0.copy = pred[] copy(%param.0), sharding={maximal device=0} + %param.1 = f32[4,5] parameter(1) + %param.1.copy = f32[4,5] copy(%param.1), sharding={replicated} + %param.2 = f32[4,5] parameter(2) + %param.2.copy = f32[4,5] copy(%param.2), sharding={devices=[2,1]0,1} + ROOT cond = f32[4,5] conditional(%param.0.copy, %param.1.copy, %param.2.copy), + true_computation=Negate, false_computation=Identity, + sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto param0 = AllOf(op::Copy(op::Copy(op::Parameter()), op::Shape("pred[]"))); + auto param1 = AllOf(op::Copy(op::Parameter()), op::Shape("f32[4,5]")); + auto param2 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[2,5]")); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Conditional(op::AllReduce(), param1, param2), + op::Shape("f32[2,5]"))); + + auto then_branch_root = root->branch_computation(0)->root_instruction(); + EXPECT_THAT(then_branch_root, + AllOf(op::DynamicSlice(op::Negate(op::Parameter()), op::Reshape(), + op::Constant()), + op::Shape("f32[2,5]"))); + + auto else_branch_root = root->branch_computation(1)->root_instruction(); + EXPECT_THAT(else_branch_root, + AllOf(op::Copy(op::Parameter()), op::Shape("f32[2,5]"))); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatter_RetinaNet) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param.0 = f32[32,128,384,64] parameter(0) + %param.0.copy = f32[32,128,384,64] copy(%param.0), + sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7} + %param.1 = f32[32,64,192,64] parameter(1) + %param.1.copy = f32[32,64,192,64] copy(%param.1), + sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[32,128,384,64] select-and-scatter(param.0.copy, + %param.1.copy, constant.1), window={size=1x1x1x1 stride=1x2x2x1}, + select=ge, scatter=sum, sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto source = AllOf( + op::Shape("f32[32,8,192,64]"), + op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(), + op::Constant(), op::Constant()))); + auto data = AllOf( + op::Shape("f32[32,16,384,64]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Constant()))); + + EXPECT_THAT(root, op::SelectAndScatter(data, source, op::Constant())); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, TiledDot) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,64] parameter(0) + %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1} + %rhs = f32[64,256] parameter(1) + %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1} + ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy), + dim_labels=bf_io->bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), + op::Reshape())), + op::Shape("f32[128,32]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[32,256]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[128,256]"))); +} + +TEST_F(SpmdPartitioningTest, TiledDotOutputTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,64] parameter(0) + %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1} + %rhs = f32[64,256] parameter(1) + %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1} + ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy), + dim_labels=bf_io->bf, sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), + op::Reshape())), + op::Shape("f32[128,32]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[32,256]")); + EXPECT_THAT(root, AllOf(op::DynamicSlice( + AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[128,256]")), + op::Constant(), op::Reshape()), + op::Shape("f32[128,128]"))); +} + +TEST_F(SpmdPartitioningTest, BatchPartitionedConvolution) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,256,256] parameter(0) + %lhs.copy = f32[128,256,256] copy(%lhs), sharding={devices=[1,2,1]0,1} + %rhs = f32[256,8,1] parameter(1) + %rhs.copy = f32[256,8,1] copy(%rhs), sharding={replicated} + ROOT %conv = f32[128,256,8] convolution(%lhs.copy, %rhs.copy), + window={size=1}, dim_labels=0bf_io0->0bf, sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[128,128,256]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[256,8,1]")); + EXPECT_THAT(root, + AllOf(op::Convolution(lhs, rhs), op::Shape("f32[128,128,8]"))); +} + +TEST_F(SpmdPartitioningTest, DotOutputFeaturePartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,64] parameter(0) + %lhs.copy = f32[24,64] copy(%lhs), sharding={replicated} + %rhs = f32[39296,64] parameter(1) + %rhs.copy = f32[39296,64] copy(%rhs), sharding={devices=[2,1]0,1} + ROOT %dot = f32[24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[24,64]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), + op::Constant())), + op::Shape("f32[19648,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[24,19648]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,24,64]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumLHSandOutputBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,24,64]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, op::DynamicSlice(rhs, op::Reshape(), + op::Constant(), + op::Constant())), + op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSandOutputBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[1,2,1]0,1} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[32,12,64]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,39296,64]")); + auto lhs_reshard = op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs)))); + EXPECT_THAT(root, + AllOf(op::Dot(lhs_reshard, rhs), op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumOutputBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs_slice = + AllOf(op::DynamicSlice(op::Copy(op::Parameter(0)), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[16,24,64]")); + auto rhs_slice = + AllOf(op::DynamicSlice(op::Copy(op::Parameter(1)), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[16,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs_slice, rhs_slice), + op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumContractingDimsPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,1,2,2]0,1,2,3} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,1,2,2]0,1,2,3} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Constant(), op::Reshape(), op::Reshape())), + op::Shape("f32[32,24,32,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), + op::Constant(), op::Reshape(), op::Reshape())), + op::Shape("f32[32,39296,32,64]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Dot(lhs, rhs)), + op::Shape("f32[32,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumLHSNonContractingDimsPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,2]0,1,2,3} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,128,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[1,2,2,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Reshape())), + op::Shape("f32[32,12,64,64]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,12,64,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSNonContractingDimsPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,2,1,2]0,1,2,3} + ROOT %dot = f32[32,24,39296,128] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[1,1,2,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(), + op::Constant(), op::Reshape())), + op::Shape("f32[32,19648,64,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,24,19648,64]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumOutputLHSNonContractingDimPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]")); + auto rhs = + AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]")); + EXPECT_THAT( + root, + AllOf(op::Dot(AllOf(op::DynamicSlice(lhs, op::Constant(), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[32,12,64,128]")), + rhs), + op::Shape("f32[32,12,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumOutputRHSNonContractingDimPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]")); + auto rhs = + AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]")); + EXPECT_THAT(root, + AllOf(op::Dot(lhs, AllOf(op::DynamicSlice( + rhs, op::Constant(), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[32,19648,64,128]"))), + op::Shape("f32[32,24,19648]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39295,64,128] parameter(1) + %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,12,64,128]")); + auto rhs = + AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()), + op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,19648,64,128]")); + EXPECT_THAT( + root, + AllOf(op::Slice(AllOf(op::GetTupleElement(op::While(op::Tuple( + lhs, rhs, op::Broadcast(), op::Constant()))), + op::Shape("f32[32,12,39296]"))), + op::Shape("f32[32,12,39295]"))); + auto while_loop = root->operand(0)->operand(0); + // Check loop condition. + EXPECT_THAT( + while_loop->while_condition()->root_instruction(), + op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant())); + + // Check loop body. + auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()); + auto window = op::Conditional(op::Compare(next_i, op::Constant()), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0))); + auto partial_output = op::Dot(op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0))); + EXPECT_THAT( + while_loop->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(op::Parameter(0)), window, + op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)), + partial_output, op::Constant(), + op::Constant(), op::Reshape()), + next_i)); + + // Check the conditional that contains the collective permute. + auto cp_conditional = + while_loop->while_body()->root_instruction()->operand(1); + EXPECT_THAT(cp_conditional->true_computation()->root_instruction(), + op::CollectivePermute(op::Parameter(0))); + EXPECT_THAT(cp_conditional->false_computation()->root_instruction(), + op::Parameter(0)); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,63,128] parameter(0) + %lhs.copy = f32[32,24,63,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39296,63,128] parameter(1) + %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,12,63,128]")); + auto rhs = + AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()), + op::Constant(), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[32,39296,32,128]")); + auto masked_rhs = + op::Select(op::Compare(), rhs, op::Broadcast(op::Constant())); + EXPECT_THAT(root, + AllOf(op::GetTupleElement(op::While(op::Tuple( + lhs, masked_rhs, op::Broadcast(), op::Constant()))), + op::Shape("f32[32,12,39296]"))); + auto while_loop = root->operand(0); + // Check loop condition. + EXPECT_THAT( + while_loop->while_condition()->root_instruction(), + op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant())); + + // Check loop body. + auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()); + auto window = op::Conditional(op::Compare(next_i, op::Constant()), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0))); + auto partial_output = op::Dot( + op::DynamicSlice( + op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()), + op::Constant(), op::Constant(), op::Reshape(), op::Constant()), + op::GetTupleElement(op::Parameter(0))); + EXPECT_THAT( + while_loop->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(op::Parameter(0)), window, + op::Add(op::GetTupleElement(op::Parameter(0)), partial_output), + next_i)); + + // Check the conditional that contains the collective permute. + auto cp_conditional = + while_loop->while_body()->root_instruction()->operand(1); + EXPECT_THAT(cp_conditional->true_computation()->root_instruction(), + op::CollectivePermute(op::Parameter(0))); + EXPECT_THAT(cp_conditional->false_computation()->root_instruction(), + op::Parameter(0)); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce1) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39295,64,128] parameter(1) + %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} + %constant = f32[] constant(0) + %constant.1 = f32[] constant(2) + %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={}, + sharding={devices=[1,2,1]0,1} + %multiply = f32[32,24,39295] multiply(%dot, %broadcast), + sharding={devices=[1,2,1]0,1} + ROOT %reduce = f32[32,24] reduce(%multiply, %constant), dimensions={2}, + to_apply=sum, sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + // Involves loop code motion, skips pattern matching. +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce2) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39295,64,128] parameter(1) + %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} + %constant = f32[] constant(0) + %constant.1 = f32[] constant(2) + %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={}, + sharding={devices=[1,2,1]0,1} + %multiply = f32[32,24,39295] multiply(%dot, %broadcast), + sharding={devices=[1,2,1]0,1} + ROOT %reduce = f32[32,39295] reduce(%multiply, %constant), dimensions={1}, + to_apply=sum, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + // Involves loop code motion, skips pattern matching. +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContractingFromBroadcast) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %rhs = f32[32,39296,63,128] parameter(0) + %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1} + %constant.1 = f32[] constant(2) + %broadcast = f32[32,24,63,128] broadcast(%constant.1), dimensions={}, + sharding={devices=[1,2,1,1]0,1} + %add = f32[32,24,63,128] add(%broadcast, %broadcast), + sharding={devices=[1,2,1,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%add, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + // Involves loop code motion, skips pattern matching. +} + +TEST_F(SpmdPartitioningTest, ReplicatedRng) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = s32[] parameter(0) + %lhs.copy = s32[] copy(%lhs), sharding={replicated} + %rhs = s32[] parameter(1) + %rhs.copy = s32[] copy(%rhs), sharding={replicated} + ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy), + distribution=rng_uniform, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("s32[]")); + EXPECT_THAT( + root, + AllOf(op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::Rng(), op::Broadcast(op::Constant()))), + op::Shape("s32[4]"))); +} + +TEST_F(SpmdPartitioningTest, PartitionedRng) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = s32[] parameter(0) + %lhs.copy = s32[] copy(%lhs), sharding={replicated} + %rhs = s32[] parameter(1) + %rhs.copy = s32[] copy(%rhs), sharding={maximal device=1} + ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy), + distribution=rng_uniform, sharding={devices=[2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]")); + auto rhs = AllOf(op::Copy(op::Copy(op::Parameter(1))), op::Shape("s32[]")); + EXPECT_THAT(root, AllOf(op::Rng(lhs, op::AllReduce(op::Select( + op::Broadcast(op::Compare()), rhs, + op::Broadcast(op::Constant())))), + op::Shape("s32[2]"))); +} + +TEST_F(SpmdPartitioningTest, DynamicSliceAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = s32[128,64] parameter(0) + %input.copy = s32[128,64] copy(%input), sharding={devices=[2,1]0,1} + %index = s32[] parameter(1) + %constant = s32[] constant(0) + ROOT %dynamic-slice = s32[128,2] dynamic-slice(%input.copy, %constant, %index), + dynamic_slice_sizes={128,2}, sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant())), + op::Shape("s32[64,64]")); + EXPECT_THAT(root, + AllOf(op::DynamicSlice(input, op::Constant(), op::Parameter(1)), + op::Shape("s32[64,2]"))); +} + +TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = s32[128,64] parameter(0) + %input.copy = s32[128,64] copy(%input), sharding={devices=[2,1]0,1} + %index = s32[] parameter(1) + %constant = s32[] constant(0) + %update = s32[128,2] parameter(2) + %update.copy = s32[128,2] copy(%update), sharding={devices=[2,1]0,1} + ROOT %dynamic-update-slice = s32[128,64] + dynamic-update-slice(%input.copy, %update.copy, %constant, %index), + sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant())), + op::Shape("s32[64,64]")); + auto update = AllOf(op::Copy(op::DynamicSlice(op::Parameter(2), op::Reshape(), + op::Constant())), + op::Shape("s32[64,2]")); + EXPECT_THAT(root, AllOf(op::DynamicUpdateSlice(input, update, op::Constant(), + op::Parameter(1)), + op::Shape("s32[64,64]"))); +} + +TEST_F(SpmdPartitioningTest, PassthroughGather) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1} + %indices = s32[3] parameter(1), sharding={replicated} + ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9}, sharding={devices=[1,2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)), + op::Shape("f32[3,5]"))); +} + +TEST_F(SpmdPartitioningTest, GatherPartitionedOnTrivialSliceDims) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1} + %indices = s32[2,3] parameter(1), sharding={replicated} + ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, + slice_sizes={1,9}, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto offset = op::Reshape( + op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")); + auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())), + op::Shape("s32[2,3]")); + auto clamp = op::Clamp(min, op::Parameter(1), max); + auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min)); + auto mask = + op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max)); + auto masked = + op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]"))); +} + +TEST_F(SpmdPartitioningTest, PassthroughScatter) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1} + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), sharding={devices=[1,2]0,1} + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, sharding={devices=[1,2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1), + op::Parameter(2)), + op::Shape("f32[2,5]"))); +} + +TEST_F(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDims) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1} + %indices = s32[2,3] parameter(1), sharding={replicated} + %updates = f32[2,3,9] parameter(2), sharding={replicated} + ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2, sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto offset = op::Reshape( + op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + auto indices = op::Subtract( + op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"))); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)), + op::Shape("f32[9,9]"))); +} + +TEST_F(SpmdPartitioningTest, TiledReverse) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + ROOT reverse = f32[3,3]{1,0} reverse(constant), dimensions={1}, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]{1,0}"), + op::Reverse(op::DynamicSlice( + op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param = f32[8,2] parameter(0), sharding={devices=[2,1]0,1} + to_shard = f32[4,2] custom-call(param), custom_call_target="SPMDFullToShardShape", sharding={replicated} + add = f32[4,2] add(to_shard, to_shard), sharding={replicated} + to_full = f32[8,2] custom-call(add), custom_call_target="SPMDShardToFullShape", sharding={devices=[2,1]0,1} + ROOT mul = f32[8,2] multiply(to_full, param), sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + auto to_shard = op::Copy(op::Parameter(0)); + EXPECT_THAT(root, AllOf(op::Shape("f32[4,2]"), + op::Multiply(op::Copy(op::Add(to_shard, to_shard)), + op::Parameter(0)))); +} + +} // namespace +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc new file mode 100644 index 00000000000..207f854cd9f --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc @@ -0,0 +1,662 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace spmd { + +bool HasReplicatedSharding(const HloSharding& sharding) { + if (sharding.IsTuple()) { + return absl::c_any_of(sharding.tuple_elements(), HasReplicatedSharding); + } + return sharding.IsReplicated(); +} + +HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) { + if (shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + elements.push_back( + CreateZero(ShapeUtil::GetTupleElementShape(shape, i), b)); + } + return b->AddInstruction(HloInstruction::CreateTuple(elements)); + } + + if (shape.IsToken()) { + return b->AddInstruction(HloInstruction::CreateToken()); + } + auto zero = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); + return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {})); +} + +HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) { + HloComputation::Builder sum_b("add"); + auto x = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x")); + auto y = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y")); + if (type == PRED) { + sum_b.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y)); + } else { + sum_b.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y)); + } + HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build()); + return reduction; +} + +bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) { + if (sharding.IsTuple()) { + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + if (!EvenlyPartitions(ShapeUtil::GetTupleElementShape(shape, i), + sharding.GetSubSharding(shape, {i}))) { + return false; + } + } + } + + if (sharding.IsTileMaximal()) { + return sharding.IsReplicated(); + } + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) { + return false; + } + } + return true; +} + +Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) { + if (sharding.IsTuple()) { + std::vector subshapes; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + subshapes.push_back( + MakePartitionedShape(ShapeUtil::GetTupleElementShape(shape, i), + sharding.GetSubSharding(shape, {i}))); + } + return ShapeUtil::MakeTupleShape(subshapes); + } + return sharding.TileShape(shape); +} + +Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, + const HloSharding& sharding, + int64 partition_id) { + if (sharding.IsTuple()) { + std::vector subshapes; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + subshapes.push_back(MakeNonPaddedShapeForGivenPartition( + ShapeUtil::GetTupleElementShape(shape, i), + sharding.GetSubSharding(shape, {i}), partition_id)); + } + return ShapeUtil::MakeTupleShape(subshapes); + } + + auto partition_shape = shape; + std::vector tile_offset = + sharding.TileOffsetForDevice(shape, partition_id); + std::vector tile_limit = + sharding.TileLimitForDevice(shape, partition_id); + for (int64 i = 0; i < tile_offset.size(); ++i) { + if (sharding.UsesDevice(partition_id)) { + partition_shape.set_dimensions(i, tile_limit[i] - tile_offset[i]); + } else { + partition_shape.set_dimensions(i, 0); + } + } + return partition_shape; +} + +std::vector MakePartitionOffsets(const Shape& shape, + const HloSharding& sharding, + HloInstruction* partition_id, + SpmdBuilder* b) { + CHECK(!shape.IsTuple()); + + Array2D offset_array( + {sharding.tile_assignment().num_elements(), shape.rank()}); + offset_array.Each([&](int64 i, int64 j, int32* value) { + *value = sharding.TileOffsetForDevice(shape, i)[j]; + }); + auto offset_table = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2FromArray2D(offset_array))); + std::vector offsets; + for (int64 i = 0; i < shape.rank(); ++i) { + if (sharding.tile_assignment().dim(i) == 1) { + offsets.push_back(b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32)))); + } else { + auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(S32, {1, 1}), offset_table, + {partition_id, b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(i)))}, + {1, 1})); + offsets.push_back(b->AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index))); + } + } + return offsets; +} + +std::vector MakeTiledPartitionOrdinals( + const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) { + CHECK(!sharding.IsTileMaximal()); + auto table_shape = + ShapeUtil::MakeShape(S32, sharding.tile_assignment().dimensions()); + return MakePartitionOffsets(table_shape, sharding, partition_id, b); +} + +HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape, + SpmdBuilder* b, HloComputation* computation) { + CHECK(b == nullptr || computation == nullptr); + if (ShapeUtil::Compatible(hlo->shape(), padded_shape)) { + return hlo; + } + PaddingConfig padding_config; + for (int64 i = 0; i < padded_shape.rank(); ++i) { + auto padding_config_dim = padding_config.add_dimensions(); + padding_config_dim->set_edge_padding_low(0); + padding_config_dim->set_interior_padding(0); + padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) - + hlo->shape().dimensions(i)); + } + auto add_hlo = [&](std::unique_ptr to_add) { + if (b == nullptr) { + return computation->AddInstruction(std::move(to_add)); + } + return b->AddInstruction(std::move(to_add)); + }; + auto zero = add_hlo(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + return add_hlo( + HloInstruction::CreatePad(padded_shape, hlo, zero, padding_config)); +} + +Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape, + const HloSharding& sharding) { + if (sharding.IsTileMaximal()) { + return base_shape; + } + if (EvenlyPartitions(base_shape, sharding)) { + return base_shape; + } + auto shard_shape = MakePartitionedShape(base_shape, sharding); + Shape padded_base_shape = base_shape; + for (int64 i = 0; i < padded_base_shape.rank(); ++i) { + padded_base_shape.set_dimensions( + i, shard_shape.dimensions(i) * sharding.tile_assignment().dim(i)); + } + return padded_base_shape; +} + +HloInstruction* PadBaseShapeBeforeUnevenTiledSharding( + HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b) { + auto padded_base_shape = + GetPaddedShapeForUnevenPartitioning(hlo->shape(), sharding); + if (ShapeUtil::Compatible(padded_base_shape, hlo->shape())) { + return hlo; + } + return PadToShape(hlo, padded_base_shape, b); +} + +absl::optional UniqueTiledDim(const HloSharding& sharding) { + if (sharding.IsTileMaximal()) { + return absl::nullopt; + } + int64 dim = -1; + for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + if (sharding.tile_assignment().dim(i) > 1) { + if (dim != -1) { + return absl::nullopt; + } + dim = i; + } + } + CHECK_NE(dim, -1); + return dim; +} + +MultiplyAddDivideOffsetCalculation::MultiplyAddDivideOffsetCalculation( + int64 multiplier, int64 offset, int64 divisor) + : multiplier_(multiplier), offset_(offset), divisor_(divisor) { + CHECK_GT(divisor_, 0); + Simplify(); +} + +OffsetCalculation MultiplyAddDivideOffsetCalculation::operator-( + const MultiplyAddDivideOffsetCalculation& other) const { + if (divisor_ == 1 && other.divisor_ == 1) { + return OffsetCalculation(MultiplyAddDivideOffsetCalculation( + multiplier_ - other.multiplier_, offset_ - other.offset_, 1)); + } + return OffsetCalculation(HloOpcode::kSubtract, *this, other); +} + +void MultiplyAddDivideOffsetCalculation::Simplify() { + // We could simplify the calculation when multiplier is a multiple of + // divisor_. However, when offset_ is not a multiple of divisor_, we must + // make sure that offset_ and multiplier_ are both non-negative or both + // non-positive. E.g., (3 * i - 1) / 3 is not equivalent to i or i - 1. + if (divisor_ != 1 && multiplier_ % divisor_ == 0 && + (offset_ % divisor_ == 0 || offset_ * multiplier_ > 0)) { + multiplier_ /= divisor_; + offset_ /= divisor_; + divisor_ = 1; + } +} + +int64 MultiplyAddDivideOffsetCalculation::Calculate(int64 shard_ordinal) const { + return (shard_ordinal * multiplier_ + offset_) / divisor_; +} + +HloInstruction* MultiplyAddDivideOffsetCalculation::Calculate( + HloInstruction* shard_ordinal, SpmdBuilder* b) const { + auto scalar_shape = ShapeUtil::MakeShape(S32, {}); + if (multiplier_ == 0) { + return b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(offset_ / divisor_))); + } + HloInstruction* result = shard_ordinal; + if (multiplier_ != 1) { + result = b->AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kMultiply, shard_ordinal, + b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(multiplier_))))); + } + if (offset_ != 0) { + auto offset = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(offset_))); + result = b->AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, result, offset)); + } + if (divisor_ != 1) { + auto divisor = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(divisor_))); + result = b->AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kDivide, result, divisor)); + } + return result; +} + +int64 MultiplyAddDivideOffsetCalculation::MaxInRange( + int64 start_ordinal, int64 limit_ordinal) const { + int64 max = Calculate(start_ordinal); + for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) { + max = std::max(max, Calculate(i)); + } + return max; +} + +OffsetCalculation& OffsetCalculation::operator=( + const OffsetCalculation& other) { + opcode_ = other.opcode_; + copy_from_ = other.copy_from_; + if (opcode_ != HloOpcode::kCopy) { + lhs_ = absl::make_unique(*other.lhs_); + rhs_ = absl::make_unique(*other.rhs_); + } + return *this; +} + +bool OffsetCalculation::IsConstant() const { + if (opcode_ == HloOpcode::kCopy) { + return copy_from_.IsConstant(); + } + if (opcode_ == HloOpcode::kSubtract && *lhs_ == *rhs_) { + return true; + } + return lhs_->IsConstant() && rhs_->IsConstant(); +} + +OffsetCalculation OffsetCalculation::operator-( + const OffsetCalculation& other) const { + if (opcode_ == HloOpcode::kCopy && other.opcode_ == HloOpcode::kCopy) { + return copy_from_ - other.copy_from_; + } + return OffsetCalculation(HloOpcode::kSubtract, *this, other); +} + +bool OffsetCalculation::operator==(const OffsetCalculation& other) const { + if (opcode_ != other.opcode_) { + return false; + } + if (opcode_ == HloOpcode::kCopy) { + return copy_from_ == other.copy_from_; + } + return *lhs_ == *other.lhs_ && *rhs_ == *other.rhs_; +} + +int64 OffsetCalculation::Calculate(int64 shard_ordinal) const { + switch (opcode_) { + case HloOpcode::kCopy: + return copy_from_.Calculate(shard_ordinal); + case HloOpcode::kSubtract: + return lhs_->Calculate(shard_ordinal) - rhs_->Calculate(shard_ordinal); + case HloOpcode::kMultiply: + return lhs_->Calculate(shard_ordinal) * rhs_->Calculate(shard_ordinal); + default: + LOG(FATAL) << "Should not happen"; + } +} + +HloInstruction* OffsetCalculation::Calculate(HloInstruction* shard_ordinal, + SpmdBuilder* b) const { + if (opcode_ == HloOpcode::kCopy) { + return copy_from_.Calculate(shard_ordinal, b); + } + auto lhs = lhs_->Calculate(shard_ordinal, b); + auto rhs = rhs_->Calculate(shard_ordinal, b); + return b->AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), opcode_, lhs, rhs)); +} + +int64 OffsetCalculation::MaxInRange(int64 start_ordinal, + int64 limit_ordinal) const { + if (IsConstant()) { + return Calculate(start_ordinal); + } + if (opcode_ == HloOpcode::kCopy) { + return std::max(Calculate(start_ordinal), Calculate(limit_ordinal - 1)); + } + int64 max = Calculate(start_ordinal); + for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) { + max = std::max(max, Calculate(i)); + } + return max; +} + +absl::optional ExchangeHalo( + HloInstruction* hlo, const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, int64 dim, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b) { + int64 input_shard_size = hlo->shape().dimensions(dim); + int64 shard_count = target.tile_assignment().dim(dim); + + std::vector concat_pieces; + + int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count); + if (max_left_halo_size > input_shard_size) { + VLOG(1) << "ExchangeHalo failed: halo is beyond the left neighbor."; + return absl::nullopt; + } + if (max_left_halo_size > 0) { + std::vector> source_target_pairs; + target.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + if (indices[dim] > 0) { + std::vector source_indices(indices.begin(), indices.end()); + source_indices[dim] -= 1; + source_target_pairs.emplace_back( + target.tile_assignment()(source_indices), device); + } + }); + auto halo_shape = hlo->shape(); + auto source_halo_slice = hlo; + if (max_left_halo_size != hlo->shape().dimensions(dim)) { + halo_shape.set_dimensions(dim, max_left_halo_size); + std::vector halo_start_indices(halo_shape.rank(), 0); + halo_start_indices[dim] = + hlo->shape().dimensions(dim) - max_left_halo_size; + std::vector halo_slice_strides(halo_shape.rank(), 1); + + source_halo_slice = b->AddInstruction( + hlo->CreateSlice(halo_shape, hlo, halo_start_indices, + hlo->shape().dimensions(), halo_slice_strides)); + } + auto left_halo = + collective_ops_creator.create_cross_partition_collective_permute( + b, source_halo_slice, source_target_pairs, (*next_channel_id)++); + concat_pieces.push_back(left_halo); + } + + concat_pieces.push_back(hlo); + + // Right halo. + int64 max_right_halo_size = + right_halo_size_function.MaxInRange(0, shard_count - 1); + if (max_right_halo_size > input_shard_size) { + VLOG(1) << "ExchangeHalo failed: halo is beyond the right neighbor."; + return absl::nullopt; + } + if (max_right_halo_size > 0) { + std::vector> source_target_pairs; + target.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + if (indices[dim] > 0) { + std::vector target_indices(indices.begin(), indices.end()); + target_indices[dim] -= 1; + source_target_pairs.emplace_back( + device, target.tile_assignment()(target_indices)); + } + }); + auto halo_shape = hlo->shape(); + halo_shape.set_dimensions(dim, max_right_halo_size); + std::vector halo_start_indices(halo_shape.rank(), 0); + std::vector halo_slice_strides(halo_shape.rank(), 1); + + auto source_halo_slice = b->AddInstruction( + hlo->CreateSlice(halo_shape, hlo, halo_start_indices, + halo_shape.dimensions(), halo_slice_strides)); + auto right_halo = + collective_ops_creator.create_cross_partition_collective_permute( + b, source_halo_slice, source_target_pairs, (*next_channel_id)++); + concat_pieces.push_back(right_halo); + } + + auto concat = hlo; + // Concat with halos/padding. + if (concat_pieces.size() > 1) { + auto concat_shape = hlo->shape(); + int64 concat_dim_size = 0; + for (auto piece : concat_pieces) { + concat_dim_size += piece->shape().dimensions(dim); + } + concat_shape.set_dimensions(dim, concat_dim_size); + concat = b->AddInstruction( + HloInstruction::CreateConcatenate(concat_shape, concat_pieces, dim)); + } + + return concat; +} + +absl::optional ExchangeHalo( + HloInstruction* hlo, + std::vector left_halo_size_functions, + std::vector right_halo_size_functions, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b) { + CHECK(left_halo_size_functions.size() == hlo->shape().rank()); + CHECK(right_halo_size_functions.size() == hlo->shape().rank()); + + HloInstruction* visiting_hlo = hlo; + for (int dim = 0; dim < hlo->shape().rank(); ++dim) { + auto concat = ExchangeHalo(visiting_hlo, left_halo_size_functions[dim], + right_halo_size_functions[dim], dim, target, + collective_ops_creator, next_channel_id, b); + if (!concat) { + return absl::nullopt; + } + visiting_hlo = *concat; + } + return visiting_hlo; +} + +absl::optional ExchangeHaloAndGetValidData( + HloInstruction* hlo, const Shape& base_shape, + const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, + int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size, + int64 shard_size_with_halo, int64 dim, const HloSharding& target, + HloInstruction* offset_on_padded_shape, HloInstruction* pad_value, + HloInstruction* partition_ordinal, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region) { + auto halo_exchange_result = + ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function, dim, + target, collective_ops_creator, next_channel_id, b); + if (!halo_exchange_result) { + return absl::nullopt; + } + auto concat = *halo_exchange_result; + int64 shard_count = target.tile_assignment().dim(dim); + int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count); + + // Now we determine if we need extra padding after the concat. + // + // The max of halo size or the first shard's explicit left padding. + int64 max_left_halo_or_padding_size = + std::max(std::max(int64{0}, max_left_halo_size), + explicit_left_padding_on_full_shape); + // The calculation that returns the dynamic slice index for a shard on the + // padded concat, which is the difference between + // max_left_halo_or_padding_size and its left halo size. + auto start_offset_on_padded_concat_calculation = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + 0, max_left_halo_or_padding_size, 1)) - + left_halo_size_function; + + // See if we need to pad the concat before dynamic slice. + int64 extra_left_padding = + std::max(int64{0}, max_left_halo_or_padding_size - + std::max(int64{0}, max_left_halo_size)); + int64 extra_right_padding = + start_offset_on_padded_concat_calculation.MaxInRange(0, shard_count) + + shard_size_with_halo - concat->shape().dimensions(dim) - + extra_left_padding; + extra_right_padding = std::max(int64{0}, extra_right_padding); + if (extra_left_padding > 0 || extra_right_padding > 0) { + PaddingConfig padding_config; + auto padded_concat_shape = concat->shape(); + for (int64 i = 0; i < base_shape.rank(); ++i) { + auto padding_config_dim = padding_config.add_dimensions(); + padding_config_dim->set_interior_padding(0); + padding_config_dim->set_edge_padding_low(0); + padding_config_dim->set_edge_padding_high(0); + if (i != dim) { + continue; + } + padding_config_dim->set_edge_padding_low(extra_left_padding); + padding_config_dim->set_edge_padding_high(extra_right_padding); + padded_concat_shape.set_dimensions(dim, concat->shape().dimensions(dim) + + extra_left_padding + + extra_right_padding); + } + concat = b->AddInstruction(HloInstruction::CreatePad( + padded_concat_shape, concat, pad_value, padding_config)); + } + + auto valid_slice = concat; + if (shard_size_with_halo != concat->shape().dimensions(dim)) { + // Concat is bigger than the shard shape, so we need a dynamic slice. + CHECK_LT(shard_size_with_halo, concat->shape().dimensions(dim)); + auto slice_shape = concat->shape(); + slice_shape.set_dimensions(dim, shard_size_with_halo); + + if (left_halo_size_function.IsConstant() && + left_halo_size_function.Calculate(0) == + explicit_left_padding_on_full_shape) { + std::vector start_indices(slice_shape.rank(), 0); + std::vector strides(slice_shape.rank(), 1); + valid_slice = b->AddInstruction( + HloInstruction::CreateSlice(slice_shape, concat, start_indices, + slice_shape.dimensions(), strides)); + } else { + auto zero = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + std::vector slice_offsets(base_shape.rank(), zero); + slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate( + partition_ordinal, b); + valid_slice = b->AddInstruction(HloInstruction::CreateDynamicSlice( + slice_shape, concat, slice_offsets, slice_shape.dimensions())); + } + } + + if (!mask_invalid_region) { + return valid_slice; + } + + int64 total_right_padding = padded_full_shape_size - + base_shape.dimensions(dim) - + explicit_left_padding_on_full_shape; + // Mask off garbage data due to uneven partition or low/high padding. + if (explicit_left_padding_on_full_shape > 0 || total_right_padding > 0) { + auto index_shape = ShapeUtil::ChangeElementType(valid_slice->shape(), S32); + auto iota = b->AddInstruction(HloInstruction::CreateIota(index_shape, dim)); + auto broadcast_start_index_in_padded_shape = + b->AddInstruction(HloInstruction::CreateBroadcast( + index_shape, offset_on_padded_shape, {})); + auto index_in_padded_shape = b->AddInstruction( + HloInstruction::CreateBinary(index_shape, HloOpcode::kAdd, iota, + broadcast_start_index_in_padded_shape)); + auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED); + std::vector predicates; + if (explicit_left_padding_on_full_shape > 0) { + auto valid_index_start = + b->AddInstruction(HloInstruction::CreateBroadcast( + index_shape, + b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + explicit_left_padding_on_full_shape))), + {})); + predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare( + mask_shape, index_in_padded_shape, valid_index_start, + ComparisonDirection::kGe))); + } + if (total_right_padding > 0) { + auto valid_index_limit = + b->AddInstruction(HloInstruction::CreateBroadcast( + index_shape, + b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + base_shape.dimensions(dim) + + explicit_left_padding_on_full_shape))), + {})); + predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare( + mask_shape, index_in_padded_shape, valid_index_limit, + ComparisonDirection::kLt))); + } + CHECK(!predicates.empty()); + auto is_valid = + predicates.size() == 2 + ? b->AddInstruction(HloInstruction::CreateBinary( + mask_shape, HloOpcode::kAnd, predicates[0], predicates[1])) + : predicates[0]; + auto masking_value = b->AddInstruction( + HloInstruction::CreateBroadcast(valid_slice->shape(), pad_value, {})); + valid_slice = b->AddInstruction( + HloInstruction::CreateTernary(valid_slice->shape(), HloOpcode::kSelect, + is_valid, valid_slice, masking_value)); + } + return valid_slice; +} + +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h new file mode 100644 index 00000000000..f96b23d7073 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h @@ -0,0 +1,229 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" + +namespace xla { +namespace spmd { + +// Returns true if the given sharding contains any replicated sharding. +bool HasReplicatedSharding(const HloSharding& sharding); + +// Creates zero value instructions of the given shape. +HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b); + +template +HloInstruction* CreateR0WithType(PrimitiveType type, NativeT value, + SpmdBuilder* b) { + auto literal = LiteralUtil::CreateR0(value) + .ConvertToShape(ShapeUtil::MakeShape(type, {})) + .ValueOrDie(); + return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal))); +} + +// Create a binary add computation of the given type and add to the module. +HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module); + +// Returns true if the shape can be evenly partitioned for the given sharding. +// All tile sharded dimensions should be evenly divisible and there should be no +// single-device sharding. Replicate sharding is considered even partition. +bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding); + +// Returns the shard shape of the given shape when it is partitioned for the +// target sharding. +Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding); + +// Returns the shard shape for a partition without padding due to uneven +// sharding. +Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, + const HloSharding& sharding, + int64 partition_id); + +// Generates the HLO instructions that represent the dimension offsets on any +// device. The size of the returned vector is the rank of the given shape. +std::vector MakePartitionOffsets(const Shape& shape, + const HloSharding& sharding, + HloInstruction* partition_id, + SpmdBuilder* b); + +// Returns the offsets of the partition in the tile assignment. +std::vector MakeTiledPartitionOrdinals( + const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b); + +// Pads hlo to the desired shape using high padding. Either a builder or a +// computation needs to be supplied, but not both. +HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape, + SpmdBuilder* b, + HloComputation* computation = nullptr); + +// Returns the padded shape when combining all partitions. +Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape, + const HloSharding& sharding); + +// Pads the HLO (with base shape) for uneven tiled partition to make it evenly +// partitionable. +HloInstruction* PadBaseShapeBeforeUnevenTiledSharding( + HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b); + +// Returns the index of the unique tile dimension. Returns absl::nullopt if the +// given sharding is not tiled or tiled along multiple dimensions. +absl::optional UniqueTiledDim(const HloSharding& sharding); + +// Utilities for symbolic offset calculation and halo exchange. +class OffsetCalculation; + +// Represents a calculation over integers: +// (shard_ordinal * multiplier + offset) / divisor +class MultiplyAddDivideOffsetCalculation { + public: + MultiplyAddDivideOffsetCalculation() + : multiplier_(0), offset_(0), divisor_(1) {} + MultiplyAddDivideOffsetCalculation(int64 multiplier, int64 offset, + int64 divisor); + + OffsetCalculation operator-( + const MultiplyAddDivideOffsetCalculation& other) const; + + bool operator==(const MultiplyAddDivideOffsetCalculation& other) const { + return multiplier_ == other.multiplier_ && offset_ == other.offset_ && + divisor_ == other.divisor_; + } + + bool IsConstant() const { return multiplier_ == 0; } + void Simplify(); + int64 Calculate(int64 shard_ordinal) const; + HloInstruction* Calculate(HloInstruction* shard_ordinal, + SpmdBuilder* b) const; + + // Returns the maximum result for shard ordinals in the range + // [start_ordinal, limit_ordinal). + int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const; + + private: + int64 multiplier_; + int64 offset_; + int64 divisor_; +}; + +// Represents a calculation over integers based on results of other calculations +// defined by an opcode. If the opcode is kCopy, it simply wraps an +// MultiplyAddDivideOffsetCalculation. +class OffsetCalculation { + public: + OffsetCalculation() : opcode_(HloOpcode::kCopy), copy_from_() {} + explicit OffsetCalculation( + const MultiplyAddDivideOffsetCalculation& copy_from) + : opcode_(HloOpcode::kCopy), copy_from_(copy_from) {} + OffsetCalculation(const OffsetCalculation& copy_from) { *this = copy_from; } + OffsetCalculation(HloOpcode opcode, + const MultiplyAddDivideOffsetCalculation& lhs, + const MultiplyAddDivideOffsetCalculation& rhs) + : opcode_(opcode), + lhs_(absl::make_unique(lhs)), + rhs_(absl::make_unique(rhs)) {} + OffsetCalculation(HloOpcode opcode, const OffsetCalculation& lhs, + const OffsetCalculation& rhs) + : opcode_(opcode), + lhs_(absl::make_unique(lhs)), + rhs_(absl::make_unique(rhs)) {} + + OffsetCalculation& operator=(const OffsetCalculation& other); + + // Returns whether the calculation returns the same value for all shards. This + // is conservative and could return false even if it is actually constant. + bool IsConstant() const; + + OffsetCalculation operator-(const OffsetCalculation& other) const; + bool operator==(const OffsetCalculation& other) const; + int64 Calculate(int64 shard_ordinal) const; + HloInstruction* Calculate(HloInstruction* shard_ordinal, + SpmdBuilder* b) const; + + // Returns the maximum result for shard ordinals in the range + // [start_ordinal, limit_ordinal). + int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const; + + private: + HloOpcode opcode_; + std::unique_ptr lhs_; + std::unique_ptr rhs_; + MultiplyAddDivideOffsetCalculation copy_from_; +}; + +// Performs halo exchange on the given dimension based on the provided +// left/right halo size functions. Returns nullopt if the halo is beyond the +// direct neighbor of the shard. +absl::optional ExchangeHalo( + HloInstruction* hlo, const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, int64 dim, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b); + +// Exchange halo on all dimensions of the HLO. Returns nullopt if any one of the +// dimensions fails to exchange halo (halo is beyond the neighbor shard). +absl::optional ExchangeHalo( + HloInstruction* hlo, + std::vector left_halo_size_functions, + std::vector right_halo_size_functions, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b); + +// Exchanges halos and performs pad/dynamic-slice on the concatenated data such +// that the result starts with the first needed element on each shard. It also +// masks off invalid data due to padding. +// Arguments: +// hlo: the HLO op before halo exchange +// explicit_left_padding_on_full_shape: the amount of left padding to be added +// explicitly by this function on the base shape before partitioning. Without +// base dilation, this is usually set to the window's padding_low so that the +// sharded op do not need to add padding_low on the window; however, with base +// dilation, this could only be set to a custom size. +// padded_full_shape_size: the size of the padded full shape on the given +// dimension, which includes explicit_left_padding_on_full_shape and required +// right padding to make the shape evenly shardable. +// shard_size_with_halo: the shard size on the dimension after halo exchange. +// If different shards have different sizes, use the maximum size. +// offset_on_padded_shape: the offset HLO (S32) that represents the start of +// each shard on the padded full shape. +// pad_value: the padding value used on the full shape. +absl::optional ExchangeHaloAndGetValidData( + HloInstruction* hlo, const Shape& base_shape, + const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, + int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size, + int64 shard_size_with_halo, int64 dim, const HloSharding& target, + HloInstruction* offset_on_padded_shape, HloInstruction* pad_value, + HloInstruction* partition_ordinal, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region = true); + +} // namespace spmd +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_