Cleanup the sharding unique device API.

PiperOrigin-RevId: 206885051
This commit is contained in:
A. Unique TensorFlower 2018-07-31 23:18:58 -07:00 committed by TensorFlower Gardener
parent 92279f8bfa
commit 26ba623dcc
7 changed files with 47 additions and 53 deletions

View File

@ -137,9 +137,9 @@ ENTRY entry {
if (instruction->opcode() == HloOpcode::kParameter) {
continue;
}
ASSERT_TRUE(instruction->has_sharding());
TF_ASSERT_OK_AND_ASSIGN(int device, instruction->sharding().UniqueDevice());
EXPECT_EQ(device, 1);
auto device = instruction->sharding_unique_device();
ASSERT_TRUE(device);
EXPECT_EQ(*device, 1);
}
}

View File

@ -1014,9 +1014,7 @@ class HloInstruction {
if (sharding_ == nullptr) {
return tensorflow::gtl::optional<int64>();
}
auto device = sharding_->UniqueDevice();
return device.ok() ? device.ValueOrDie()
: tensorflow::gtl::optional<int64>();
return sharding_->UniqueDevice();
}
// Sets the sharding of this operator. Should only be called by HloModule or
// HloComputation methods.

View File

@ -127,15 +127,15 @@ std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
if (IsTuple()) {
for (auto& tuple_element_sharding : tuple_elements()) {
auto unique_device = tuple_element_sharding.UniqueDevice();
if (unique_device.ok()) {
device_map[unique_device.ValueOrDie()] += 1;
if (unique_device) {
device_map[*unique_device] += 1;
}
}
element_count = tuple_elements().size();
} else {
auto unique_device = UniqueDevice();
if (unique_device.ok()) {
device_map[unique_device.ValueOrDie()] += 1;
if (unique_device) {
device_map[*unique_device] += 1;
}
}
if (count != nullptr) {
@ -238,40 +238,31 @@ StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
return Tuple(ShapeTree<HloSharding>(shape, *this));
}
StatusOr<int64> HloSharding::UniqueDevice() const {
tensorflow::gtl::optional<int64> HloSharding::UniqueDevice() const {
if (IsTuple()) {
if (tuple_elements_.empty()) {
return tensorflow::errors::InvalidArgument(
"UniqueDevice() called on empty tuple");
return tensorflow::gtl::nullopt;
}
std::vector<StatusOr<int64>> results;
std::transform(tuple_elements_.begin(), tuple_elements_.end(),
std::back_inserter(results),
[](const HloSharding& s) { return s.UniqueDevice(); });
if (std::all_of(results.begin(), results.end(),
[&](const StatusOr<int64>& s) {
return s.ok() && results[0].ok() &&
s.ValueOrDie() == results[0].ValueOrDie();
})) {
return results[0];
} else {
return tensorflow::errors::InvalidArgument(
"Tuple did not contain a unique device");
tensorflow::gtl::optional<int64> unique_device;
for (auto& tuple_sharding : tuple_elements_) {
auto device = tuple_sharding.UniqueDevice();
if (!device || (unique_device && *device != *unique_device)) {
return tensorflow::gtl::nullopt;
}
unique_device = device;
}
if (!replicated_ && maximal_ && !IsTuple()) {
return unique_device;
}
if (!replicated_ && maximal_) {
return static_cast<int64>(*tile_assignment_.begin());
}
return tensorflow::errors::InvalidArgument(
"UniqueDevice() called on sharding that executes on multiple devices");
return tensorflow::gtl::nullopt;
}
bool HloSharding::HasUniqueDevice() const {
if (IsTuple()) {
return UniqueDevice().status().ok();
} else {
return !IsReplicated() && IsTileMaximal();
}
int64 HloSharding::GetUniqueDevice() const {
auto device = UniqueDevice();
CHECK(device) << "Sharding does not have a unique device: " << *this;
return *device;
}
Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {

View File

@ -158,12 +158,17 @@ class HloSharding {
// REQUIRES: !IsTuple()
std::vector<int64> TileLimitForDevice(int64 device) const;
// Returns the single device this op operates on.
// REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal()
StatusOr<int64> UniqueDevice() const;
// Returns the single device this op operates on. If the sharding does not
// span a single device, the return value will be empty.
// In order for a sharding to span a single device, every leaf sharding must
// be maximal and not replicated, and the used device must match.
tensorflow::gtl::optional<int64> UniqueDevice() const;
// Retrieves the unique device or fails with a CHECK.
int64 GetUniqueDevice() const;
// Returns true if this op only uses a single device.
bool HasUniqueDevice() const;
bool HasUniqueDevice() const { return UniqueDevice().has_value(); }
// Returns the ShapeTree containing the shardings for each element of this
// tuple, if IsTuple, or a ShapeTree with a single element containing this

View File

@ -51,7 +51,7 @@ TEST_F(HloShardingTest, Replicate) {
EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
/*num_devices=*/2));
EXPECT_IS_NOT_OK(sharding.UniqueDevice());
EXPECT_FALSE(sharding.HasUniqueDevice());
}
TEST_F(HloShardingTest, DevicePlacement) {
@ -60,7 +60,7 @@ TEST_F(HloShardingTest, DevicePlacement) {
EXPECT_TRUE(sharding.IsTileMaximal());
EXPECT_FALSE(sharding.UsesDevice(0));
EXPECT_TRUE(sharding.UsesDevice(5));
EXPECT_EQ(5, sharding.UniqueDevice().ValueOrDie());
EXPECT_EQ(5, sharding.GetUniqueDevice());
HloSharding other = HloSharding::Replicate();
EXPECT_NE(other, sharding);
@ -123,7 +123,7 @@ TEST_F(HloShardingTest, Tile) {
EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector<int64>{2, 0}));
EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector<int64>{2, 3}));
EXPECT_IS_NOT_OK(sharding.UniqueDevice());
EXPECT_FALSE(sharding.HasUniqueDevice());
}
}

View File

@ -101,11 +101,11 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction(
}
};
string node_name;
if (debug_options_.xla_hlo_tfgraph_device_scopes() &&
instruction->has_sharding() &&
instruction->sharding().HasUniqueDevice()) {
node_name = StrCat(
"dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie());
if (debug_options_.xla_hlo_tfgraph_device_scopes()) {
auto device = instruction->sharding_unique_device();
if (device) {
node_name = StrCat("dev", *device);
}
}
// If an instruction is fused, put it in the subgraph of the fusion;
// otherwise, put it in the computation subgraph.
@ -215,10 +215,10 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) {
NodeDef* node_def = graph_def_.add_node();
node_def->set_name(GetNodeNameForInstruction(instruction));
node_def->set_op(GetOpDefName(instruction));
if (instruction->has_sharding() &&
instruction->sharding().HasUniqueDevice()) {
TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice());
node_def->set_device(GetDeviceName(device));
auto device = instruction->sharding_unique_device();
if (device) {
node_def->set_device(GetDeviceName(*device));
}
SetNodeAttrs(instruction, node_def);
if (instruction->opcode() == HloOpcode::kFusion) {

View File

@ -874,8 +874,8 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
// HostCompute module.
// Otherwise it is preferable to leave the new instruction without device,
// and let the automatic device placer to choose the best location.
if (!sharding.HasUniqueDevice() ||
HloSharding::IsReservedDevice(sharding.UniqueDevice().ValueOrDie())) {
auto device = sharding.UniqueDevice();
if (!device || HloSharding::IsReservedDevice(*device)) {
copy->set_sharding(sharding);
}
}