Update TPU cluster formation pass to cast tf.TPUReplicatedInput op to access index attribute.

The class accessor can be used instead of a manual lookup, as inputs ops are already determined to be `tf.TPUReplicatedInput` prior.

PiperOrigin-RevId: 285307397
Change-Id: I74b08472db4f5df5c45a7fbbfc14f6f5cce48e0a
This commit is contained in:
Andy Ly 2019-12-12 17:32:21 -08:00 committed by TensorFlower Gardener
parent 4748c9e52b
commit 2d9c178f4e

View File

@ -59,7 +59,6 @@ constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
constexpr char kDeviceAttr[] = "device";
constexpr char kNameAttr[] = "name";
constexpr char kNumReplicasAttr[] = "num_replicas";
constexpr char kIndexAttr[] = "index";
constexpr char kBadTPUReplicateAttrMsg[] =
"requires '_tpu_replicate' string attribute";
@ -271,9 +270,8 @@ LogicalResult SortTPUReplicatedInputsByIndex(
int last_index = input_size - 1;
for (Operation* input : inputs) {
int64_t index = -1;
if (auto index_attr = input->getAttrOfType<IntegerAttr>(kIndexAttr))
index = index_attr.getInt();
int64_t index =
llvm::cast<TF::TPUReplicatedInputOp>(input).index().getLimitedValue();
if (index >= input_size || index < -1)
return input->emitError() << "'" << input->getName().getStringRef()