Update TPU cluster formation pass to cast tf.TPUReplicatedInput
op to access index
attribute.
The class accessor can be used instead of a manual lookup, as inputs ops are already determined to be `tf.TPUReplicatedInput` prior. PiperOrigin-RevId: 285307397 Change-Id: I74b08472db4f5df5c45a7fbbfc14f6f5cce48e0a
This commit is contained in:
parent
4748c9e52b
commit
2d9c178f4e
@ -59,7 +59,6 @@ constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
|
||||
constexpr char kDeviceAttr[] = "device";
|
||||
constexpr char kNameAttr[] = "name";
|
||||
constexpr char kNumReplicasAttr[] = "num_replicas";
|
||||
constexpr char kIndexAttr[] = "index";
|
||||
|
||||
constexpr char kBadTPUReplicateAttrMsg[] =
|
||||
"requires '_tpu_replicate' string attribute";
|
||||
@ -271,9 +270,8 @@ LogicalResult SortTPUReplicatedInputsByIndex(
|
||||
int last_index = input_size - 1;
|
||||
|
||||
for (Operation* input : inputs) {
|
||||
int64_t index = -1;
|
||||
if (auto index_attr = input->getAttrOfType<IntegerAttr>(kIndexAttr))
|
||||
index = index_attr.getInt();
|
||||
int64_t index =
|
||||
llvm::cast<TF::TPUReplicatedInputOp>(input).index().getLimitedValue();
|
||||
|
||||
if (index >= input_size || index < -1)
|
||||
return input->emitError() << "'" << input->getName().getStringRef()
|
||||
|
Loading…
Reference in New Issue
Block a user