hlo_sharding.cc. PiperOrigin-RevId: 306997392 Change-Id: I86d006f8004e9ebb93baf1cd79bac9443b8dabbc
		
			
				
	
	
		
			528 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			528 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License.
 | |
| ==============================================================================*/
 | |
| 
 | |
| #include "tensorflow/compiler/xla/service/hlo_sharding.h"
 | |
| 
 | |
| #include "absl/container/flat_hash_set.h"
 | |
| #include "absl/strings/str_cat.h"
 | |
| #include "absl/strings/str_join.h"
 | |
| #include "tensorflow/compiler/xla/overflow_util.h"
 | |
| #include "tensorflow/core/lib/core/errors.h"
 | |
| 
 | |
| namespace xla {
 | |
| 
 | |
| using absl::StrCat;
 | |
| using absl::StrJoin;
 | |
| 
 | |
| HloSharding HloSharding::AssignDevice(int64 device_id) {
 | |
|   return HloSharding(device_id);
 | |
| }
 | |
| 
 | |
| HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
 | |
|   CHECK_EQ(1, input_shape.rank());
 | |
|   CHECK_GT(num_tiles, 1);
 | |
|   std::vector<int64> dimensions(1, num_tiles);
 | |
|   Array<int64> assignment(dimensions);
 | |
|   std::iota(assignment.begin(), assignment.end(), 0);
 | |
|   return HloSharding(assignment);
 | |
| }
 | |
| 
 | |
| HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) {
 | |
|   std::vector<HloSharding> flattened_list;
 | |
|   flattened_list.reserve(sub_shardings.leaf_count());
 | |
|   for (const auto& index_to_sharding : sub_shardings.leaves()) {
 | |
|     flattened_list.push_back(index_to_sharding.second);
 | |
|   }
 | |
|   if (flattened_list.empty()) {
 | |
|     // Empty tuple sharding ends up having no leaves, but we want to allow
 | |
|     // empty tuple HLO instruction results to have sharding, so we fetch the
 | |
|     // root ({}) sharding value from the ShapeTree.
 | |
|     // A ShapeTree created with ShapeTree<HloSharding>(shape, init) will have
 | |
|     // init as value at its root.
 | |
|     flattened_list.push_back(sub_shardings.element(ShapeIndex({})));
 | |
|   }
 | |
|   return HloSharding(flattened_list);
 | |
| }
 | |
| 
 | |
| HloSharding HloSharding::Tuple(const Shape& tuple_shape,
 | |
|                                absl::Span<const HloSharding> shardings) {
 | |
|   CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
 | |
|   for (auto& sharding : shardings) {
 | |
|     CHECK(!sharding.IsTuple()) << sharding.ToString();
 | |
|   }
 | |
|   std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
 | |
|   CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape))
 | |
|       << "Flat list has " << flattened_list.size() << ", required "
 | |
|       << RequiredLeaves(tuple_shape);
 | |
|   return HloSharding(flattened_list);
 | |
| }
 | |
| 
 | |
| HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
 | |
|                                      const HloSharding& sharding) {
 | |
|   CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
 | |
|   CHECK(!sharding.IsTuple()) << sharding.ToString();
 | |
|   int64 leaf_count = RequiredLeaves(tuple_shape);
 | |
|   std::vector<HloSharding> flattened_list;
 | |
|   flattened_list.resize(leaf_count, sharding);
 | |
|   return HloSharding(flattened_list);
 | |
| }
 | |
| 
 | |
| HloSharding HloSharding::Single(const Shape& shape,
 | |
|                                 const HloSharding& sharding) {
 | |
|   return shape.IsTuple() ? SingleTuple(shape, sharding) : sharding;
 | |
| }
 | |
| 
 | |
