diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index aed0add2be3..a183c3dc522 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -3934,6 +3934,7 @@ DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes( TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(), device_set, tpu_compilation_device, &num_tpus_per_task, &tpu_devices)); + *num_tasks = tpu_devices.size(); string topology; TF_RETURN_IF_ERROR(