From 24f8591d6ed36a61470d4132e949b282243b0d23 Mon Sep 17 00:00:00 2001 From: Tayo Oguntebi Date: Wed, 17 Feb 2021 15:39:17 -0800 Subject: [PATCH] Sets num_tasks value for use in TPUReplicate rewrite. PiperOrigin-RevId: 358048719 Change-Id: Ic396a70ab1bd14d760fa120548bb01dc3a2dbe56 --- .../core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index aed0add2be3..a183c3dc522 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -3934,6 +3934,7 @@ DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes( TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(), device_set, tpu_compilation_device, &num_tpus_per_task, &tpu_devices)); + *num_tasks = tpu_devices.size(); string topology; TF_RETURN_IF_ERROR(