| string HloSharding::ToString() const {
 | |
|   if (IsTuple()) {
 | |
|     std::vector<string> parts;
 | |
|     parts.reserve(tuple_elements_.size());
 | |
|     for (const HloSharding& element : tuple_elements_) {
 | |
|       parts.push_back(element.ToString());
 | |
|     }
 | |
|     return StrCat("{", absl::StrJoin(parts, ", "), "}");
 | |
|   }
 | |
| 
 | |
|   if (replicated_) {
 | |
|     return "{replicated}";
 | |
|   }
 | |
|   if (maximal_) {
 | |
|     return StrCat(
 | |
|         "{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}");
 | |
|   }
 | |
|   return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]",
 | |
|                 StrJoin(tile_assignment_, ","), "}");
 | |
| }
 | |
| 
 | |
| bool HloSharding::UsesDevice(int64 device) const {
 | |
|   if (IsTuple()) {
 | |
|     return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) {
 | |
|       return s.UsesDevice(device);
 | |
|     });
 | |
|   }
 | |
|   const auto& devices = tile_assignment_;
 | |
|   return replicated_ || absl::c_linear_search(devices, device);
 | |
| }
 | |
| 
 | |
| std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
 | |
|   int64 element_count = 1;
 | |
|   std::map<int64, int64> device_map;
 | |
|   if (IsTuple()) {
 | |
|     for (auto& tuple_element_sharding : tuple_elements()) {
 | |
|       auto unique_device = tuple_element_sharding.UniqueDevice();
 | |
|       if (unique_device) {
 | |
|         device_map[*unique_device] += 1;
 | |
|       }
 | |
|     }
 | |
|     element_count = tuple_elements().size();
 | |
|   } else {
 | |
|     auto unique_device = UniqueDevice();
 | |
|     if (unique_device) {
 | |
|       device_map[*unique_device] += 1;
 | |
|     }
 | |
|   }
 | |
|   if (count != nullptr) {
 | |
|     *count = element_count;
 | |
|   }
 | |
|   return device_map;
 | |
| }
 | |
| 
 | |
| std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
 | |
|   CHECK(!maximal_);
 | |
|   CHECK(!IsTuple());
 | |
|   std::vector<int64> ret_index;
 | |
|   tile_assignment_.Each([&](absl::Span<const int64> index, int64 d) {
 | |
|     if (d == device) {
 | |
|       ret_index = {index.begin(), index.end()};
 | |
|     }
 | |
|   });
 | |
|   CHECK(!ret_index.empty());
 | |
|   return ret_index;
 | |
| }
 | |
| 
 | |
| int64 HloSharding::DeviceForTileIndex(absl::Span<const int64> index) const {
 | |
|   CHECK(!replicated_);
 | |
|   CHECK(!IsTuple());
 | |
|   if (maximal_) {
 | |
|     return *tile_assignment_.begin();
 | |
|   }
 | |
|   return tile_assignment_(index);
 | |
| }
 | |
| 
 | |
| std::vector<int64> HloSharding::TileOffsetForDevice(const Shape& shape,
 | |
|                                                     int64 device) const {
 | |
|   CHECK(!IsTuple());
 | |
| 
 | |
|   if (maximal_) {
 | |
|     return std::vector<int64>(shape.dimensions_size(), 0);
 | |
|   }
 | |
| 
 | |
|   CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions());
 | |
|   std::vector<int64> index = TileIndexForDevice(device);
 | |
|   for (int64 i = 0; i < index.size(); ++i) {
 | |
|     const int64 shape_dim = shape.dimensions(i);
 | |
|     index[i] = std::min(
 | |
|         index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
 | |
|   }
 | |
|   return index;
 | |
| }
 | |
| 
 | |
| std::vector<int64> HloSharding::TileLimitForDevice(const Shape& shape,
 | |
|                                                    int64 device) const {
 | |
|   CHECK(!IsTuple());
 | |
| 
 | |
|   if (maximal_) {
 | |
|     return std::vector<int64>(shape.dimensions().begin(),
 | |
|                               shape.dimensions().end());
 | |
|   }
 | |
| 
 | |
|   CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions());
 | |
