Infer sharding from neighbors only for maximal sharding

PiperOrigin-RevId: 283624742
Change-Id: I290f54778c2c5406d4161045f6e8d3df39ce96b1
This commit is contained in:
HyoukJoong Lee 2019-12-03 14:37:30 -08:00 committed by TensorFlower Gardener
parent 8cf05f47a5
commit 0365083580

View File

@ -503,8 +503,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
ParseShardingFromDevice(
*possible_match,
/*num_cores_per_replica=*/std::numeric_limits<int32>::max()));
if (sharding.has_value()) {
TF_RET_CHECK(sharding.value().type() == xla::OpSharding::MAXIMAL);
if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) {
const int core_annotation = sharding.value().tile_assignment_devices(0);
if (core == -1 || core > core_annotation) {
core = core_annotation;