Sets num_tasks value for use in TPUReplicate rewrite.
PiperOrigin-RevId: 358048719 Change-Id: Ic396a70ab1bd14d760fa120548bb01dc3a2dbe56
This commit is contained in:
parent
607ffbc56c
commit
24f8591d6e
@ -3934,6 +3934,7 @@ DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes(
|
|||||||
TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(),
|
TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(),
|
||||||
device_set, tpu_compilation_device,
|
device_set, tpu_compilation_device,
|
||||||
&num_tpus_per_task, &tpu_devices));
|
&num_tpus_per_task, &tpu_devices));
|
||||||
|
*num_tasks = tpu_devices.size();
|
||||||
|
|
||||||
string topology;
|
string topology;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
|
Loading…
Reference in New Issue
Block a user