|   std::vector<int64> index = TileIndexForDevice(device);
 | |
|   for (int64 i = 0; i < index.size(); ++i) {
 | |
|     const int64 shape_dim = shape.dimensions(i);
 | |
|     index[i] = std::min(
 | |
|         (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
 | |
|         shape_dim);
 | |
|   }
 | |
|   return index;
 | |
| }
 | |
| 
 | |
| int64 HloSharding::RequiredLeaves(const Shape& shape) {
 | |
|   // Empty tuples (with arbitrary nesting) have no leaf nodes as far as
 | |
|   // ShapeUtil and ShapeTree are concerned, but they do have a single
 | |
|   // tuple_elements_ entry since we want to allow empty tuple results to
 | |
|   // have sharding.
 | |
|   const int64 leaf_count = ShapeUtil::GetLeafCount(shape);
 | |
|   return (leaf_count == 0) ? 1 : leaf_count;
 | |
| }
 | |
| 
 | |
| Status HloSharding::CheckLeafCount(const Shape& shape) const {
 | |
|   int64 shape_leaves = RequiredLeaves(shape);
 | |
|   TF_RET_CHECK(shape_leaves == tuple_elements_.size())
 | |
|       << "Shape " << ShapeUtil::HumanString(shape) << " has " << shape_leaves
 | |
|       << " leaf nodes while this sharding has " << tuple_elements_.size();
 | |
|   return Status::OK();
 | |
| }
 | |
| 
 | |
| StatusOr<ShapeTree<HloSharding>> HloSharding::AsShapeTree(
 | |
|     const Shape& shape) const {
 | |
|   if (IsTuple()) {
 | |
|     ShapeTree<HloSharding> result(shape, HloSharding::Replicate());
 | |
|     TF_RETURN_IF_ERROR(CheckLeafCount(shape));
 | |
|     auto it = tuple_elements_.begin();
 | |
|     for (auto& index_to_sharding : result.leaves()) {
 | |
|       index_to_sharding.second = *it++;
 | |
|     }
 | |
|     if (ShapeUtil::IsEmptyTuple(shape)) {
 | |
|       // Empty tuples have no leaves, but we want to assign them a sharding
 | |
|       // anyway, so we use the root element sharding.
 | |
|       *result.mutable_element(ShapeIndex({})) = *it;
 | |
|     }
 | |
|     return std::move(result);
 | |
|   } else {
 | |
|     return ShapeTree<HloSharding>(shape, *this);
 | |
|   }
 | |
| }
 | |
| 
 | |
| StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
 | |
|   if (IsTuple()) {
 | |
|     TF_RETURN_IF_ERROR(CheckLeafCount(shape));
 | |
|     return *this;
 | |
|   }
 | |
|   return Tuple(ShapeTree<HloSharding>(shape, *this));
 | |
| }
 | |
| 
 | |
| absl::optional<int64> HloSharding::UniqueDevice() const {
 | |
|   if (IsTuple()) {
 | |
|     if (tuple_elements_.empty()) {
 | |
|       return absl::nullopt;
 | |
|     }
 | |
|     absl::optional<int64> unique_device;
 | |
|     for (auto& tuple_sharding : tuple_elements_) {
 | |
|       auto device = tuple_sharding.UniqueDevice();
 | |
|       if (!device || (unique_device && *device != *unique_device)) {
 | |
|         return absl::nullopt;
 | |
|       }
 | |
|       unique_device = device;
 | |
|     }
 | |
|     return unique_device;
 | |
|   }
 | |
|   if (!replicated_ && maximal_) {
 | |
|     return static_cast<int64>(*tile_assignment_.begin());
 | |
|   }
 | |
|   return absl::nullopt;
 | |
| }
 | |
| 
 | |
| int64 HloSharding::GetUniqueDevice() const {
 | |
|   auto device = UniqueDevice();
 | |
|   CHECK(device) << "Sharding does not have a unique device: " << *this;
 | |
|   return *device;
 | |
| }
 | |
