Cleanup the sharding unique device API.
PiperOrigin-RevId: 206885051
This commit is contained in:
parent
92279f8bfa
commit
26ba623dcc
@ -137,9 +137,9 @@ ENTRY entry {
|
||||
if (instruction->opcode() == HloOpcode::kParameter) {
|
||||
continue;
|
||||
}
|
||||
ASSERT_TRUE(instruction->has_sharding());
|
||||
TF_ASSERT_OK_AND_ASSIGN(int device, instruction->sharding().UniqueDevice());
|
||||
EXPECT_EQ(device, 1);
|
||||
auto device = instruction->sharding_unique_device();
|
||||
ASSERT_TRUE(device);
|
||||
EXPECT_EQ(*device, 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1014,9 +1014,7 @@ class HloInstruction {
|
||||
if (sharding_ == nullptr) {
|
||||
return tensorflow::gtl::optional<int64>();
|
||||
}
|
||||
auto device = sharding_->UniqueDevice();
|
||||
return device.ok() ? device.ValueOrDie()
|
||||
: tensorflow::gtl::optional<int64>();
|
||||
return sharding_->UniqueDevice();
|
||||
}
|
||||
// Sets the sharding of this operator. Should only be called by HloModule or
|
||||
// HloComputation methods.
|
||||
|
@ -127,15 +127,15 @@ std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
|
||||
if (IsTuple()) {
|
||||
for (auto& tuple_element_sharding : tuple_elements()) {
|
||||
auto unique_device = tuple_element_sharding.UniqueDevice();
|
||||
if (unique_device.ok()) {
|
||||
device_map[unique_device.ValueOrDie()] += 1;
|
||||
if (unique_device) {
|
||||
device_map[*unique_device] += 1;
|
||||
}
|
||||
}
|
||||
element_count = tuple_elements().size();
|
||||
} else {
|
||||
auto unique_device = UniqueDevice();
|
||||
if (unique_device.ok()) {
|
||||
device_map[unique_device.ValueOrDie()] += 1;
|
||||
if (unique_device) {
|
||||
device_map[*unique_device] += 1;
|
||||
}
|
||||
}
|
||||
if (count != nullptr) {
|
||||
@ -238,40 +238,31 @@ StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
|
||||
return Tuple(ShapeTree<HloSharding>(shape, *this));
|
||||
}
|
||||
|
||||
StatusOr<int64> HloSharding::UniqueDevice() const {
|
||||
tensorflow::gtl::optional<int64> HloSharding::UniqueDevice() const {
|
||||
if (IsTuple()) {
|
||||
if (tuple_elements_.empty()) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"UniqueDevice() called on empty tuple");
|
||||
return tensorflow::gtl::nullopt;
|
||||
}
|
||||
std::vector<StatusOr<int64>> results;
|
||||
std::transform(tuple_elements_.begin(), tuple_elements_.end(),
|
||||
std::back_inserter(results),
|
||||
[](const HloSharding& s) { return s.UniqueDevice(); });
|
||||
if (std::all_of(results.begin(), results.end(),
|
||||
[&](const StatusOr<int64>& s) {
|
||||
return s.ok() && results[0].ok() &&
|
||||
s.ValueOrDie() == results[0].ValueOrDie();
|
||||
})) {
|
||||
return results[0];
|
||||
} else {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Tuple did not contain a unique device");
|
||||
tensorflow::gtl::optional<int64> unique_device;
|
||||
for (auto& tuple_sharding : tuple_elements_) {
|
||||
auto device = tuple_sharding.UniqueDevice();
|
||||
if (!device || (unique_device && *device != *unique_device)) {
|
||||
return tensorflow::gtl::nullopt;
|
||||
}
|
||||
unique_device = device;
|
||||
}
|
||||
return unique_device;
|
||||
}
|
||||
if (!replicated_ && maximal_ && !IsTuple()) {
|
||||
if (!replicated_ && maximal_) {
|
||||
return static_cast<int64>(*tile_assignment_.begin());
|
||||
}
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"UniqueDevice() called on sharding that executes on multiple devices");
|
||||
return tensorflow::gtl::nullopt;
|
||||
}
|
||||
|
||||
bool HloSharding::HasUniqueDevice() const {
|
||||
if (IsTuple()) {
|
||||
return UniqueDevice().status().ok();
|
||||
} else {
|
||||
return !IsReplicated() && IsTileMaximal();
|
||||
}
|
||||
int64 HloSharding::GetUniqueDevice() const {
|
||||
auto device = UniqueDevice();
|
||||
CHECK(device) << "Sharding does not have a unique device: " << *this;
|
||||
return *device;
|
||||
}
|
||||
|
||||
Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {
|
||||
|
@ -158,12 +158,17 @@ class HloSharding {
|
||||
// REQUIRES: !IsTuple()
|
||||
std::vector<int64> TileLimitForDevice(int64 device) const;
|
||||
|
||||
// Returns the single device this op operates on.
|
||||
// REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal()
|
||||
StatusOr<int64> UniqueDevice() const;
|
||||
// Returns the single device this op operates on. If the sharding does not
|
||||
// span a single device, the return value will be empty.
|
||||
// In order for a sharding to span a single device, every leaf sharding must
|
||||
// be maximal and not replicated, and the used device must match.
|
||||
tensorflow::gtl::optional<int64> UniqueDevice() const;
|
||||
|
||||
// Retrieves the unique device or fails with a CHECK.
|
||||
int64 GetUniqueDevice() const;
|
||||
|
||||
// Returns true if this op only uses a single device.
|
||||
bool HasUniqueDevice() const;
|
||||
bool HasUniqueDevice() const { return UniqueDevice().has_value(); }
|
||||
|
||||
// Returns the ShapeTree containing the shardings for each element of this
|
||||
// tuple, if IsTuple, or a ShapeTree with a single element containing this
|
||||
|
@ -51,7 +51,7 @@ TEST_F(HloShardingTest, Replicate) {
|
||||
|
||||
EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
|
||||
/*num_devices=*/2));
|
||||
EXPECT_IS_NOT_OK(sharding.UniqueDevice());
|
||||
EXPECT_FALSE(sharding.HasUniqueDevice());
|
||||
}
|
||||
|
||||
TEST_F(HloShardingTest, DevicePlacement) {
|
||||
@ -60,7 +60,7 @@ TEST_F(HloShardingTest, DevicePlacement) {
|
||||
EXPECT_TRUE(sharding.IsTileMaximal());
|
||||
EXPECT_FALSE(sharding.UsesDevice(0));
|
||||
EXPECT_TRUE(sharding.UsesDevice(5));
|
||||
EXPECT_EQ(5, sharding.UniqueDevice().ValueOrDie());
|
||||
EXPECT_EQ(5, sharding.GetUniqueDevice());
|
||||
|
||||
HloSharding other = HloSharding::Replicate();
|
||||
EXPECT_NE(other, sharding);
|
||||
@ -123,7 +123,7 @@ TEST_F(HloShardingTest, Tile) {
|
||||
EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector<int64>{2, 0}));
|
||||
EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector<int64>{2, 3}));
|
||||
|
||||
EXPECT_IS_NOT_OK(sharding.UniqueDevice());
|
||||
EXPECT_FALSE(sharding.HasUniqueDevice());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,11 +101,11 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction(
|
||||
}
|
||||
};
|
||||
string node_name;
|
||||
if (debug_options_.xla_hlo_tfgraph_device_scopes() &&
|
||||
instruction->has_sharding() &&
|
||||
instruction->sharding().HasUniqueDevice()) {
|
||||
node_name = StrCat(
|
||||
"dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie());
|
||||
if (debug_options_.xla_hlo_tfgraph_device_scopes()) {
|
||||
auto device = instruction->sharding_unique_device();
|
||||
if (device) {
|
||||
node_name = StrCat("dev", *device);
|
||||
}
|
||||
}
|
||||
// If an instruction is fused, put it in the subgraph of the fusion;
|
||||
// otherwise, put it in the computation subgraph.
|
||||
@ -215,10 +215,10 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) {
|
||||
NodeDef* node_def = graph_def_.add_node();
|
||||
node_def->set_name(GetNodeNameForInstruction(instruction));
|
||||
node_def->set_op(GetOpDefName(instruction));
|
||||
if (instruction->has_sharding() &&
|
||||
instruction->sharding().HasUniqueDevice()) {
|
||||
TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice());
|
||||
node_def->set_device(GetDeviceName(device));
|
||||
|
||||
auto device = instruction->sharding_unique_device();
|
||||
if (device) {
|
||||
node_def->set_device(GetDeviceName(*device));
|
||||
}
|
||||
SetNodeAttrs(instruction, node_def);
|
||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||
|
@ -874,8 +874,8 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
|
||||
// HostCompute module.
|
||||
// Otherwise it is preferable to leave the new instruction without device,
|
||||
// and let the automatic device placer to choose the best location.
|
||||
if (!sharding.HasUniqueDevice() ||
|
||||
HloSharding::IsReservedDevice(sharding.UniqueDevice().ValueOrDie())) {
|
||||
auto device = sharding.UniqueDevice();
|
||||
if (!device || HloSharding::IsReservedDevice(*device)) {
|
||||
copy->set_sharding(sharding);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user