Cleanup the sharding unique device API.

PiperOrigin-RevId: 206885051
This commit is contained in:
A. Unique TensorFlower 2018-07-31 23:18:58 -07:00 committed by TensorFlower Gardener
parent 92279f8bfa
commit 26ba623dcc
7 changed files with 47 additions and 53 deletions

View File

@ -137,9 +137,9 @@ ENTRY entry {
if (instruction->opcode() == HloOpcode::kParameter) { if (instruction->opcode() == HloOpcode::kParameter) {
continue; continue;
} }
ASSERT_TRUE(instruction->has_sharding()); auto device = instruction->sharding_unique_device();
TF_ASSERT_OK_AND_ASSIGN(int device, instruction->sharding().UniqueDevice()); ASSERT_TRUE(device);
EXPECT_EQ(device, 1); EXPECT_EQ(*device, 1);
} }
} }

View File

@ -1014,9 +1014,7 @@ class HloInstruction {
if (sharding_ == nullptr) { if (sharding_ == nullptr) {
return tensorflow::gtl::optional<int64>(); return tensorflow::gtl::optional<int64>();
} }
auto device = sharding_->UniqueDevice(); return sharding_->UniqueDevice();
return device.ok() ? device.ValueOrDie()
: tensorflow::gtl::optional<int64>();
} }
// Sets the sharding of this operator. Should only be called by HloModule or // Sets the sharding of this operator. Should only be called by HloModule or
// HloComputation methods. // HloComputation methods.

View File

@ -127,15 +127,15 @@ std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
if (IsTuple()) { if (IsTuple()) {
for (auto& tuple_element_sharding : tuple_elements()) { for (auto& tuple_element_sharding : tuple_elements()) {
auto unique_device = tuple_element_sharding.UniqueDevice(); auto unique_device = tuple_element_sharding.UniqueDevice();
if (unique_device.ok()) { if (unique_device) {
device_map[unique_device.ValueOrDie()] += 1; device_map[*unique_device] += 1;
} }
} }
element_count = tuple_elements().size(); element_count = tuple_elements().size();
} else { } else {
auto unique_device = UniqueDevice(); auto unique_device = UniqueDevice();
if (unique_device.ok()) { if (unique_device) {
device_map[unique_device.ValueOrDie()] += 1; device_map[*unique_device] += 1;
} }
} }
if (count != nullptr) { if (count != nullptr) {
@ -238,40 +238,31 @@ StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
return Tuple(ShapeTree<HloSharding>(shape, *this)); return Tuple(ShapeTree<HloSharding>(shape, *this));
} }
StatusOr<int64> HloSharding::UniqueDevice() const { tensorflow::gtl::optional<int64> HloSharding::UniqueDevice() const {
if (IsTuple()) { if (IsTuple()) {
if (tuple_elements_.empty()) { if (tuple_elements_.empty()) {
return tensorflow::errors::InvalidArgument( return tensorflow::gtl::nullopt;
"UniqueDevice() called on empty tuple");
} }
std::vector<StatusOr<int64>> results; tensorflow::gtl::optional<int64> unique_device;
std::transform(tuple_elements_.begin(), tuple_elements_.end(), for (auto& tuple_sharding : tuple_elements_) {
std::back_inserter(results), auto device = tuple_sharding.UniqueDevice();
[](const HloSharding& s) { return s.UniqueDevice(); }); if (!device || (unique_device && *device != *unique_device)) {
if (std::all_of(results.begin(), results.end(), return tensorflow::gtl::nullopt;
[&](const StatusOr<int64>& s) { }
return s.ok() && results[0].ok() && unique_device = device;
s.ValueOrDie() == results[0].ValueOrDie();
})) {
return results[0];
} else {
return tensorflow::errors::InvalidArgument(
"Tuple did not contain a unique device");
} }
return unique_device;
} }
if (!replicated_ && maximal_ && !IsTuple()) { if (!replicated_ && maximal_) {
return static_cast<int64>(*tile_assignment_.begin()); return static_cast<int64>(*tile_assignment_.begin());
} }
return tensorflow::errors::InvalidArgument( return tensorflow::gtl::nullopt;
"UniqueDevice() called on sharding that executes on multiple devices");
} }
bool HloSharding::HasUniqueDevice() const { int64 HloSharding::GetUniqueDevice() const {
if (IsTuple()) { auto device = UniqueDevice();
return UniqueDevice().status().ok(); CHECK(device) << "Sharding does not have a unique device: " << *this;
} else { return *device;
return !IsReplicated() && IsTileMaximal();
}
} }
Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {

View File

@ -158,12 +158,17 @@ class HloSharding {
// REQUIRES: !IsTuple() // REQUIRES: !IsTuple()
std::vector<int64> TileLimitForDevice(int64 device) const; std::vector<int64> TileLimitForDevice(int64 device) const;
// Returns the single device this op operates on. // Returns the single device this op operates on. If the sharding does not
// REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal() // span a single device, the return value will be empty.
StatusOr<int64> UniqueDevice() const; // In order for a sharding to span a single device, every leaf sharding must
// be maximal and not replicated, and the used device must match.
tensorflow::gtl::optional<int64> UniqueDevice() const;
// Retrieves the unique device or fails with a CHECK.
int64 GetUniqueDevice() const;
// Returns true if this op only uses a single device. // Returns true if this op only uses a single device.
bool HasUniqueDevice() const; bool HasUniqueDevice() const { return UniqueDevice().has_value(); }
// Returns the ShapeTree containing the shardings for each element of this // Returns the ShapeTree containing the shardings for each element of this
// tuple, if IsTuple, or a ShapeTree with a single element containing this // tuple, if IsTuple, or a ShapeTree with a single element containing this

View File

@ -51,7 +51,7 @@ TEST_F(HloShardingTest, Replicate) {
EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}), EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
/*num_devices=*/2)); /*num_devices=*/2));
EXPECT_IS_NOT_OK(sharding.UniqueDevice()); EXPECT_FALSE(sharding.HasUniqueDevice());
} }
TEST_F(HloShardingTest, DevicePlacement) { TEST_F(HloShardingTest, DevicePlacement) {
@ -60,7 +60,7 @@ TEST_F(HloShardingTest, DevicePlacement) {
EXPECT_TRUE(sharding.IsTileMaximal()); EXPECT_TRUE(sharding.IsTileMaximal());
EXPECT_FALSE(sharding.UsesDevice(0)); EXPECT_FALSE(sharding.UsesDevice(0));
EXPECT_TRUE(sharding.UsesDevice(5)); EXPECT_TRUE(sharding.UsesDevice(5));
EXPECT_EQ(5, sharding.UniqueDevice().ValueOrDie()); EXPECT_EQ(5, sharding.GetUniqueDevice());
HloSharding other = HloSharding::Replicate(); HloSharding other = HloSharding::Replicate();
EXPECT_NE(other, sharding); EXPECT_NE(other, sharding);
@ -123,7 +123,7 @@ TEST_F(HloShardingTest, Tile) {
EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector<int64>{2, 0})); EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector<int64>{2, 0}));
EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector<int64>{2, 3})); EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector<int64>{2, 3}));
EXPECT_IS_NOT_OK(sharding.UniqueDevice()); EXPECT_FALSE(sharding.HasUniqueDevice());
} }
} }