| 
 | |
| Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {
 | |
|   if (!shape.IsTuple()) {
 | |
|     return tensorflow::errors::InvalidArgument(
 | |
|         StrCat("Sharding is tuple-shaped but validation shape is not."));
 | |
|   }
 | |
|   TF_RETURN_IF_ERROR(CheckLeafCount(shape));
 | |
| 
 | |
|   // Now we've validated the number of tuple elements, it's safe to request a
 | |
|   // shape tree.
 | |
|   ShapeTree<HloSharding> shape_tree = GetAsShapeTree(shape);
 | |
|   for (const auto& index_to_sharding : shape_tree.leaves()) {
 | |
|     Status status = index_to_sharding.second.ValidateNonTuple(
 | |
|         ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices);
 | |
|     if (!status.ok()) {
 | |
|       tensorflow::errors::AppendToMessage(
 | |
|           &status, StrCat("Note: While validating sharding tuple element ",
 | |
|                           index_to_sharding.first.ToString(), " which is ",
 | |
|                           index_to_sharding.second.ToString()));
 | |
|       return status;
 | |
|     }
 | |
|   }
 | |
|   return Status::OK();
 | |
| }
 | |
| 
 | |
| Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
 | |
|   Status status = IsTuple() ? ValidateTuple(shape, num_devices)
 | |
|                             : ValidateNonTuple(shape, num_devices);
 | |
|   if (!status.ok()) {
 | |
|     tensorflow::errors::AppendToMessage(
 | |
|         &status, StrCat("Note: While validating sharding ", ToString(),
 | |
|                         " against shape ", ShapeUtil::HumanString(shape)));
 | |
|   }
 | |
|   return status;
 | |
| }
 | |
| 
 | |
| Status HloSharding::ValidateNonTuple(const Shape& shape,
 | |
|                                      int64 num_devices) const {
 | |
|   if (shape.IsTuple()) {
 | |
|     return tensorflow::errors::InvalidArgument(
 | |
|         StrCat("Validation shape is a tuple but sharding is not."));
 | |
|   }
 | |
|   if (replicated_) {
 | |
|     return Status::OK();
 | |
|   }
 | |
| 
 | |
|   // All tile assignments must be less than the number of available cores and
 | |
|   // unique.
 | |
|   Status status = Status::OK();
 | |
|   absl::flat_hash_set<int64> seen_cores;
 | |
|   tile_assignment_.Each(
 | |
|       [&](absl::Span<const int64> indices, int32 core) {
 | |
|         // Don't overwrite a bad status, so we report the first error.
 | |
|         if (status.ok()) {
 | |
|           if (core >= num_devices) {
 | |
|             status = tensorflow::errors::InvalidArgument(StrCat(
 | |
|                 "core ", core, " > ", num_devices, " in tile assignment"));
 | |
|           } else if (seen_cores.contains(core)) {
 | |
|             status = tensorflow::errors::InvalidArgument(
 | |
|                 StrCat("core ", core, " is not unique in tile assignment"));
 | |
|           }
 | |
|           seen_cores.insert(core);
 | |
|         }
 | |
|       });
 | |
|   if (!status.ok()) {
 | |
|     return status;
 | |
|   }
 | |
| 
 | |
|   if (IsTileMaximal()) {
 | |
|     return Status::OK();
 | |
|   }
 | |
| 
 | |
|   // The tile assignment tensor must have the same rank as the input.
 | |
|   if (shape.rank() != tile_assignment_.num_dimensions()) {
 | |
|     return tensorflow::errors::InvalidArgument(
 | |
|         "Number of tile assignment dimensions is different to the input rank. "
 | |
|         "sharding=",
 | |
|         ToString(), ", input_shape=", ShapeUtil::HumanString(shape));
 | |
|   }
 | |
| 
 | |
|   // The correct constructor has to be used to create tile maximal shardings.
 | |
|   if (tile_assignment_.num_elements() == 1) {
 | |
|     return tensorflow::errors::InvalidArgument(
 | |
|         "Tile assignment only contains a single device. If a replicated "
 | |
|         "sharding was intended, use HloSharding::Replicated(). If a device "
 | |
|         "placement was intended, use HloSharding::AssignDevice()");
 | |
|   }
 | |
|   return Status::OK();
 | |
| }
 | |
