Enable NcclGather and NcclBroadcast.
PiperOrigin-RevId: 355859246 Change-Id: I80831dbecb4858073c8bf61906fac9b651868bf9
This commit is contained in:
parent
844002ed99
commit
9922e83047
@ -59,13 +59,13 @@ namespace {
|
|||||||
const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
|
const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
|
||||||
switch (cp->instance.type) {
|
switch (cp->instance.type) {
|
||||||
case BROADCAST_COLLECTIVE:
|
case BROADCAST_COLLECTIVE:
|
||||||
return "HierarchicalTreeBroadcast";
|
return nccl ? "NcclBroadcast" : "HierarchicalTreeBroadcast";
|
||||||
|
|
||||||
case REDUCTION_COLLECTIVE:
|
case REDUCTION_COLLECTIVE:
|
||||||
return nccl ? "NcclReduce" : "RingReduce";
|
return nccl ? "NcclReduce" : "RingReduce";
|
||||||
|
|
||||||
case GATHER_COLLECTIVE:
|
case GATHER_COLLECTIVE:
|
||||||
return "RingGather";
|
return nccl ? "NcclGather" : "RingGather";
|
||||||
|
|
||||||
case PERMUTE_COLLECTIVE:
|
case PERMUTE_COLLECTIVE:
|
||||||
return "Permute";
|
return "Permute";
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -952,5 +953,6 @@ def _setup_context():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
os.environ['NCCL_DEBUG'] = 'INFO'
|
||||||
v2_compat.enable_v2_behavior()
|
v2_compat.enable_v2_behavior()
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user