diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index 40aff06678e..9e80e69885e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -306,8 +306,8 @@ void IdentifyXlaShardingForTPUComputation( xla::sharding_builder::AssignDevice(0).SerializeAsString(); bool use_spmd = false; - if (auto use_spmd_attr = - cluster_func.getAttrOfType("use_spmd_for_xla_partitioning")) + if (auto use_spmd_attr = cluster_func->getAttrOfType( + "use_spmd_for_xla_partitioning")) use_spmd = use_spmd_attr.getValue(); IdentifyXlaShardingForComputationInputs(logical_core_0_sharding, use_spmd,