| 
 | |
| /*static*/ StatusOr<HloSharding> HloSharding::FromProto(
 | |
|     const OpSharding& proto) {
 | |
|   if (proto.type() == OpSharding::TUPLE) {
 | |
|     std::vector<HloSharding> tuple_shardings;
 | |
|     tuple_shardings.reserve(proto.tuple_shardings().size());
 | |
|     for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) {
 | |
|       TF_ASSIGN_OR_RETURN(HloSharding sharding,
 | |
|                           HloSharding::FromProto(tuple_sharding_proto));
 | |
|       tuple_shardings.push_back(sharding);
 | |
|     }
 | |
|     return HloSharding(tuple_shardings);
 | |
|   } else if (proto.type() == OpSharding::REPLICATED) {
 | |
|     return Replicate();
 | |
|   } else if (proto.tile_assignment_devices().size() == 1) {
 | |
|     return HloSharding(proto.tile_assignment_devices(0));
 | |
|   }
 | |
| 
 | |
|   TF_RET_CHECK(proto.type() != OpSharding::MAXIMAL)
 | |
|       << "Maximal sharding is expected to have single device assignment, but "
 | |
|       << proto.tile_assignment_devices().size() << " has provided.";
 | |
| 
 | |
|   TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
 | |
|   TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
 | |
| 
 | |
|   // RE: the product of tile assignment tensor dimensions must be
 | |
|   // equal to tile_assignment_devices.size().
 | |
|   int64 product_of_dimensions = 1;
 | |
|   for (auto dimension : proto.tile_assignment_dimensions()) {
 | |
|     TF_RET_CHECK(dimension > 0);
 | |
|     product_of_dimensions =
 | |
|         MultiplyWithoutOverflow(product_of_dimensions, dimension);
 | |
|     TF_RET_CHECK(product_of_dimensions > 0);
 | |
|   }
 | |
|   TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
 | |
| 
 | |
|   // Some versions of gcc cannot infer the TileAssignment constructor from a
 | |
|   // braced initializer-list, so create one manually.
 | |
|   std::vector<int64> devices(proto.tile_assignment_devices().begin(),
 | |
|                              proto.tile_assignment_devices().end());
 | |
|   Array<int64> tile_assignment(
 | |
|       std::vector<int64>(proto.tile_assignment_dimensions().begin(),
 | |
|                          proto.tile_assignment_dimensions().end()));
 | |
|   std::copy(proto.tile_assignment_devices().begin(),
 | |
|             proto.tile_assignment_devices().end(), tile_assignment.begin());
 | |
|   return HloSharding(tile_assignment);
 | |
| }
 | |
| 
 | |
| OpSharding HloSharding::ToProto() const {
 | |
|   OpSharding result;
 | |
| 
 | |
|   if (IsTuple()) {
 | |
|     for (const HloSharding& element : tuple_elements_) {
 | |
|       *result.add_tuple_shardings() = element.ToProto();
 | |
|     }
 | |
|     result.set_type(OpSharding::TUPLE);
 | |
|     return result;
 | |
|   }
 | |
| 
 | |
|   for (int64 dim : tile_assignment_.dimensions()) {
 | |
|     result.add_tile_assignment_dimensions(dim);
 | |
|   }
 | |
|   for (auto device : tile_assignment_) {
 | |
|     result.add_tile_assignment_devices(device);
 | |
|   }
 | |
|   if (IsReplicated()) {
 | |
|     result.set_type(OpSharding::REPLICATED);
 | |
|   } else if (IsTileMaximal()) {
 | |
|     result.set_type(OpSharding::MAXIMAL);
 | |
|   } else {
 | |
|     result.set_type(OpSharding::OTHER);
 | |
|   }
 | |
|   return result;
 | |
| }
 | |
