Sets num_tasks value for use in TPUReplicate rewrite.

PiperOrigin-RevId: 358048719
Change-Id: Ic396a70ab1bd14d760fa120548bb01dc3a2dbe56
This commit is contained in:
Tayo Oguntebi 2021-02-17 15:39:17 -08:00 committed by TensorFlower Gardener
parent 607ffbc56c
commit 24f8591d6e

View File

@ -3934,6 +3934,7 @@ DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes(
TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(),
device_set, tpu_compilation_device,
&num_tpus_per_task, &tpu_devices));
*num_tasks = tpu_devices.size();
string topology;
TF_RETURN_IF_ERROR(