Sets num_tasks value for use in TPUReplicate rewrite.
PiperOrigin-RevId: 358048719 Change-Id: Ic396a70ab1bd14d760fa120548bb01dc3a2dbe56
This commit is contained in:
parent
607ffbc56c
commit
24f8591d6e
@ -3934,6 +3934,7 @@ DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes(
|
||||
TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(),
|
||||
device_set, tpu_compilation_device,
|
||||
&num_tpus_per_task, &tpu_devices));
|
||||
*num_tasks = tpu_devices.size();
|
||||
|
||||
string topology;
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
Loading…
Reference in New Issue
Block a user