| 
 | |
| Shape HloSharding::TileShape(const Shape& shape) const {
 | |
|   if (IsTileMaximal()) {
 | |
|     return shape;
 | |
|   }
 | |
|   Shape result_shape = shape;
 | |
|   for (int64 i = 0; i < shape.dimensions_size(); ++i) {
 | |
|     result_shape.set_dimensions(
 | |
|         i, CeilOfRatio<int64>(shape.dimensions(i), tile_assignment_.dim(i)));
 | |
|   }
 | |
|   return result_shape;
 | |
| }
 | |
| 
 | |
| Shape HloSharding::TileShape(const Shape& shape, int64 device) const {
 | |
|   if (IsTileMaximal()) {
 | |
|     return shape;
 | |
|   }
 | |
| 
 | |
|   std::vector<int64> index = TileIndexForDevice(device);
 | |
|   Shape result_shape = shape;
 | |
|   for (int64 i = 0; i < index.size(); ++i) {
 | |
|     const int64 shape_dim = shape.dimensions(i);
 | |
|     int64 offset = std::min(
 | |
|         index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
 | |
|     int64 limit = std::min(
 | |
|         (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
 | |
|         shape_dim);
 | |
|     result_shape.set_dimensions(i, limit - offset);
 | |
|   }
 | |
|   return result_shape;
 | |
| }
 | |
| 
 | |
| HloSharding HloSharding::GetSubSharding(const Shape& shape,
 | |
|                                         const ShapeIndex& index) const {
 | |
|   CHECK(IsTuple());
 | |
|   int64 sharding_index = 0;
 | |
|   const Shape* sub_shape = &shape;
 | |
|   for (int64 idx : index) {
 | |
|     for (int64 i = 0; i < idx; ++i) {
 | |
|       sharding_index +=
 | |
|           ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i}));
 | |
|     }
 | |
|     sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx});
 | |
|   }
 | |
|   if (sub_shape->IsTuple()) {
 | |
|     auto begin_it = tuple_elements_.begin() + sharding_index;
 | |
|     std::vector<HloSharding> sub_shardings(
 | |
|         begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape));
 | |
|     return HloSharding::Tuple(*sub_shape, sub_shardings);
 | |
|   } else {
 | |
|     return tuple_elements_[sharding_index];
 | |
|   }
 | |
| }
 | |
| 
 | |
| absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
 | |
|   if (!IsTuple()) {
 | |
|     return *this;
 | |
|   }
 | |
|   if (tuple_elements_.empty()) {
 | |
|     return absl::nullopt;
 | |
|   }
 | |
|   for (int64 i = 1; i < tuple_elements_.size(); ++i) {
 | |
|     if (tuple_elements_[0] != tuple_elements_[i]) {
 | |
|       return absl::nullopt;
 | |
|     }
 | |
|   }
 | |
|   return tuple_elements_.front();
 | |
| }
 | |
| 
 | |
| size_t HloSharding::Hash() const {
 | |
|   if (tuple_) {
 | |
|     size_t h = 0;
 | |
|     for (const auto& element : tuple_elements_) {
 | |
|       h = tensorflow::Hash64Combine(h, element.Hash());
 | |
|     }
 | |
|     return h;
 | |
|   }
 | |
|   if (replicated_) {
 | |
|     return 0;
 | |
|   }
 | |
|   size_t h = 0;
 | |
|   for (uint32 v : tile_assignment_) {
 | |
|     h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
 | |
|   }
 | |
|   return h;
 | |
| }
 | |
| 
 | |
| std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) {
 | |
|   out << sharding.ToString();
 | |
|   return out;
 | |
| }
 | |
| 
 | |
| }  // namespace xla
 |