View File

@ -101,11 +101,11 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction(
} }
}; };
string node_name; string node_name;
if (debug_options_.xla_hlo_tfgraph_device_scopes() && if (debug_options_.xla_hlo_tfgraph_device_scopes()) {
instruction->has_sharding() && auto device = instruction->sharding_unique_device();
instruction->sharding().HasUniqueDevice()) { if (device) {
node_name = StrCat( node_name = StrCat("dev", *device);
"dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie()); }
} }
// If an instruction is fused, put it in the subgraph of the fusion; // If an instruction is fused, put it in the subgraph of the fusion;
// otherwise, put it in the computation subgraph. // otherwise, put it in the computation subgraph.
@ -215,10 +215,10 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) {
NodeDef* node_def = graph_def_.add_node(); NodeDef* node_def = graph_def_.add_node();
node_def->set_name(GetNodeNameForInstruction(instruction)); node_def->set_name(GetNodeNameForInstruction(instruction));
node_def->set_op(GetOpDefName(instruction)); node_def->set_op(GetOpDefName(instruction));
if (instruction->has_sharding() &&
instruction->sharding().HasUniqueDevice()) { auto device = instruction->sharding_unique_device();
TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice()); if (device) {
node_def->set_device(GetDeviceName(device)); node_def->set_device(GetDeviceName(*device));
} }
SetNodeAttrs(instruction, node_def); SetNodeAttrs(instruction, node_def);
if (instruction->opcode() == HloOpcode::kFusion) { if (instruction->opcode() == HloOpcode::kFusion) {

View File

@ -874,8 +874,8 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
// HostCompute module. // HostCompute module.
// Otherwise it is preferable to leave the new instruction without device, // Otherwise it is preferable to leave the new instruction without device,
// and let the automatic device placer to choose the best location. // and let the automatic device placer to choose the best location.
if (!sharding.HasUniqueDevice() || auto device = sharding.UniqueDevice();
HloSharding::IsReservedDevice(sharding.UniqueDevice().ValueOrDie())) { if (!device || HloSharding::IsReservedDevice(*device)) {
copy->set_sharding(sharding); copy->set_sharding(sharding);
} }
} }