Cleanup the sharding unique device API.
PiperOrigin-RevId: 206885051
This commit is contained in:
parent
92279f8bfa
commit
26ba623dcc
@ -137,9 +137,9 @@ ENTRY entry {
|
|||||||
if (instruction->opcode() == HloOpcode::kParameter) {
|
if (instruction->opcode() == HloOpcode::kParameter) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
ASSERT_TRUE(instruction->has_sharding());
|
auto device = instruction->sharding_unique_device();
|
||||||
TF_ASSERT_OK_AND_ASSIGN(int device, instruction->sharding().UniqueDevice());
|
ASSERT_TRUE(device);
|
||||||
EXPECT_EQ(device, 1);
|
EXPECT_EQ(*device, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1014,9 +1014,7 @@ class HloInstruction {
|
|||||||
if (sharding_ == nullptr) {
|
if (sharding_ == nullptr) {
|
||||||
return tensorflow::gtl::optional<int64>();
|
return tensorflow::gtl::optional<int64>();
|
||||||
}
|
}
|
||||||
auto device = sharding_->UniqueDevice();
|
return sharding_->UniqueDevice();
|
||||||
return device.ok() ? device.ValueOrDie()
|
|
||||||
: tensorflow::gtl::optional<int64>();
|
|
||||||
}
|
}
|
||||||
// Sets the sharding of this operator. Should only be called by HloModule or
|
// Sets the sharding of this operator. Should only be called by HloModule or
|
||||||
// HloComputation methods.
|
// HloComputation methods.
|
||||||
|
@ -127,15 +127,15 @@ std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
|
|||||||
if (IsTuple()) {
|
if (IsTuple()) {
|
||||||
for (auto& tuple_element_sharding : tuple_elements()) {
|
for (auto& tuple_element_sharding : tuple_elements()) {
|
||||||
auto unique_device = tuple_element_sharding.UniqueDevice();
|
auto unique_device = tuple_element_sharding.UniqueDevice();
|
||||||
if (unique_device.ok()) {
|
if (unique_device) {
|
||||||
device_map[unique_device.ValueOrDie()] += 1;
|
device_map[*unique_device] += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
element_count = tuple_elements().size();
|
element_count = tuple_elements().size();
|
||||||
} else {
|
} else {
|
||||||
auto unique_device = UniqueDevice();
|
auto unique_device = UniqueDevice();
|
||||||
if (unique_device.ok()) {
|
if (unique_device) {
|
||||||
device_map[unique_device.ValueOrDie()] += 1;
|
device_map[*unique_device] += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (count != nullptr) {
|
if (count != nullptr) {
|
||||||
@ -238,40 +238,31 @@ StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
|
|||||||
return Tuple(ShapeTree<HloSharding>(shape, *this));
|
return Tuple(ShapeTree<HloSharding>(shape, *this));
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<int64> HloSharding::UniqueDevice() const {
|
tensorflow::gtl::optional<int64> HloSharding::UniqueDevice() const {
|
||||||
if (IsTuple()) {
|
if (IsTuple()) {
|
||||||
if (tuple_elements_.empty()) {
|
if (tuple_elements_.empty()) {
|
||||||
return tensorflow::errors::InvalidArgument(
|
return tensorflow::gtl::nullopt;
|
||||||
"UniqueDevice() called on empty tuple");
|
|
||||||
}
|
}
|
||||||
std::vector<StatusOr<int64>> results;
|
tensorflow::gtl::optional<int64> unique_device;
|
||||||
std::transform(tuple_elements_.begin(), tuple_elements_.end(),
|
for (auto& tuple_sharding : tuple_elements_) {
|
||||||
std::back_inserter(results),
|
auto device = tuple_sharding.UniqueDevice();
|
||||||
[](const HloSharding& s) { return s.UniqueDevice(); });
|
if (!device || (unique_device && *device != *unique_device)) {
|
||||||
if (std::all_of(results.begin(), results.end(),
|
return tensorflow::gtl::nullopt;
|
||||||
[&](const StatusOr<int64>& s) {
|
}
|
||||||
return s.ok() && results[0].ok() &&
|
unique_device = device;
|
||||||
s.ValueOrDie() == results[0].ValueOrDie();
|
|
||||||
})) {
|
|
||||||
return results[0];
|
|
||||||
} else {
|
|
||||||
return tensorflow::errors::InvalidArgument(
|
|
||||||
"Tuple did not contain a unique device");
|
|
||||||
}
|
}
|
||||||
|
return unique_device;
|
||||||
}
|
}
|
||||||
if (!replicated_ && maximal_ && !IsTuple()) {
|
if (!replicated_ && maximal_) {
|
||||||
return static_cast<int64>(*tile_assignment_.begin());
|
return static_cast<int64>(*tile_assignment_.begin());
|
||||||
}
|
}
|
||||||
return tensorflow::errors::InvalidArgument(
|
return tensorflow::gtl::nullopt;
|
||||||
"UniqueDevice() called on sharding that executes on multiple devices");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HloSharding::HasUniqueDevice() const {
|
int64 HloSharding::GetUniqueDevice() const {
|
||||||
if (IsTuple()) {
|
auto device = UniqueDevice();
|
||||||
return UniqueDevice().status().ok();
|
CHECK(device) << "Sharding does not have a unique device: " << *this;
|
||||||
} else {
|
return *device;
|
||||||
return !IsReplicated() && IsTileMaximal();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {
|
Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {
|
||||||
|
@ -158,12 +158,17 @@ class HloSharding {
|
|||||||
// REQUIRES: !IsTuple()
|
// REQUIRES: !IsTuple()
|
||||||
std::vector<int64> TileLimitForDevice(int64 device) const;
|
std::vector<int64> TileLimitForDevice(int64 device) const;
|
||||||
|
|
||||||
// Returns the single device this op operates on.
|
// Returns the single device this op operates on. If the sharding does not
|
||||||
// REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal()
|
// span a single device, the return value will be empty.
|
||||||
StatusOr<int64> UniqueDevice() const;
|
// In order for a sharding to span a single device, every leaf sharding must
|
||||||
|
// be maximal and not replicated, and the used device must match.
|
||||||
|
tensorflow::gtl::optional<int64> UniqueDevice() const;
|
||||||
|
|
||||||
|
// Retrieves the unique device or fails with a CHECK.
|
||||||
|
int64 GetUniqueDevice() const;
|
||||||
|
|
||||||
// Returns true if this op only uses a single device.
|
// Returns true if this op only uses a single device.
|
||||||
bool HasUniqueDevice() const;
|
bool HasUniqueDevice() const { return UniqueDevice().has_value(); }
|
||||||
|
|
||||||
// Returns the ShapeTree containing the shardings for each element of this
|
// Returns the ShapeTree containing the shardings for each element of this
|
||||||
// tuple, if IsTuple, or a ShapeTree with a single element containing this
|
// tuple, if IsTuple, or a ShapeTree with a single element containing this
|
||||||
|
@ -51,7 +51,7 @@ TEST_F(HloShardingTest, Replicate) {
|
|||||||
|
|
||||||
EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
|
EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
|
||||||
/*num_devices=*/2));
|
/*num_devices=*/2));
|
||||||
EXPECT_IS_NOT_OK(sharding.UniqueDevice());
|
EXPECT_FALSE(sharding.HasUniqueDevice());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloShardingTest, DevicePlacement) {
|
TEST_F(HloShardingTest, DevicePlacement) {
|
||||||
@ -60,7 +60,7 @@ TEST_F(HloShardingTest, DevicePlacement) {
|
|||||||
EXPECT_TRUE(sharding.IsTileMaximal());
|
EXPECT_TRUE(sharding.IsTileMaximal());
|
||||||
EXPECT_FALSE(sharding.UsesDevice(0));
|
EXPECT_FALSE(sharding.UsesDevice(0));
|
||||||
EXPECT_TRUE(sharding.UsesDevice(5));
|
EXPECT_TRUE(sharding.UsesDevice(5));
|
||||||
EXPECT_EQ(5, sharding.UniqueDevice().ValueOrDie());
|
EXPECT_EQ(5, sharding.GetUniqueDevice());
|
||||||
|
|
||||||
HloSharding other = HloSharding::Replicate();
|
HloSharding other = HloSharding::Replicate();
|
||||||
EXPECT_NE(other, sharding);
|
EXPECT_NE(other, sharding);
|
||||||
@ -123,7 +123,7 @@ TEST_F(HloShardingTest, Tile) {
|
|||||||
EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector<int64>{2, 0}));
|
EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector<int64>{2, 0}));
|
||||||
EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector<int64>{2, 3}));
|
EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector<int64>{2, 3}));
|
||||||
|
|
||||||
EXPECT_IS_NOT_OK(sharding.UniqueDevice());
|
EXPECT_FALSE(sharding.HasUniqueDevice());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,11 +101,11 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
string node_name;
|
string node_name;
|
||||||
if (debug_options_.xla_hlo_tfgraph_device_scopes() &&
|
if (debug_options_.xla_hlo_tfgraph_device_scopes()) {
|
||||||
instruction->has_sharding() &&
|
auto device = instruction->sharding_unique_device();
|
||||||
instruction->sharding().HasUniqueDevice()) {
|
if (device) {
|
||||||
node_name = StrCat(
|
node_name = StrCat("dev", *device);
|
||||||
"dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie());
|
}
|
||||||
}
|
}
|
||||||
// If an instruction is fused, put it in the subgraph of the fusion;
|
// If an instruction is fused, put it in the subgraph of the fusion;
|
||||||
// otherwise, put it in the computation subgraph.
|
// otherwise, put it in the computation subgraph.
|
||||||
@ -215,10 +215,10 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) {
|
|||||||
NodeDef* node_def = graph_def_.add_node();
|
NodeDef* node_def = graph_def_.add_node();
|
||||||
node_def->set_name(GetNodeNameForInstruction(instruction));
|
node_def->set_name(GetNodeNameForInstruction(instruction));
|
||||||
node_def->set_op(GetOpDefName(instruction));
|
node_def->set_op(GetOpDefName(instruction));
|
||||||
if (instruction->has_sharding() &&
|
|
||||||
instruction->sharding().HasUniqueDevice()) {
|
auto device = instruction->sharding_unique_device();
|
||||||
TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice());
|
if (device) {
|
||||||
node_def->set_device(GetDeviceName(device));
|
node_def->set_device(GetDeviceName(*device));
|
||||||
}
|
}
|
||||||
SetNodeAttrs(instruction, node_def);
|
SetNodeAttrs(instruction, node_def);
|
||||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||||
|
@ -874,8 +874,8 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
|
|||||||
// HostCompute module.
|
// HostCompute module.
|
||||||
// Otherwise it is preferable to leave the new instruction without device,
|
// Otherwise it is preferable to leave the new instruction without device,
|
||||||
// and let the automatic device placer to choose the best location.
|
// and let the automatic device placer to choose the best location.
|
||||||
if (!sharding.HasUniqueDevice() ||
|
auto device = sharding.UniqueDevice();
|
||||||
HloSharding::IsReservedDevice(sharding.UniqueDevice().ValueOrDie())) {
|
if (!device || HloSharding::IsReservedDevice(*device)) {
|
||||||
copy->set_sharding(sharding);
|
copy->set_sharding(sharding);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user